├── .gitignore ├── LICENSE ├── README.md ├── agents ├── __init__.py └── td3.py ├── approximators ├── __init__.py ├── policy.py └── rl_solution.py ├── cfgs ├── agent │ └── td3.yaml ├── approximator │ ├── hyperzero.yaml │ ├── hyperzero_without_q.yaml │ ├── hyperzero_without_td.yaml │ ├── meta_policy.yaml │ ├── meta_rl.yaml │ ├── meta_rl_td.yaml │ ├── mlp_policy.yaml │ ├── mlp_rl.yaml │ ├── mlp_rl_td.yaml │ └── pearl_policy.yaml ├── config.yaml ├── config_rl_approximator.yaml ├── dynamics │ ├── cartpole.yaml │ ├── cheetah.yaml │ ├── default.yaml │ ├── finger.yaml │ └── walker.yaml ├── obs │ └── states.yaml ├── reward │ ├── cartpole_default.yaml │ ├── cheetah_default.yaml │ ├── finger_default.yaml │ ├── overwrite_all.yaml │ └── walker_default.yaml └── task │ ├── cheetah_run.yaml │ ├── easy.yaml │ ├── finger_spin.yaml │ ├── medium.yaml │ └── walker_walk.yaml ├── eval.py ├── eval_many_agents.py ├── eval_many_approximators.py ├── models ├── __init__.py ├── core.py ├── hypenet_core.py └── rl_regressor.py ├── requirements.txt ├── train.py ├── train_rl_regressor.py └── utils ├── __init__.py ├── dataloader.py ├── dataset.py ├── dmc.py ├── logger.py ├── plots.py ├── replay_buffer.py ├── utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | runs 4 | .idea/ 5 | .DS_Store 6 | plots/ 7 | eval_plots/ 8 | images/ 9 | exp/ 10 | exp_multirun/ 11 | results/ 12 | results_multirun/ 13 | results_approximator/ 14 | results_approximator_multirun/ 15 | rollout_data*/ 16 | rollout_data_comparison*/ 17 | video_logs*/ 18 | video_demos*/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) SAMSUNG ELECTRONICS CANADA INC. (“SECA”). 2 | 3 | Attribution-NonCommercial 4.0 International 4 | 5 | ======================================================================= 6 | 7 | Creative Commons Corporation ("Creative Commons") is not a law firm and 8 | does not provide legal services or legal advice. Distribution of 9 | Creative Commons public licenses does not create a lawyer-client or 10 | other relationship. Creative Commons makes its licenses and related 11 | information available on an "as-is" basis. Creative Commons gives no 12 | warranties regarding its licenses, any material licensed under their 13 | terms and conditions, or any related information. Creative Commons 14 | disclaims all liability for damages resulting from their use to the 15 | fullest extent possible. 16 | 17 | Using Creative Commons Public Licenses 18 | 19 | Creative Commons public licenses provide a standard set of terms and 20 | conditions that creators and other rights holders may use to share 21 | original works of authorship and other material subject to copyright 22 | and certain other rights specified in the public license below. The 23 | following considerations are for informational purposes only, are not 24 | exhaustive, and do not form part of our licenses. 25 | 26 | Considerations for licensors: Our public licenses are 27 | intended for use by those authorized to give the public 28 | permission to use material in ways otherwise restricted by 29 | copyright and certain other rights. Our licenses are 30 | irrevocable. Licensors should read and understand the terms 31 | and conditions of the license they choose before applying it. 32 | Licensors should also secure all rights necessary before 33 | applying our licenses so that the public can reuse the 34 | material as expected. Licensors should clearly mark any 35 | material not subject to the license. This includes other CC- 36 | licensed material, or material used under an exception or 37 | limitation to copyright. More considerations for licensors: 38 | wiki.creativecommons.org/Considerations_for_licensors 39 | 40 | Considerations for the public: By using one of our public 41 | licenses, a licensor grants the public permission to use the 42 | licensed material under specified terms and conditions. If 43 | the licensor's permission is not necessary for any reason--for 44 | example, because of any applicable exception or limitation to 45 | copyright--then that use is not regulated by the license. Our 46 | licenses grant only permissions under copyright and certain 47 | other rights that a licensor has authority to grant. Use of 48 | the licensed material may still be restricted for other 49 | reasons, including because others have copyright or other 50 | rights in the material. A licensor may make special requests, 51 | such as asking that all changes be marked or described. 52 | Although not required by our licenses, you are encouraged to 53 | respect those requests where reasonable. More considerations 54 | for the public: 55 | wiki.creativecommons.org/Considerations_for_licensees 56 | 57 | ======================================================================= 58 | 59 | Creative Commons Attribution-NonCommercial 4.0 International Public 60 | License 61 | 62 | By exercising the Licensed Rights (defined below), You accept and agree 63 | to be bound by the terms and conditions of this Creative Commons 64 | Attribution-NonCommercial 4.0 International Public License ("Public 65 | License"). To the extent this Public License may be interpreted as a 66 | contract, You are granted the Licensed Rights in consideration of Your 67 | acceptance of these terms and conditions, and the Licensor grants You 68 | such rights in consideration of benefits the Licensor receives from 69 | making the Licensed Material available under these terms and 70 | conditions. 71 | 72 | 73 | Section 1 -- Definitions. 74 | 75 | a. Adapted Material means material subject to Copyright and Similar 76 | Rights that is derived from or based upon the Licensed Material 77 | and in which the Licensed Material is translated, altered, 78 | arranged, transformed, or otherwise modified in a manner requiring 79 | permission under the Copyright and Similar Rights held by the 80 | Licensor. For purposes of this Public License, where the Licensed 81 | Material is a musical work, performance, or sound recording, 82 | Adapted Material is always produced where the Licensed Material is 83 | synched in timed relation with a moving image. 84 | 85 | b. Adapter's License means the license You apply to Your Copyright 86 | and Similar Rights in Your contributions to Adapted Material in 87 | accordance with the terms and conditions of this Public License. 88 | 89 | c. Copyright and Similar Rights means copyright and/or similar rights 90 | closely related to copyright including, without limitation, 91 | performance, broadcast, sound recording, and Sui Generis Database 92 | Rights, without regard to how the rights are labeled or 93 | categorized. For purposes of this Public License, the rights 94 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 95 | Rights. 96 | d. Effective Technological Measures means those measures that, in the 97 | absence of proper authority, may not be circumvented under laws 98 | fulfilling obligations under Article 11 of the WIPO Copyright 99 | Treaty adopted on December 20, 1996, and/or similar international 100 | agreements. 101 | 102 | e. Exceptions and Limitations means fair use, fair dealing, and/or 103 | any other exception or limitation to Copyright and Similar Rights 104 | that applies to Your use of the Licensed Material. 105 | 106 | f. Licensed Material means the artistic or literary work, database, 107 | or other material to which the Licensor applied this Public 108 | License. 109 | 110 | g. Licensed Rights means the rights granted to You subject to the 111 | terms and conditions of this Public License, which are limited to 112 | all Copyright and Similar Rights that apply to Your use of the 113 | Licensed Material and that the Licensor has authority to license. 114 | 115 | h. Licensor means the individual(s) or entity(ies) granting rights 116 | under this Public License. 117 | 118 | i. NonCommercial means not primarily intended for or directed towards 119 | commercial advantage or monetary compensation. For purposes of 120 | this Public License, the exchange of the Licensed Material for 121 | other material subject to Copyright and Similar Rights by digital 122 | file-sharing or similar means is NonCommercial provided there is 123 | no payment of monetary compensation in connection with the 124 | exchange. 125 | 126 | j. Share means to provide material to the public by any means or 127 | process that requires permission under the Licensed Rights, such 128 | as reproduction, public display, public performance, distribution, 129 | dissemination, communication, or importation, and to make material 130 | available to the public including in ways that members of the 131 | public may access the material from a place and at a time 132 | individually chosen by them. 133 | 134 | k. Sui Generis Database Rights means rights other than copyright 135 | resulting from Directive 96/9/EC of the European Parliament and of 136 | the Council of 11 March 1996 on the legal protection of databases, 137 | as amended and/or succeeded, as well as other essentially 138 | equivalent rights anywhere in the world. 139 | 140 | l. You means the individual or entity exercising the Licensed Rights 141 | under this Public License. Your has a corresponding meaning. 142 | 143 | 144 | Section 2 -- Scope. 145 | 146 | a. License grant. 147 | 148 | 1. Subject to the terms and conditions of this Public License, 149 | the Licensor hereby grants You a worldwide, royalty-free, 150 | non-sublicensable, non-exclusive, irrevocable license to 151 | exercise the Licensed Rights in the Licensed Material to: 152 | 153 | a. reproduce and Share the Licensed Material, in whole or 154 | in part, for NonCommercial purposes only; and 155 | 156 | b. produce, reproduce, and Share Adapted Material for 157 | NonCommercial purposes only. 158 | 159 | 2. Exceptions and Limitations. For the avoidance of doubt, where 160 | Exceptions and Limitations apply to Your use, this Public 161 | License does not apply, and You do not need to comply with 162 | its terms and conditions. 163 | 164 | 3. Term. The term of this Public License is specified in Section 165 | 6(a). 166 | 167 | 4. Media and formats; technical modifications allowed. The 168 | Licensor authorizes You to exercise the Licensed Rights in 169 | all media and formats whether now known or hereafter created, 170 | and to make technical modifications necessary to do so. The 171 | Licensor waives and/or agrees not to assert any right or 172 | authority to forbid You from making technical modifications 173 | necessary to exercise the Licensed Rights, including 174 | technical modifications necessary to circumvent Effective 175 | Technological Measures. For purposes of this Public License, 176 | simply making modifications authorized by this Section 2(a) 177 | (4) never produces Adapted Material. 178 | 179 | 5. Downstream recipients. 180 | 181 | a. Offer from the Licensor -- Licensed Material. Every 182 | recipient of the Licensed Material automatically 183 | receives an offer from the Licensor to exercise the 184 | Licensed Rights under the terms and conditions of this 185 | Public License. 186 | 187 | b. No downstream restrictions. You may not offer or impose 188 | any additional or different terms or conditions on, or 189 | apply any Effective Technological Measures to, the 190 | Licensed Material if doing so restricts exercise of the 191 | Licensed Rights by any recipient of the Licensed 192 | Material. 193 | 194 | 6. No endorsement. Nothing in this Public License constitutes or 195 | may be construed as permission to assert or imply that You 196 | are, or that Your use of the Licensed Material is, connected 197 | with, or sponsored, endorsed, or granted official status by, 198 | the Licensor or others designated to receive attribution as 199 | provided in Section 3(a)(1)(A)(i). 200 | 201 | b. Other rights. 202 | 203 | 1. Moral rights, such as the right of integrity, are not 204 | licensed under this Public License, nor are publicity, 205 | privacy, and/or other similar personality rights; however, to 206 | the extent possible, the Licensor waives and/or agrees not to 207 | assert any such rights held by the Licensor to the limited 208 | extent necessary to allow You to exercise the Licensed 209 | Rights, but not otherwise. 210 | 211 | 2. Patent and trademark rights are not licensed under this 212 | Public License. 213 | 214 | 3. To the extent possible, the Licensor waives any right to 215 | collect royalties from You for the exercise of the Licensed 216 | Rights, whether directly or through a collecting society 217 | under any voluntary or waivable statutory or compulsory 218 | licensing scheme. In all other cases the Licensor expressly 219 | reserves any right to collect such royalties, including when 220 | the Licensed Material is used other than for NonCommercial 221 | purposes. 222 | 223 | 224 | Section 3 -- License Conditions. 225 | 226 | Your exercise of the Licensed Rights is expressly made subject to the 227 | following conditions. 228 | 229 | a. Attribution. 230 | 231 | 1. If You Share the Licensed Material (including in modified 232 | form), You must: 233 | 234 | a. retain the following if it is supplied by the Licensor 235 | with the Licensed Material: 236 | 237 | i. identification of the creator(s) of the Licensed 238 | Material and any others designated to receive 239 | attribution, in any reasonable manner requested by 240 | the Licensor (including by pseudonym if 241 | designated); 242 | 243 | ii. a copyright notice; 244 | 245 | iii. a notice that refers to this Public License; 246 | 247 | iv. a notice that refers to the disclaimer of 248 | warranties; 249 | 250 | v. a URI or hyperlink to the Licensed Material to the 251 | extent reasonably practicable; 252 | 253 | b. indicate if You modified the Licensed Material and 254 | retain an indication of any previous modifications; and 255 | 256 | c. indicate the Licensed Material is licensed under this 257 | Public License, and include the text of, or the URI or 258 | hyperlink to, this Public License. 259 | 260 | 2. You may satisfy the conditions in Section 3(a)(1) in any 261 | reasonable manner based on the medium, means, and context in 262 | which You Share the Licensed Material. For example, it may be 263 | reasonable to satisfy the conditions by providing a URI or 264 | hyperlink to a resource that includes the required 265 | information. 266 | 267 | 3. If requested by the Licensor, You must remove any of the 268 | information required by Section 3(a)(1)(A) to the extent 269 | reasonably practicable. 270 | 271 | 4. If You Share Adapted Material You produce, the Adapter's 272 | License You apply must not prevent recipients of the Adapted 273 | Material from complying with this Public License. 274 | 275 | 276 | Section 4 -- Sui Generis Database Rights. 277 | 278 | Where the Licensed Rights include Sui Generis Database Rights that 279 | apply to Your use of the Licensed Material: 280 | 281 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 282 | to extract, reuse, reproduce, and Share all or a substantial 283 | portion of the contents of the database for NonCommercial purposes 284 | only; 285 | 286 | b. if You include all or a substantial portion of the database 287 | contents in a database in which You have Sui Generis Database 288 | Rights, then the database in which You have Sui Generis Database 289 | Rights (but not its individual contents) is Adapted Material; and 290 | 291 | c. You must comply with the conditions in Section 3(a) if You Share 292 | all or a substantial portion of the contents of the database. 293 | 294 | For the avoidance of doubt, this Section 4 supplements and does not 295 | replace Your obligations under this Public License where the Licensed 296 | Rights include other Copyright and Similar Rights. 297 | 298 | 299 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 300 | 301 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 302 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 303 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 304 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 305 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 306 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 307 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 308 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 309 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 310 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 311 | 312 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 313 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 314 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 315 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 316 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 317 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 318 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 319 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 320 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 321 | 322 | c. The disclaimer of warranties and limitation of liability provided 323 | above shall be interpreted in a manner that, to the extent 324 | possible, most closely approximates an absolute disclaimer and 325 | waiver of all liability. 326 | 327 | 328 | Section 6 -- Term and Termination. 329 | 330 | a. This Public License applies for the term of the Copyright and 331 | Similar Rights licensed here. However, if You fail to comply with 332 | this Public License, then Your rights under this Public License 333 | terminate automatically. 334 | 335 | b. Where Your right to use the Licensed Material has terminated under 336 | Section 6(a), it reinstates: 337 | 338 | 1. automatically as of the date the violation is cured, provided 339 | it is cured within 30 days of Your discovery of the 340 | violation; or 341 | 342 | 2. upon express reinstatement by the Licensor. 343 | 344 | For the avoidance of doubt, this Section 6(b) does not affect any 345 | right the Licensor may have to seek remedies for Your violations 346 | of this Public License. 347 | 348 | c. For the avoidance of doubt, the Licensor may also offer the 349 | Licensed Material under separate terms or conditions or stop 350 | distributing the Licensed Material at any time; however, doing so 351 | will not terminate this Public License. 352 | 353 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 354 | License. 355 | 356 | 357 | Section 7 -- Other Terms and Conditions. 358 | 359 | a. The Licensor shall not be bound by any additional or different 360 | terms or conditions communicated by You unless expressly agreed. 361 | 362 | b. Any arrangements, understandings, or agreements regarding the 363 | Licensed Material not stated herein are separate from and 364 | independent of the terms and conditions of this Public License. 365 | 366 | 367 | Section 8 -- Interpretation. 368 | 369 | a. For the avoidance of doubt, this Public License does not, and 370 | shall not be interpreted to, reduce, limit, restrict, or impose 371 | conditions on any use of the Licensed Material that could lawfully 372 | be made without permission under this Public License. 373 | 374 | b. To the extent possible, if any provision of this Public License is 375 | deemed unenforceable, it shall be automatically reformed to the 376 | minimum extent necessary to make it enforceable. If the provision 377 | cannot be reformed, it shall be severed from this Public License 378 | without affecting the enforceability of the remaining terms and 379 | conditions. 380 | 381 | c. No term or condition of this Public License will be waived and no 382 | failure to comply consented to unless expressly agreed to by the 383 | Licensor. 384 | 385 | d. Nothing in this Public License constitutes or may be interpreted 386 | as a limitation upon, or waiver of, any privileges and immunities 387 | that apply to the Licensor or You, including from the legal 388 | processes of any jurisdiction or authority. 389 | 390 | ======================================================================= 391 | 392 | Creative Commons is not a party to its public 393 | licenses. Notwithstanding, Creative Commons may elect to apply one of 394 | its public licenses to material it publishes and in those instances 395 | will be considered the “Licensor.” The text of the Creative Commons 396 | public licenses is dedicated to the public domain under the CC0 Public 397 | Domain Dedication. Except for the limited purpose of indicating that 398 | material is shared under a Creative Commons public license or as 399 | otherwise permitted by the Creative Commons policies published at 400 | creativecommons.org/policies, Creative Commons does not authorize the 401 | use of the trademark "Creative Commons" or any other trademark or logo 402 | of Creative Commons without its prior written consent including, 403 | without limitation, in connection with any unauthorized modifications 404 | to any of its public licenses or any other arrangements, 405 | understandings, or agreements concerning use of licensed material. For 406 | the avoidance of doubt, this paragraph does not form part of the 407 | public licenses. 408 | 409 | Creative Commons may be contacted at creativecommons.org. 410 | 411 | 412 | ======================================================================= 413 | Below is the license for DrQ-v2 414 | https://github.com/facebookresearch/drqv2 415 | 416 | MIT License 417 | 418 | Copyright (c) Facebook, Inc. and its affiliates. 419 | 420 | Permission is hereby granted, free of charge, to any person obtaining a copy 421 | of this software and associated documentation files (the "Software"), to deal 422 | in the Software without restriction, including without limitation the rights 423 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 424 | copies of the Software, and to permit persons to whom the Software is 425 | furnished to do so, subject to the following conditions: 426 | 427 | The above copyright notice and this permission notice shall be included in all 428 | copies or substantial portions of the Software. 429 | 430 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 431 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 432 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 433 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 434 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 435 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 436 | SOFTWARE. 437 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hypernetworks for Zero-shot Transfer in Reinforcement Learning 2 | Author's PyTorch implementation of HyperZero. If you use our code, please cite our [AAAI 2023 paper](https://arxiv.org/abs/2211.15457): 3 | 4 | ```bib 5 | @article{rezaei2022hypernetworks, 6 | title={Hypernetworks for Zero-shot Transfer in Reinforcement Learning}, 7 | author={Rezaei-Shoshtari, Sahand and Morissette, Charlotte and Hogan, Francois Robert and Dudek, Gregory and Meger, David}, 8 | journal={arXiv preprint arXiv:2211.15457}, 9 | year={2022} 10 | } 11 | ``` 12 | 13 | ## Setup 14 | * We recommend using a conda virtual environment to run the code. 15 | Create the virtual environment: 16 | ```commandline 17 | conda create -n contextual_env python=3.9 18 | conda activate hyperzero_env 19 | pip install --upgrade pip 20 | ``` 21 | * This package requires [Contextual Control Suite](https://github.com/SAIC-MONTREAL/contextual-control-suite) 22 | to run. First, install that package following its instructions. 23 | * Clone this package and install the rest of its requirements: 24 | ```commandline 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Instructions 29 | The training of HyperZero is done in two steps: 30 | 1. Obtaining the near-optimal rollout dataset 31 | 1. [First option](https://github.com/SAIC-MONTREAL/hyperzero#option-a-training-rl-agents) is to train the RL agents yourself and then collect the rollouts. 32 | 2. [Second option](https://github.com/SAIC-MONTREAL/hyperzero#option-b-using-the-published-dataset) is to use our published dataset. 33 | 2. Training HyperZero on the dataset. 34 | 35 | ### Step 1: Obtaining the Near-optimal Rollout Dataset 36 | 37 | #### Option A: Training RL Agents 38 | * To train standard RL on a [Contextual Control Suite](https://github.com/SAIC-MONTREAL/contextual-control-suite) environment with default reward and dynamics parameters, use: 39 | ```commandline 40 | python train.py agent@_global_=td3 task@_global_=cheetah_run reward@_global_=cheetah_default dynamics@_global_=default 41 | ``` 42 | * We use [Hydra](https://github.com/facebookresearch/hydra) to specify configs. To sweep over context parameters, 43 | use the `--multirun` argument. For example, the following command sweeps over reward margins of ` range(0.5,10.1,0.5)` 44 | with linear reward function and default dynamics parameters: 45 | ```commandline 46 | python train.py --multirun agent@_global_=td3 task@_global_=cheetah_run reward@_global_=overwrite_all reward_parameters.ALL.margin='range(0.5,10.1,.5)' reward_parameters.ALL.sigmoid='linear' dynamics@_global_=default 47 | ``` 48 | * As another example, the following commands sweeps over a grid of reward margins of `range(1,5.1,1)` and dynamics 49 | parameters of `range(0.3,0.71,0.05)`: 50 | ```commandline 51 | python train.py --multirun agent@_global_=td3 task@_global_=cheetah_run reward@_global_=overwrite_all reward_parameters.ALL.margin='range(1,5.1,1)' reward_parameters.ALL.sigmoid='linear' dynamics@_global_=cheetah dynamics_parameters.length='range(0.3,0.71,0.05)' 52 | ``` 53 | * **Note:** Be mindful! These commands launch a lot of training scripts! 54 | 55 | * To evaluate the RL agents and generate the dataset used for training hyperzero, you can use [eval.py](eval.py). 56 | A helper script is set up to load each trained RL agent and generates a set of `.npy` files to be later loaded 57 | by [RLSolutionDataset](utils/dataset.py): 58 | ```commandline 59 | python eval_many_agents.py --rootdir --domain_task cheetah_run 60 | ``` 61 | 62 | #### Option B: Using the Published Dataset 63 | * Instead of training RL agents, you can download our published dataset from 64 | [here](https://mcgill-my.sharepoint.com/:f:/g/personal/sahand_rezaei-shoshtari_mail_mcgill_ca/EhDgTXh3v-pIhTHZXM1xaz0BMWT-N8jNheVm2156mhbZdA?e=hMu4N1). 65 | This dataset was used to generate some of the results in our [AAAI 2023 paper](https://arxiv.org/abs/2211.15457). 66 | * Simply extract the dataset in the desired location and proceed to Step 2. 67 | 68 | ### Step 2: Training HyperZero 69 | * Finally, to train hyperzero (or the baselines), use the following command. It trains and saves the RL regressor: 70 | ```commandline 71 | python train_rl_regressor.py rollout_dir= domain_task=cheetah_run approximator@_global_=hyperzero input_to_model=rew 72 | ``` 73 | * The argument `input_to_model` specifies the MDP context that is used to generate the policies. It can take `rew`, 74 | `dyn` and `rew_dyn`. 75 | * Or to train a bunch of RL regressors, use: 76 | ```commandline 77 | python train_rl_regressor.py --multirun rollout_dir= domain_task=cheetah_run input_to_model=`rew` 78 | ``` 79 | * To visualize the RL regressor and roll-out the policy, you can use [eval.py](eval.py). A helper script is set up that 80 | loads several RL regressors and evaluates them: 81 | ```commandline 82 | python eval_many_approximators.py --rootdir --approximator_rootdir --rollout_dir --domain_task cheetah_run 83 | ``` 84 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAIC-MONTREAL/hyperzero/ab0508a73c09940d8c98267af8ae021d834915d0/agents/__init__.py -------------------------------------------------------------------------------- /agents/td3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3) 3 | https://arxiv.org/abs/1802.09477 4 | """ 5 | 6 | import hydra 7 | import copy 8 | import numpy as np 9 | from pathlib import Path 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from models.core import DeterministicActor, Critic 15 | import utils.utils as utils 16 | 17 | 18 | class TD3Agent: 19 | def __init__(self, obs_shape, action_shape, device, lr, hidden_dim, 20 | critic_target_tau, num_expl_steps, update_every_steps, 21 | stddev_schedule, stddev_clip): 22 | 23 | self.device = device 24 | self.critic_target_tau = critic_target_tau 25 | self.update_every_steps = update_every_steps 26 | self.num_expl_steps = num_expl_steps 27 | self.stddev_schedule = stddev_schedule 28 | self.stddev_clip = stddev_clip 29 | self.action_dim = action_shape[0] 30 | self.hidden_dim = hidden_dim 31 | self.lr = lr 32 | 33 | # models 34 | self.actor = DeterministicActor(obs_shape[0], action_shape[0], hidden_dim).to(self.device) 35 | self.actor_target = copy.deepcopy(self.actor) 36 | 37 | self.critic = Critic(obs_shape[0], action_shape[0], hidden_dim).to(self.device) 38 | self.critic_target = copy.deepcopy(self.critic) 39 | 40 | # optimizers 41 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) 42 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr) 43 | 44 | self.train() 45 | self.actor_target.train() 46 | self.critic_target.train() 47 | 48 | def train(self, training=True): 49 | self.training = training 50 | self.actor.train(training) 51 | self.critic.train(training) 52 | 53 | def act(self, obs, step, eval_mode): 54 | obs = torch.as_tensor(obs, device=self.device) 55 | stddev = utils.schedule(self.stddev_schedule, step) 56 | action = self.actor(obs.float().unsqueeze(0)) 57 | if eval_mode: 58 | action = action.cpu().numpy()[0] 59 | else: 60 | action = action.cpu().numpy()[0] + np.random.normal(0, stddev, size=self.action_dim) 61 | if step < self.num_expl_steps: 62 | action = np.random.uniform(-1.0, 1.0, size=self.action_dim) 63 | return action.astype(np.float32) 64 | 65 | def observe(self, obs, action): 66 | obs = torch.as_tensor(obs, device=self.device).float().unsqueeze(0) 67 | action = torch.as_tensor(action, device=self.device).float().unsqueeze(0) 68 | 69 | q, _ = self.critic(obs, action) 70 | 71 | return { 72 | 'state': obs.cpu().numpy()[0], 73 | 'value': q.cpu().numpy()[0] 74 | } 75 | 76 | def update_critic(self, obs, action, reward, discount, next_obs, step): 77 | metrics = dict() 78 | 79 | with torch.no_grad(): 80 | # Select action according to policy and add clipped noise 81 | stddev = utils.schedule(self.stddev_schedule, step) 82 | noise = (torch.randn_like(action) * stddev).clamp(-self.stddev_clip, self.stddev_clip) 83 | 84 | next_action = (self.actor_target(next_obs) + noise).clamp(-1.0, 1.0) 85 | 86 | # Compute the target Q value 87 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 88 | target_Q = torch.min(target_Q1, target_Q2) 89 | target_Q = reward + discount * target_Q 90 | 91 | # Get current Q estimates 92 | current_Q1, current_Q2 = self.critic(obs, action) 93 | 94 | # Compute critic loss 95 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 96 | 97 | metrics['critic_target_q'] = target_Q.mean().item() 98 | metrics['critic_q1'] = current_Q1.mean().item() 99 | metrics['critic_q2'] = current_Q2.mean().item() 100 | metrics['critic_loss'] = critic_loss.item() 101 | 102 | # Optimize the critic 103 | self.critic_optimizer.zero_grad(set_to_none=True) 104 | critic_loss.backward() 105 | self.critic_optimizer.step() 106 | 107 | return metrics 108 | 109 | def update_actor(self, obs, step): 110 | metrics = dict() 111 | 112 | # Compute actor loss 113 | actor_loss = -self.critic.Q1(obs, self.actor(obs)).mean() 114 | 115 | # Optimize the actor 116 | self.actor_optimizer.zero_grad(set_to_none=True) 117 | actor_loss.backward() 118 | self.actor_optimizer.step() 119 | 120 | metrics['actor_loss'] = actor_loss.item() 121 | 122 | return metrics 123 | 124 | def update(self, replay_iter, step): 125 | metrics = dict() 126 | 127 | batch = next(replay_iter) 128 | obs, action, reward, discount, next_obs, _ = utils.to_torch( 129 | batch, self.device) 130 | 131 | obs = obs.float() 132 | next_obs = next_obs.float() 133 | 134 | metrics['batch_reward'] = reward.mean().item() 135 | 136 | # update critic 137 | metrics.update(self.update_critic(obs, action, reward, discount, next_obs, step)) 138 | 139 | # update actor (delayed) 140 | if step % self.update_every_steps == 0: 141 | metrics.update(self.update_actor(obs.detach(), step)) 142 | 143 | # update target networks 144 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau) 145 | utils.soft_update_params(self.actor, self.actor_target, self.critic_target_tau) 146 | 147 | return metrics 148 | 149 | def save(self, model_dir, step): 150 | model_save_dir = Path(f'{model_dir}/step_{str(step).zfill(8)}') 151 | model_save_dir.mkdir(exist_ok=True, parents=True) 152 | 153 | torch.save(self.actor.state_dict(), f'{model_save_dir}/actor.pt') 154 | torch.save(self.critic.state_dict(), f'{model_save_dir}/critic.pt') 155 | 156 | def load(self, model_dir, step): 157 | print(f"Loading the model from {model_dir}, step: {step}") 158 | model_load_dir = Path(f'{model_dir}/step_{str(step).zfill(8)}') 159 | 160 | self.actor.load_state_dict( 161 | torch.load(f'{model_load_dir}/actor.pt', map_location=self.device) 162 | ) 163 | self.critic.load_state_dict( 164 | torch.load(f'{model_load_dir}/critic.pt', map_location=self.device) 165 | ) 166 | -------------------------------------------------------------------------------- /approximators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAIC-MONTREAL/hyperzero/ab0508a73c09940d8c98267af8ae021d834915d0/approximators/__init__.py -------------------------------------------------------------------------------- /approximators/policy.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | 5 | import hydra 6 | from pathlib import Path 7 | from collections import defaultdict 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import learn2learn as l2l 12 | 13 | import utils.utils as utils 14 | from models.rl_regressor import MLPActionPredictor, HyperPolicy, MLPContextEncoder 15 | from utils.dataloader import FastTensorDataLoader 16 | 17 | 18 | class PolicyApproximator: 19 | """ 20 | Approximates a family of near-optimal policies. 21 | Uses either a conditional MLP or a hypernetwork. 22 | """ 23 | def __init__(self, model, input_dim, state_dim, 24 | action_dim, device, lr, embed_dim, 25 | hidden_dim, noise_clip, use_clipped_noise): 26 | self.device = device 27 | self.lr = lr 28 | self.use_clipped_noise = use_clipped_noise 29 | self.noise_clip = noise_clip 30 | 31 | # model 32 | if model == 'mlp': 33 | model_fn = MLPActionPredictor 34 | elif model == 'hyper': 35 | model_fn = HyperPolicy 36 | else: 37 | raise NotImplementedError 38 | 39 | self.policy = model_fn(input_dim, 40 | state_dim, 41 | action_dim, 42 | embed_dim, 43 | hidden_dim).to(device) 44 | 45 | # optimizer 46 | self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) 47 | 48 | self.train() 49 | self.policy.train() 50 | 51 | def train(self, training=True): 52 | self.training = training 53 | self.policy.train(training) 54 | 55 | def act(self, input_param, obs): 56 | input_param = torch.as_tensor(input_param, dtype=torch.float32, device=self.device).unsqueeze(0) 57 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0) 58 | action = self.policy(input_param, obs).cpu().numpy()[0] 59 | return action.astype(np.float32) 60 | 61 | def eval(self, data_loader): 62 | metrics = defaultdict(lambda: 0) 63 | 64 | num_batches = len(data_loader) 65 | self.train(True) 66 | 67 | for batch_idx, batch in enumerate(data_loader): 68 | input_param, state, action, next_state, reward, discount, value = batch 69 | 70 | predicted_action = self.policy(input_param, state) 71 | loss = F.mse_loss(predicted_action, action) 72 | 73 | metrics['valid/loss_action_pred'] += loss.item() 74 | metrics['valid/loss_total'] += loss.item() 75 | 76 | for k in metrics.keys(): 77 | metrics[k] /= num_batches 78 | return metrics 79 | 80 | def update(self, data_loader): 81 | metrics = defaultdict(lambda: 0) 82 | 83 | num_batches = len(data_loader) 84 | self.train(True) 85 | 86 | for batch_idx, batch in enumerate(data_loader): 87 | input_param, state, action, next_state, reward, discount, value = batch 88 | 89 | if self.use_clipped_noise: 90 | # Add clipped noise to the input param 91 | input_param_noise = (torch.randn_like(input_param)).clamp(-self.noise_clip, self.noise_clip) 92 | input_param += input_param_noise 93 | 94 | predicted_action = self.policy(input_param, state) 95 | 96 | loss = F.mse_loss(predicted_action, action) 97 | 98 | self.policy_optimizer.zero_grad() 99 | loss.backward() 100 | self.policy_optimizer.step() 101 | 102 | metrics['train/loss_action_pred'] += loss.item() 103 | metrics['train/loss_total'] += loss.item() 104 | 105 | for k in metrics.keys(): 106 | metrics[k] /= num_batches 107 | return metrics 108 | 109 | def save(self, model_dir, name): 110 | model_save_dir = Path(f'{model_dir}/step_{str(name).zfill(8)}') 111 | model_save_dir.mkdir(exist_ok=True, parents=True) 112 | 113 | torch.save(self.policy.state_dict(), f'{model_save_dir}/policy.pt') 114 | 115 | def load(self, model_dir, name): 116 | print(f"Loading the model from {model_dir}, name: {name}") 117 | model_load_dir = Path(f'{model_dir}/step_{str(name).zfill(8)}') 118 | 119 | self.policy.load_state_dict( 120 | torch.load(f'{model_load_dir}/policy.pt', map_location=self.device) 121 | ) 122 | 123 | 124 | class MetaPolicyApproximator(PolicyApproximator): 125 | """ 126 | Approximates a family of near-optimal policies. 127 | Uses either MAML or PEARL. 128 | """ 129 | def __init__(self, model, input_dim, state_dim, 130 | action_dim, device, lr, fast_lr, embed_dim, 131 | hidden_dim, noise_clip, use_clipped_noise, 132 | adaptation_steps, use_pearl, kl_lambda): 133 | super().__init__(model, input_dim, state_dim, 134 | action_dim, device, lr, embed_dim, 135 | hidden_dim, noise_clip, use_clipped_noise) 136 | assert model == 'mlp', "MAML only works with MLP policy" 137 | assert not use_clipped_noise, "MAML/PEARL cannot use clipped noise." 138 | del self.policy 139 | del self.policy_optimizer 140 | 141 | self.adaptation_steps = adaptation_steps 142 | self.use_pearl = use_pearl 143 | self.kl_lambda = kl_lambda 144 | self.prev_action = None 145 | self.prev_state = None 146 | self.action_dim = action_dim 147 | self.state_dim = state_dim 148 | 149 | if self.use_pearl: 150 | # PEARL baseline 151 | context_dim = embed_dim 152 | self.context_encoder = MLPContextEncoder(state_dim, 153 | action_dim, 154 | embed_dim, 155 | hidden_dim).to(device) 156 | self.context_encoder_optimizer = torch.optim.Adam(self.context_encoder.parameters(), lr=lr) 157 | else: 158 | # MAML baseline 159 | context_dim = input_dim 160 | 161 | policy = MLPActionPredictor(context_dim, 162 | state_dim, 163 | action_dim, 164 | embed_dim, 165 | hidden_dim).to(device) 166 | self.policy = l2l.algorithms.MAML(policy, lr=fast_lr, first_order=False) 167 | 168 | # optimizer 169 | self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr) 170 | 171 | self.train() 172 | self.policy.train() 173 | 174 | def fast_adapt(self, learner, batch): 175 | eval_metrics = dict() 176 | batch_size = batch[0].shape[0] 177 | 178 | # separate data into adaptation/evaluation sets 179 | adaptation_indices = np.zeros(batch_size, dtype=bool) 180 | adaptation_indices[np.arange(batch_size // 2)] = True 181 | evaluation_indices = torch.from_numpy(~adaptation_indices) 182 | adaptation_indices = torch.from_numpy(adaptation_indices) 183 | 184 | adapt_input_param, adapt_state, adapt_action, adapt_next_state, adapt_reward, adapt_discount, adapt_value = utils.select_indices( 185 | batch, adaptation_indices) 186 | eval_input_param, eval_state, eval_action, eval_next_state, eval_reward, eval_discount, eval_value = utils.select_indices( 187 | batch, evaluation_indices) 188 | 189 | # adapt the model 190 | for step in range(self.adaptation_steps): 191 | if self.use_pearl: 192 | adapt_context, _, _ = self.context_encoder(adapt_state, adapt_action) 193 | adapt_predicted_action = learner(adapt_context, adapt_state) 194 | else: 195 | adapt_predicted_action = learner(adapt_input_param, adapt_state) 196 | adapt_loss = F.mse_loss(adapt_predicted_action, adapt_action) 197 | learner.adapt(adapt_loss) 198 | 199 | # evaluate the adapted model 200 | if self.use_pearl: 201 | eval_context, eval_mu, eval_log_var = self.context_encoder(eval_state, eval_action) 202 | eval_predicted_action = learner(eval_context, eval_state) 203 | eval_loss = F.mse_loss(eval_predicted_action, eval_action) 204 | 205 | # compute KL loss 206 | kl_div = self.compute_kl_div(eval_mu, eval_log_var) 207 | kl_loss = self.kl_lambda * kl_div 208 | eval_loss += kl_loss 209 | eval_metrics['loss_kl'] = kl_loss.item() 210 | 211 | else: 212 | eval_predicted_action = learner(eval_input_param, eval_state) 213 | eval_loss = F.mse_loss(eval_predicted_action, eval_action) 214 | 215 | eval_metrics['loss_action_pred'] = eval_loss.item() 216 | eval_metrics['loss_total'] = eval_loss.item() 217 | 218 | return eval_loss, eval_metrics 219 | 220 | def compute_kl_div(self, mu, log_var): 221 | assert self.use_pearl 222 | kl_div = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1) 223 | return kl_div.mean() 224 | 225 | def act(self, input_param, obs): 226 | if self.use_pearl: 227 | # uses previous action and state to infer the context 228 | if self.prev_action is None: 229 | self.prev_action = torch.zeros(1, self.action_dim).to(self.device) 230 | if self.prev_state is None: 231 | self.prev_state = torch.zeros(1, self.state_dim).to(self.device) 232 | 233 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0) 234 | context, _, _ = self.context_encoder(self.prev_state, self.prev_action) 235 | action = self.policy(context, obs) 236 | 237 | self.prev_action = action.clone() 238 | self.prev_state = obs.clone() 239 | 240 | action = action.cpu().numpy()[0] 241 | else: 242 | input_param = torch.as_tensor(input_param, dtype=torch.float32, device=self.device).unsqueeze(0) 243 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0) 244 | action = self.policy(input_param, obs).cpu().numpy()[0] 245 | 246 | return action.astype(np.float32) 247 | 248 | def eval(self, data_loader): 249 | metrics = defaultdict(lambda: 0) 250 | 251 | num_tasks = data_loader.n_tasks 252 | self.train(True) 253 | 254 | for task_idx in range(num_tasks): 255 | learner = self.policy.clone() 256 | batch = data_loader.sample(task_idx) 257 | 258 | _, meta_loss_logs = self.fast_adapt(learner, batch) 259 | 260 | for k, v in meta_loss_logs.items(): 261 | metrics[f'valid/{k}'] += v 262 | 263 | for k in metrics.keys(): 264 | metrics[k] /= num_tasks 265 | return metrics 266 | 267 | def finetune(self, input_param, state, action): 268 | input_param = torch.as_tensor(input_param, dtype=torch.float32, device=self.device) 269 | state = torch.as_tensor(state, dtype=torch.float32, device=self.device) 270 | action = torch.as_tensor(action, dtype=torch.float32, device=self.device) 271 | 272 | data_loader = FastTensorDataLoader(input_param, state, action, batch_size=512, shuffle=True, device=self.device) 273 | finetuner = torch.optim.Adam(self.policy.parameters(), lr=0.001) 274 | 275 | # Updates the actual policy! 276 | for step in range(self.adaptation_steps): 277 | for batch_idx, batch in enumerate(data_loader): 278 | batch_input_param, batch_state, batch_action = batch 279 | if self.use_pearl: 280 | adapt_context, _, _ = self.context_encoder(batch_state, batch_action) 281 | adapt_predicted_action = self.policy(adapt_context, batch_state) 282 | else: 283 | adapt_predicted_action = self.policy(batch_input_param, batch_state) 284 | adapt_loss = F.mse_loss(adapt_predicted_action, batch_action) 285 | 286 | # self.policy.adapt(adapt_loss) 287 | finetuner.zero_grad() 288 | adapt_loss.backward() 289 | finetuner.step() 290 | 291 | def update(self, data_loader): 292 | metrics = defaultdict(lambda: 0) 293 | 294 | num_tasks = data_loader.n_tasks 295 | self.train(True) 296 | 297 | self.policy_optimizer.zero_grad() 298 | if self.use_pearl: 299 | self.context_encoder_optimizer.zero_grad() 300 | 301 | for task_idx in range(num_tasks): 302 | # compute meta-training loss 303 | learner = self.policy.clone() 304 | batch = data_loader.sample(task_idx) 305 | 306 | meta_loss, meta_loss_logs = self.fast_adapt(learner, batch) 307 | meta_loss.backward() 308 | 309 | for k, v in meta_loss_logs.items(): 310 | metrics[f'train/{k}'] += v 311 | 312 | for p in self.policy.parameters(): 313 | p.grad.data.mul_(1.0 / num_tasks) 314 | self.policy_optimizer.step() 315 | 316 | if self.use_pearl: 317 | for p in self.context_encoder.parameters(): 318 | p.grad.data.mul_(1.0 / num_tasks) 319 | self.context_encoder_optimizer.step() 320 | 321 | for k in metrics.keys(): 322 | metrics[k] /= num_tasks 323 | return metrics 324 | -------------------------------------------------------------------------------- /approximators/rl_solution.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | from pathlib import Path 4 | from collections import defaultdict 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import learn2learn as l2l 9 | 10 | from models.rl_regressor import MLPRLSolution, HyperRLSolution 11 | import utils.utils as utils 12 | 13 | 14 | class RLApproximator: 15 | """ 16 | Approximates a family of near-optimal Rl solutions. 17 | Uses either a conditional MLP or a hypernetwork. 18 | """ 19 | def __init__(self, model, input_dim, state_dim, action_dim, 20 | device, lr, embed_dim, hidden_dim, noise_clip, 21 | use_clipped_noise, use_td, td_weight, value_weight): 22 | self.device = device 23 | self.lr = lr 24 | self.model = model 25 | self.use_td_error = use_td 26 | self.use_clipped_noise = use_clipped_noise 27 | self.noise_clip = noise_clip 28 | self.td_weight = td_weight 29 | self.value_weight = value_weight 30 | 31 | # model 32 | if model == 'mlp': 33 | model_fn = MLPRLSolution 34 | elif model == 'hyper': 35 | model_fn = HyperRLSolution 36 | else: 37 | raise NotImplementedError 38 | 39 | self.rl_net = model_fn(input_dim, 40 | state_dim, 41 | action_dim, 42 | embed_dim, 43 | hidden_dim).to(device) 44 | 45 | # optimizer 46 | self.rl_net_optimizer = torch.optim.Adam(self.rl_net.parameters(), lr=lr) 47 | 48 | self.train() 49 | self.rl_net.train() 50 | 51 | def train(self, training=True): 52 | self.training = training 53 | self.rl_net.train(training) 54 | 55 | def act(self, input_param, obs): 56 | input_param = torch.as_tensor(input_param, dtype=torch.float32, device=self.device).unsqueeze(0) 57 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0) 58 | task_emb = self.rl_net.embed_task(input_param) 59 | action = self.rl_net.predict_action(task_emb, obs).cpu().numpy()[0] 60 | return action.astype(np.float32) 61 | 62 | def q(self, input_param, obs, action): 63 | input_param = torch.as_tensor(input_param, dtype=torch.float32, device=self.device).unsqueeze(0) 64 | obs = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0) 65 | action = torch.as_tensor(action, dtype=torch.float32, device=self.device).unsqueeze(0) 66 | task_emb = self.rl_net.embed_task(input_param) 67 | q = self.rl_net.predict_q_value(task_emb, obs, action).cpu().numpy()[0] 68 | return q.astype(np.float32) 69 | 70 | def eval(self, data_loader): 71 | metrics = defaultdict(lambda: 0) 72 | 73 | num_batches = len(data_loader) 74 | self.train(True) 75 | 76 | for batch_idx, batch in enumerate(data_loader): 77 | input_param, state, action, next_state, reward, discount, value = batch 78 | 79 | task_emb, predicted_action, predicted_value = self.rl_net(input_param, state, action) 80 | 81 | loss_action = F.mse_loss(predicted_action, action) 82 | loss_value = F.mse_loss(predicted_value, value) 83 | loss = loss_action + self.value_weight * loss_value 84 | # evaluate the TD error in any case 85 | loss_td = self.get_td_error(task_emb, next_state, reward, discount, value) 86 | 87 | if self.use_td_error: 88 | loss += self.td_weight * loss_td 89 | 90 | metrics['valid/loss_action_pred'] += loss_action.item() 91 | metrics['valid/loss_value_pred'] += self.value_weight * loss_value.item() 92 | metrics['valid/loss_td'] += self.td_weight * loss_td.item() 93 | metrics['valid/loss_total'] += loss.item() 94 | 95 | for k in metrics.keys(): 96 | metrics[k] /= num_batches 97 | return metrics 98 | 99 | def update(self, data_loader): 100 | metrics = defaultdict(lambda: 0) 101 | 102 | num_batches = len(data_loader) 103 | self.train(True) 104 | 105 | for batch_idx, batch in enumerate(data_loader): 106 | input_param, state, action, next_state, reward, discount, value = batch 107 | 108 | if self.use_clipped_noise: 109 | # Add clipped noise to the reward param 110 | input_param_noise = (torch.randn_like(input_param)).clamp(-self.noise_clip, self.noise_clip) 111 | input_param += input_param_noise 112 | 113 | task_emb, predicted_action, predicted_value = self.rl_net(input_param, state, action) 114 | 115 | loss_action = F.mse_loss(predicted_action, action) 116 | loss_value = F.mse_loss(predicted_value, value) 117 | loss = loss_action + self.value_weight * loss_value 118 | 119 | if self.use_td_error: 120 | loss_td = self.get_td_error(task_emb, next_state, reward, discount, value) 121 | loss += self.td_weight * loss_td 122 | metrics['train/loss_td'] += self.td_weight * loss_td.item() 123 | 124 | self.rl_net_optimizer.zero_grad() 125 | loss.backward() 126 | self.rl_net_optimizer.step() 127 | 128 | metrics['train/loss_action_pred'] += loss_action.item() 129 | metrics['train/loss_value_pred'] += self.value_weight * loss_value.item() 130 | metrics['train/loss_total'] += loss.item() 131 | 132 | for k in metrics.keys(): 133 | metrics[k] /= num_batches 134 | return metrics 135 | 136 | def get_td_error(self, task_emb, next_state, reward, discount, q): 137 | with torch.no_grad(): 138 | next_action = self.rl_net.predict_action(task_emb, next_state) 139 | target_q = self.rl_net.predict_q_value(task_emb, next_state, next_action) 140 | target_q = reward + discount * target_q 141 | 142 | td_error = F.mse_loss(q, target_q) 143 | return td_error 144 | 145 | def save(self, model_dir, name): 146 | model_save_dir = Path(f'{model_dir}/step_{str(name).zfill(8)}') 147 | model_save_dir.mkdir(exist_ok=True, parents=True) 148 | 149 | torch.save(self.rl_net.state_dict(), f'{model_save_dir}/rl_net.pt') 150 | 151 | def load(self, model_dir, name): 152 | print(f"Loading the model from {model_dir}, name: {name}") 153 | model_load_dir = Path(f'{model_dir}/step_{str(name).zfill(8)}') 154 | 155 | self.rl_net.load_state_dict( 156 | torch.load(f'{model_load_dir}/rl_net.pt', map_location=self.device) 157 | ) 158 | 159 | 160 | class MetaRLApproximator(RLApproximator): 161 | """ 162 | Approximates a family of near-optimal policies. Uses MAML. 163 | """ 164 | def __init__(self, model, input_dim, state_dim, action_dim, 165 | device, lr, fast_lr, embed_dim, hidden_dim, noise_clip, 166 | use_clipped_noise, use_td, td_weight, value_weight, 167 | adaptation_steps): 168 | super().__init__(model, input_dim, state_dim, action_dim, 169 | device, lr, embed_dim, hidden_dim, noise_clip, 170 | use_clipped_noise, use_td, td_weight, value_weight) 171 | assert model == 'mlp', "MAML only works with MLP RL approximator" 172 | assert not use_clipped_noise, "MAML cannot use clipped noise." 173 | del self.rl_net 174 | del self.rl_net_optimizer 175 | 176 | self.adaptation_steps = adaptation_steps 177 | 178 | # MAML model 179 | policy = MLPRLSolution(input_dim, 180 | state_dim, 181 | action_dim, 182 | embed_dim, 183 | hidden_dim).to(device) 184 | self.rl_net = l2l.algorithms.MAML(policy, lr=fast_lr, first_order=False) 185 | 186 | # optimizer 187 | self.rl_net_optimizer = torch.optim.Adam(self.rl_net.parameters(), lr=lr) 188 | 189 | self.train() 190 | self.rl_net.train() 191 | 192 | def fast_adapt(self, learner, batch): 193 | eval_metrics = dict() 194 | batch_size = batch[0].shape[0] 195 | 196 | # separate data into adaptation/evalutation sets 197 | adaptation_indices = np.zeros(batch_size, dtype=bool) 198 | adaptation_indices[np.arange(batch_size // 2)] = True 199 | evaluation_indices = torch.from_numpy(~adaptation_indices) 200 | adaptation_indices = torch.from_numpy(adaptation_indices) 201 | 202 | adapt_input_param, adapt_state, adapt_action, adapt_next_state, adapt_reward, adapt_discount, adapt_value = utils.select_indices( 203 | batch, adaptation_indices) 204 | eval_input_param, eval_state, eval_action, eval_next_state, eval_reward, eval_discount, eval_value = utils.select_indices( 205 | batch, evaluation_indices) 206 | 207 | # adapt the model 208 | for step in range(self.adaptation_steps): 209 | adapt_task_emb, adapt_predicted_action, adapt_predicted_value = learner(adapt_input_param, adapt_state, adapt_action) 210 | 211 | adapt_action_loss = F.mse_loss(adapt_predicted_action, adapt_action) 212 | adapt_value_loss = F.mse_loss(adapt_predicted_value, adapt_value) 213 | adapt_loss = adapt_action_loss + self.value_weight * adapt_value_loss 214 | 215 | if self.use_td_error: 216 | adapt_loss_td = self.get_td_error(adapt_task_emb, adapt_next_state, adapt_reward, adapt_discount, adapt_value) 217 | adapt_loss += self.td_weight * adapt_loss_td 218 | 219 | learner.adapt(adapt_loss) 220 | 221 | # evaluate the adapted model 222 | eval_task_emb, eval_predicted_action, eval_predicted_value = learner(eval_input_param, eval_state, eval_action) 223 | 224 | eval_action_loss = F.mse_loss(eval_predicted_action, eval_action) 225 | eval_value_loss = F.mse_loss(eval_predicted_value, eval_value) 226 | eval_loss = eval_action_loss + self.value_weight * eval_value_loss 227 | 228 | if self.use_td_error: 229 | eval_loss_td = self.get_td_error(eval_task_emb, eval_next_state, eval_reward, eval_discount, eval_value) 230 | eval_loss += self.td_weight * eval_loss_td 231 | 232 | eval_metrics['loss_action_pred'] = eval_action_loss.item() 233 | eval_metrics['loss_value_pred'] = eval_value_loss.item() 234 | eval_metrics['loss_total'] = eval_loss.item() 235 | 236 | return eval_loss, eval_metrics 237 | 238 | def eval(self, data_loader): 239 | metrics = defaultdict(lambda: 0) 240 | 241 | num_tasks = data_loader.n_tasks 242 | num_batches = len(data_loader) 243 | self.train(True) 244 | 245 | for batch_idx in range(num_batches): 246 | for task_idx in range(num_tasks): 247 | learner = self.rl_net.clone() 248 | batch = data_loader.sample(task_idx) 249 | 250 | _, meta_loss_logs = self.fast_adapt(learner, batch) 251 | 252 | for k, v in meta_loss_logs.items(): 253 | metrics[f'valid/{k}'] += v 254 | 255 | for k in metrics.keys(): 256 | metrics[k] /= (num_tasks * num_batches) 257 | return metrics 258 | 259 | def update(self, data_loader): 260 | metrics = defaultdict(lambda: 0) 261 | 262 | num_tasks = data_loader.n_tasks 263 | num_batches = len(data_loader) 264 | self.train(True) 265 | 266 | for batch_idx in range(num_batches): 267 | self.rl_net_optimizer.zero_grad() 268 | for task_idx in range(num_tasks): 269 | # compute meta-training loss 270 | learner = self.rl_net.clone() 271 | batch = data_loader.sample(task_idx) 272 | 273 | meta_loss, meta_loss_logs = self.fast_adapt(learner, batch) 274 | meta_loss.backward() 275 | 276 | for k, v in meta_loss_logs.items(): 277 | metrics[f'train/{k}'] += v 278 | 279 | for p in self.rl_net.parameters(): 280 | p.grad.data.mul_(1.0 / num_tasks) 281 | self.rl_net_optimizer.step() 282 | 283 | for k in metrics.keys(): 284 | metrics[k] /= (num_tasks * num_batches) 285 | return metrics -------------------------------------------------------------------------------- /cfgs/agent/td3.yaml: -------------------------------------------------------------------------------- 1 | agent_name: td3 2 | 3 | agent: 4 | _target_: agents.td3.TD3Agent 5 | obs_shape: ??? # to be specified later 6 | action_shape: ??? # to be specified later 7 | device: ${device} 8 | lr: ${lr} 9 | critic_target_tau: 0.01 10 | update_every_steps: 2 11 | num_expl_steps: 2000 12 | hidden_dim: 256 13 | stddev_schedule: ${stddev_schedule} 14 | stddev_clip: 0.3 15 | -------------------------------------------------------------------------------- /cfgs/approximator/hyperzero.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: hyperzero 2 | 3 | approximator: 4 | _target_: approximators.rl_solution.RLApproximator 5 | model: hyper 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | hidden_dim: ${hidden_dim} 12 | embed_dim: ${embed_dim} 13 | use_td: true 14 | use_clipped_noise: false 15 | noise_clip: ${noise_clip} 16 | value_weight: ${value_weight} 17 | td_weight: ${td_weight} -------------------------------------------------------------------------------- /cfgs/approximator/hyperzero_without_q.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: hyperzero_without_q 2 | 3 | approximator: 4 | _target_: approximators.policy.PolicyApproximator 5 | model: hyper 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | hidden_dim: ${hidden_dim} 12 | embed_dim: ${embed_dim} 13 | use_clipped_noise: false 14 | noise_clip: ${noise_clip} 15 | -------------------------------------------------------------------------------- /cfgs/approximator/hyperzero_without_td.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: hyperzero_without_td 2 | 3 | approximator: 4 | _target_: approximators.rl_solution.RLApproximator 5 | model: hyper 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | hidden_dim: ${hidden_dim} 12 | embed_dim: ${embed_dim} 13 | use_td: false 14 | use_clipped_noise: false 15 | noise_clip: ${noise_clip} 16 | value_weight: ${value_weight} 17 | td_weight: ${td_weight} -------------------------------------------------------------------------------- /cfgs/approximator/meta_policy.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: meta_policy 2 | 3 | approximator: 4 | _target_: approximators.policy.MetaPolicyApproximator 5 | model: mlp 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | fast_lr: 0.01 12 | hidden_dim: ${hidden_dim} 13 | embed_dim: ${embed_dim} 14 | use_clipped_noise: false 15 | noise_clip: ${noise_clip} 16 | adaptation_steps: ${adaptation_steps} 17 | use_pearl: false 18 | kl_lambda: 0.1 -------------------------------------------------------------------------------- /cfgs/approximator/meta_rl.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: meta_rl 2 | 3 | approximator: 4 | _target_: approximators.rl_solution.MetaRLApproximator 5 | model: mlp 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | fast_lr: 0.01 12 | hidden_dim: ${hidden_dim} 13 | embed_dim: ${embed_dim} 14 | use_td: false 15 | use_clipped_noise: false 16 | noise_clip: ${noise_clip} 17 | value_weight: ${value_weight} 18 | td_weight: ${td_weight} 19 | adaptation_steps: ${adaptation_steps} -------------------------------------------------------------------------------- /cfgs/approximator/meta_rl_td.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: meta_rl_td 2 | 3 | approximator: 4 | _target_: approximators.rl_solution.MetaRLApproximator 5 | model: mlp 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | fast_lr: 0.01 12 | hidden_dim: ${hidden_dim} 13 | embed_dim: ${embed_dim} 14 | use_td: true 15 | use_clipped_noise: false 16 | noise_clip: ${noise_clip} 17 | value_weight: ${value_weight} 18 | td_weight: ${td_weight} 19 | adaptation_steps: ${adaptation_steps} -------------------------------------------------------------------------------- /cfgs/approximator/mlp_policy.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: mlp_policy 2 | 3 | approximator: 4 | _target_: approximators.policy.PolicyApproximator 5 | model: mlp 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | hidden_dim: ${hidden_dim} 12 | embed_dim: ${embed_dim} 13 | use_clipped_noise: false 14 | noise_clip: ${noise_clip} -------------------------------------------------------------------------------- /cfgs/approximator/mlp_rl.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: mlp_rl 2 | 3 | approximator: 4 | _target_: approximators.rl_solution.RLApproximator 5 | model: mlp 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | hidden_dim: ${hidden_dim} 12 | embed_dim: ${embed_dim} 13 | use_td: false 14 | use_clipped_noise: false 15 | noise_clip: ${noise_clip} 16 | value_weight: ${value_weight} 17 | td_weight: ${td_weight} -------------------------------------------------------------------------------- /cfgs/approximator/mlp_rl_td.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: mlp_rl_td 2 | 3 | approximator: 4 | _target_: approximators.rl_solution.RLApproximator 5 | model: mlp 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | hidden_dim: ${hidden_dim} 12 | embed_dim: ${embed_dim} 13 | use_td: true 14 | use_clipped_noise: false 15 | noise_clip: ${noise_clip} 16 | value_weight: ${value_weight} 17 | td_weight: ${td_weight} -------------------------------------------------------------------------------- /cfgs/approximator/pearl_policy.yaml: -------------------------------------------------------------------------------- 1 | approximator_name: pearl_policy 2 | 3 | approximator: 4 | _target_: approximators.policy.MetaPolicyApproximator 5 | model: mlp 6 | input_dim: ??? # to be specified later 7 | state_dim: ??? # to be specified later 8 | action_dim: ??? # to be specified later 9 | device: ${device} 10 | lr: 0.0001 11 | fast_lr: 0.01 12 | hidden_dim: ${hidden_dim} 13 | embed_dim: ${embed_dim} 14 | use_clipped_noise: false 15 | noise_clip: ${noise_clip} 16 | adaptation_steps: ${adaptation_steps} 17 | use_pearl: true 18 | kl_lambda: 0.01 -------------------------------------------------------------------------------- /cfgs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task@_global_: cheetah_run 4 | - obs@_global_: states 5 | - agent@_global_: td3 6 | - reward@_global_: cheetah_default 7 | - dynamics@_global_: default 8 | - override hydra/launcher: submitit_local 9 | 10 | # task settings 11 | discount: 0.99 12 | # train settings 13 | num_seed_frames: 4000 14 | # eval 15 | eval_every_frames: 10000 16 | num_eval_episodes: 10 17 | # plot 18 | plot_every_frames: 100000 19 | # save 20 | save_every_frames: 100000 21 | # snapshot 22 | save_snapshot: true 23 | # replay buffer 24 | replay_buffer_size: 1000000 25 | replay_buffer_num_workers: 4 26 | nstep: 1 27 | batch_size: 256 28 | # misc 29 | seed: 2 30 | device: cuda 31 | save_video: true 32 | save_train_video: false 33 | # experiment 34 | experiment: '' 35 | # agent 36 | lr: 1e-4 37 | feature_dim: 50 38 | 39 | hydra: 40 | job: 41 | chdir: True 42 | run: 43 | dir: ./results/${now:%Y.%m.%d}/${task_name}_${agent_name}_seed_${seed}_${reward_name}_${dynamics_name}_${experiment}_${now:%H-%M-%S}/ 44 | sweep: 45 | dir: ./results/${now:%Y.%m.%d}/${task_name}_${agent_name}_seed_${seed}_${reward_name}_${dynamics_name}_${experiment}_${now:%H-%M-%S}/ 46 | subdir: ${hydra.job.num} 47 | # sweeper: 48 | # params: 49 | # reward_parameters.ALL.margin: range(0.2,3.1,0.2) 50 | launcher: 51 | timeout_min: 4300 52 | cpus_per_task: 10 53 | gpus_per_node: 8 54 | tasks_per_node: 1 55 | mem_gb: 64 56 | nodes: 1 57 | submitit_folder: ./results_multirun/${now:%Y.%m.%d}/${task_name}_${agent_name}_seed_${seed}_${reward_name}_${experiment}_${now:%H-%M-%S}/ 58 | -------------------------------------------------------------------------------- /cfgs/config_rl_approximator.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - approximator@_global_: mlp_policy 4 | - override hydra/launcher: submitit_local 5 | 6 | # dataset dir 7 | rollout_dir: ~/workspace/hyperzero_private/rollout_data 8 | domain_task: cheetah_run 9 | # train settings 10 | num_train_epochs: 500 11 | batch_size: 512 12 | test_fraction: 0.2 13 | # save 14 | save_every_frames: 20 15 | # snapshot 16 | save_snapshot: true 17 | # misc 18 | seed: 2 19 | device: cuda 20 | # experiment 21 | experiment: '' 22 | # approximator 23 | input_to_model: 'rew' # options: 'rew', 'dyn', 'rew_dyn' 24 | noise_clip: 0.02 25 | value_weight: 0.01 26 | td_weight: 0.01 27 | hidden_dim: 256 28 | embed_dim: 256 29 | k_shot: 10 30 | adaptation_steps: 5 31 | 32 | hydra: 33 | job: 34 | chdir: True 35 | run: 36 | dir: ./results_approximator/${input_to_model}/seed_${seed}/${domain_task}/${domain_task}_${approximator_name}_${input_to_model}_seed_${seed}_${experiment}_${now:%Y.%m.%d-%H-%M-%S}/ 37 | sweep: 38 | dir: ./results_approximator/${input_to_model}/seed_${seed}/${domain_task}/${domain_task}_${approximator_name}_${input_to_model}_seed_${seed}_${experiment}_${now:%Y.%m.%d-%H-%M-%S}/ 39 | subdir: ${hydra.job.num} 40 | sweeper: 41 | params: 42 | approximator@_global_: mlp_policy,mlp_rl,mlp_rl_td,hyperzero,hyperzero_without_q,hyperzero_without_td,meta_policy 43 | launcher: 44 | timeout_min: 4300 45 | cpus_per_task: 10 46 | gpus_per_node: 8 47 | tasks_per_node: 1 48 | mem_gb: 64 49 | nodes: 1 50 | submitit_folder: ./results_approximator_multirun/${input_to_model}/seed_${seed}/${domain_task}/${domain_task}_${approximator_name}_${input_to_model}_seed_${seed}_${experiment}_${now:%Y.%m.%d-%H-%M-%S}/ 51 | -------------------------------------------------------------------------------- /cfgs/dynamics/cartpole.yaml: -------------------------------------------------------------------------------- 1 | dynamics_parameters: 2 | use_default: false 3 | mass: 0.1 4 | size: 0.045 5 | length: 1 6 | 7 | dynamics_name: dyn_${dynamics_parameters.mass}_${dynamics_parameters.size}_${dynamics_parameters.length} -------------------------------------------------------------------------------- /cfgs/dynamics/cheetah.yaml: -------------------------------------------------------------------------------- 1 | dynamics_parameters: 2 | use_default: false 3 | length: 0.5 4 | 5 | dynamics_name: dyn_${dynamics_parameters.length} -------------------------------------------------------------------------------- /cfgs/dynamics/default.yaml: -------------------------------------------------------------------------------- 1 | dynamics_parameters: 2 | use_default: true 3 | 4 | dynamics_name: dyn_default -------------------------------------------------------------------------------- /cfgs/dynamics/finger.yaml: -------------------------------------------------------------------------------- 1 | dynamics_parameters: 2 | use_default: false 3 | length: 0.16 4 | 5 | dynamics_name: dyn_${dynamics_parameters.length} -------------------------------------------------------------------------------- /cfgs/dynamics/walker.yaml: -------------------------------------------------------------------------------- 1 | dynamics_parameters: 2 | use_default: false 3 | length: 0.3 4 | 5 | dynamics_name: dyn_${dynamics_parameters.length} -------------------------------------------------------------------------------- /cfgs/obs/states.yaml: -------------------------------------------------------------------------------- 1 | pixel_obs: false 2 | frame_stack: 1 3 | action_repeat: 1 -------------------------------------------------------------------------------- /cfgs/reward/cartpole_default.yaml: -------------------------------------------------------------------------------- 1 | reward_parameters: 2 | centered: 3 | sigmoid: gaussian 4 | margin: 2 5 | value_at_margin: 0.1 6 | small_control: 7 | sigmoid: quadratic 8 | margin: 1 9 | value_at_margin: 0 10 | small_velocity: 11 | sigmoid: gaussian 12 | margin: 5 13 | value_at_margin: 0.1 14 | 15 | reward_name: default_${reward_parameters.centered.margin}_${reward_parameters.small_velocity.margin}_${reward_parameters.small_control.margin} -------------------------------------------------------------------------------- /cfgs/reward/cheetah_default.yaml: -------------------------------------------------------------------------------- 1 | reward_parameters: 2 | speed: 3 | bounds: [10, .inf] 4 | margin: 10 5 | value_at_margin: 0 6 | sigmoid: linear 7 | 8 | reward_name: default 9 | -------------------------------------------------------------------------------- /cfgs/reward/finger_default.yaml: -------------------------------------------------------------------------------- 1 | reward_parameters: 2 | spin: 3 | bounds: [15, .inf] 4 | margin: 15 5 | value_at_margin: 0 6 | sigmoid: linear 7 | 8 | reward_name: default 9 | -------------------------------------------------------------------------------- /cfgs/reward/overwrite_all.yaml: -------------------------------------------------------------------------------- 1 | reward_parameters: 2 | ALL: 3 | sigmoid: gaussian 4 | margin: 1 5 | value_at_margin: 0.1 6 | 7 | reward_name: overwrite_all-${reward_parameters.ALL.sigmoid}-${reward_parameters.ALL.margin} 8 | -------------------------------------------------------------------------------- /cfgs/reward/walker_default.yaml: -------------------------------------------------------------------------------- 1 | reward_parameters: 2 | horizontal_velocity: 3 | sigmoid: linear 4 | margin: 0 5 | value_at_margin: 0.1 6 | 7 | reward_name: default -------------------------------------------------------------------------------- /cfgs/task/cheetah_run.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: cheetah_run 6 | -------------------------------------------------------------------------------- /cfgs/task/easy.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 1100000 2 | stddev_schedule: 'linear(1.0,0.1,100000)' -------------------------------------------------------------------------------- /cfgs/task/finger_spin.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: finger_spin 6 | -------------------------------------------------------------------------------- /cfgs/task/medium.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 1100000 2 | stddev_schedule: 'linear(1.0,0.1,100000)' -------------------------------------------------------------------------------- /cfgs/task/walker_walk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: walker_walk 6 | nstep: 1 7 | batch_size: 512 8 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | 5 | import os 6 | import platform 7 | 8 | if platform.system() == 'Linux': 9 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 10 | os.environ['MUJOCO_GL'] = 'egl' 11 | 12 | import argparse 13 | from pathlib import Path 14 | 15 | import hydra 16 | import numpy as np 17 | import torch 18 | import omegaconf 19 | from omegaconf import OmegaConf 20 | from collections import defaultdict 21 | 22 | import utils.dmc as dmc 23 | import utils.utils as utils 24 | import utils.plots as plots 25 | from train import make_agent 26 | from train_rl_regressor import make_approximator 27 | from utils.video import VideoRecorder 28 | 29 | torch.backends.cudnn.benchmark = True 30 | # device = 'cuda' if torch.cuda.is_available() else 'cpu' 31 | device = 'cpu' 32 | 33 | 34 | class Workspace: 35 | def __init__(self, cfg, work_dir, args): 36 | self.work_dir = Path(work_dir) 37 | self.base_name = self.work_dir.parents[0].name 38 | 39 | self.cfg = cfg 40 | self.args = args 41 | utils.set_seed_everywhere(cfg.seed) 42 | self.device = torch.device(device) 43 | 44 | # Video dir 45 | self.video_dir = Path(args.video_dir).joinpath(f'{self.work_dir.parents[1].name}') 46 | self.video_dir.mkdir(exist_ok=True, parents=True) 47 | 48 | self.setup() 49 | 50 | # create and load the RL agent 51 | self.agent = make_agent(self.eval_env_rl_agent.observation_spec(), 52 | self.eval_env_rl_agent.action_spec(), 53 | self.cfg.agent, 54 | device=device) 55 | self.step_to_load = args.step_to_load if args.step_to_load != 0 else utils.get_last_model(self.agent_model_dir) 56 | self.agent.load(self.agent_model_dir, self.step_to_load) 57 | 58 | if args.rl_regressor_workdir is not None and args.rl_regressor_workdir != 'None': 59 | # create and load the RL regressor 60 | rl_regressor_cfg_path = Path(args.rl_regressor_workdir).joinpath('cfg.yaml') 61 | rl_regressor_cfg = OmegaConf.load(rl_regressor_cfg_path) 62 | self.rl_regressor_name = rl_regressor_cfg.approximator_name 63 | self.input_to_regressor = rl_regressor_cfg.input_to_model 64 | self.rl_regressor_seed = rl_regressor_cfg.seed 65 | self.is_meta_learning = True if 'meta' in self.rl_regressor_name else False 66 | 67 | # Approximated RL rollout dir 68 | self.rollout_comparison_data = Path(f"{args.rollout_dir}_comparison").joinpath(self.input_to_regressor, 69 | str(self.rl_regressor_seed), 70 | self.cfg.task_name) 71 | self.rollout_comparison_data.mkdir(exist_ok=True, parents=True) 72 | 73 | # Overwrite video dir and video recorder 74 | self.video_dir = Path(args.video_dir).joinpath(self.input_to_regressor, 75 | str(self.rl_regressor_seed), self.cfg.task_name) 76 | self.video_dir.mkdir(exist_ok=True, parents=True) 77 | self.video_recorder = VideoRecorder( 78 | self.video_dir, 79 | fps=60 // self.cfg.action_repeat 80 | ) 81 | 82 | rl_regressor_work_dir = Path(args.rl_regressor_workdir) 83 | rl_regressor_model_dir = rl_regressor_work_dir / 'models' 84 | 85 | if self.input_to_regressor == 'rew': 86 | input_dim = self._get_reward_param_dim() 87 | elif self.input_to_regressor == 'dyn': 88 | input_dim = self._get_dynamics_param_dim() 89 | elif self.input_to_regressor == 'rew_dyn': 90 | input_dim = self._get_reward_dynamics_param_dim() 91 | else: 92 | raise NotImplementedError 93 | 94 | self.rl_regressor = make_approximator(input_dim, 95 | self.eval_env_rl_agent.observation_spec().shape[0], 96 | self.eval_env_rl_agent.action_spec().shape[0], 97 | rl_regressor_cfg.approximator, 98 | device=device) 99 | regressor_step_to_load = utils.get_last_model(rl_regressor_model_dir) 100 | # regressor_step_to_load = 'best_total' 101 | self.rl_regressor.load(rl_regressor_model_dir, regressor_step_to_load) 102 | if not hasattr(self.rl_regressor, 'act'): 103 | print("RL regressor does not have the policy.") 104 | self.rl_regressor = None 105 | else: 106 | print("Did not load the RL regressor.") 107 | 108 | # RL Rollout dir 109 | self.rollout_dir = Path(args.rollout_dir).joinpath(self.cfg.task_name) 110 | self.rollout_dir.mkdir(exist_ok=True, parents=True) 111 | 112 | self.rl_regressor = None 113 | self.rl_regressor_name = '' 114 | 115 | def setup(self): 116 | # get the reward parameters 117 | reward_parameters = OmegaConf.to_container(self.cfg.reward_parameters) 118 | 119 | # get the dynamics parameters 120 | try: 121 | dynamics_parameters = OmegaConf.to_container(self.cfg.dynamics_parameters) 122 | except omegaconf.errors.ConfigAttributeError: 123 | dynamics_parameters = {'use_default': True} 124 | 125 | # create envs with equal but independent random generators 126 | rg_1 = np.random.RandomState(self.cfg.seed) 127 | rg_2 = np.random.RandomState(self.cfg.seed) 128 | 129 | self.eval_env_rl_agent = dmc.make(self.cfg.task_name, self.cfg.frame_stack, 130 | self.cfg.action_repeat, reward_parameters, 131 | dynamics_parameters, rg_1, self.cfg.pixel_obs) 132 | self.eval_env_rl_approx = dmc.make(self.cfg.task_name, self.cfg.frame_stack, 133 | self.cfg.action_repeat, reward_parameters, 134 | dynamics_parameters, rg_2, self.cfg.pixel_obs) 135 | 136 | try: 137 | _module = self.eval_env_rl_agent.task.__module__ 138 | self.domain = _module.rpartition('.')[-1] 139 | except AttributeError: 140 | self.domain = None 141 | 142 | self.video_recorder = VideoRecorder( 143 | self.video_dir, 144 | fps=60 // self.cfg.action_repeat 145 | ) 146 | 147 | self.plot_dir = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), 'eval_plots', 148 | f'{self.work_dir.parents[1].name}'))) 149 | self.plot_dir.mkdir(exist_ok=True, parents=True) 150 | self.agent_model_dir = self.work_dir / 'models' 151 | 152 | def rollout(self, n_episodes=1, use_approximator=False): 153 | rollout_data = defaultdict(list) 154 | env = self.eval_env_rl_approx if use_approximator else self.eval_env_rl_agent 155 | 156 | for episode in range(n_episodes): 157 | episode_rollout = defaultdict(list) 158 | time_step = env.reset() 159 | self.video_recorder.init(env, enabled=(episode == 0 and self.args.eval_mode == 'comparison_data')) 160 | 161 | while not time_step.last(): 162 | with torch.no_grad(), utils.eval_mode(self.agent): 163 | reward_param = self._get_reward_param() 164 | dynamics_param = self._get_dynamics_param() 165 | reward_dynamics_param = self._get_reward_dynamics_param() 166 | 167 | if use_approximator: 168 | if self.input_to_regressor == 'rew': 169 | action = self.rl_regressor.act(reward_param, time_step.observation) 170 | elif self.input_to_regressor == 'dyn': 171 | action = self.rl_regressor.act(dynamics_param, time_step.observation) 172 | elif self.input_to_regressor == 'rew_dyn': 173 | action = self.rl_regressor.act(reward_dynamics_param, time_step.observation) 174 | else: 175 | raise NotImplementedError 176 | else: 177 | action = self.agent.act(time_step.observation, self.step_to_load, eval_mode=True) 178 | observed = self.agent.observe(time_step.observation, action) 179 | episode_rollout['value'].append(observed['value']) 180 | 181 | # save the trajectory 182 | episode_rollout['reward_param'].append(reward_param) 183 | episode_rollout['dynamics_param'].append(dynamics_param) 184 | episode_rollout['state'].append(time_step.observation) 185 | episode_rollout['discount'].append(time_step.discount) 186 | episode_rollout['action'].append(action) 187 | episode_rollout['physics_qpos'].append(env._physics.data.qpos.copy()) 188 | episode_rollout['physics_qvel'].append(env._physics.data.qvel.copy()) 189 | 190 | time_step = env.step(action) 191 | self.video_recorder.record(env) 192 | 193 | episode_rollout['reward'].append([time_step.reward]) 194 | episode_rollout['next_state'].append(time_step.observation) 195 | 196 | if use_approximator: 197 | video_name = f'{self.base_name}_{self.rl_regressor_name}-approxseed-{self.rl_regressor_seed}.mp4' 198 | else: 199 | video_name = f'{self.base_name}.mp4' 200 | self.video_recorder.save(video_name) 201 | 202 | # concatenate across the current episode 203 | for k, v in episode_rollout.items(): 204 | rollout_data[k].append(np.stack(v)) 205 | 206 | # concatenate across all episodes 207 | for k, v in rollout_data.items(): 208 | rollout_data[k] = np.stack(v) 209 | 210 | return rollout_data 211 | 212 | def finetune_meta_policy(self, data): 213 | n_episodes = 10 214 | state = data['state'][0:n_episodes, :, :].reshape(n_episodes * 1000, -1) 215 | action = data['action'][0:n_episodes, :, :].reshape(n_episodes * 1000, -1) 216 | batch_size = state.shape[0] 217 | 218 | if self.input_to_regressor == 'rew': 219 | input_param = self._get_reward_param() 220 | elif self.input_to_regressor == 'dyn': 221 | input_param = self._get_dynamics_param() 222 | elif self.input_to_regressor == 'rew_dyn': 223 | input_param = self._get_reward_dynamics_param() 224 | else: 225 | raise NotImplementedError 226 | input_param = np.repeat(input_param, batch_size).reshape(batch_size, -1) 227 | 228 | self.rl_regressor.finetune(input_param, state, action) 229 | 230 | def save_rollout(self, rollout_data, dir, name=''): 231 | path = f"{dir}/{name}.npy" 232 | np.save(path, rollout_data, allow_pickle=True) 233 | 234 | def _get_reward_param(self): 235 | reward_param = [self.cfg.reward_parameters.ALL.margin] 236 | return reward_param 237 | 238 | def _get_reward_param_dim(self): 239 | reward_param_dim = 1 240 | return reward_param_dim 241 | 242 | def _get_dynamics_param(self): 243 | try: 244 | dynamics_param = [self.cfg.dynamics_parameters.length] 245 | except omegaconf.errors.ConfigAttributeError: 246 | dynamics_param = [0] 247 | return dynamics_param 248 | 249 | def _get_dynamics_param_dim(self): 250 | # for now only a single param is changed for all experiments 251 | dynamics_param_dim = 1 252 | return dynamics_param_dim 253 | 254 | def _get_reward_dynamics_param(self): 255 | try: 256 | dynamics_param = [self.cfg.reward_parameters.ALL.margin, 257 | self.cfg.dynamics_parameters.length] 258 | except omegaconf.errors.ConfigAttributeError: 259 | dynamics_param = [0, 0] 260 | return dynamics_param 261 | 262 | def _get_reward_dynamics_param_dim(self): 263 | reward_dynamics_dim = 2 264 | return reward_dynamics_dim 265 | 266 | 267 | def main(): 268 | parser = argparse.ArgumentParser() 269 | parser.add_argument('--workdir', type=str, default='results') 270 | parser.add_argument('--eval_mode', choices=['sl_data', 'comparison_data'], default='comparison_data') 271 | parser.add_argument('--rl_regressor_workdir', type=str, default=None) 272 | parser.add_argument('--step_to_load', type=int, default=0) 273 | parser.add_argument('--n_episodes', type=int, default=10) 274 | parser.add_argument('--vis', action='store_true', default=False) 275 | parser.add_argument('--rollout_dir', type=str, default='rollout_data') 276 | parser.add_argument('--video_dir', type=str, default='video_logs') 277 | args = parser.parse_args() 278 | 279 | cfg_path = Path(args.workdir).joinpath('cfg.yaml') 280 | cfg = OmegaConf.load(cfg_path) 281 | 282 | workspace = Workspace(cfg, args.workdir, args) 283 | 284 | # Generate the data used for supervised learning 285 | if args.eval_mode == 'sl_data': 286 | sl_rollout_fname = f"{workspace.base_name}_rollout" 287 | rl_rollout_data = workspace.rollout(n_episodes=args.n_episodes, use_approximator=False) 288 | workspace.save_rollout(rl_rollout_data, 289 | dir=workspace.rollout_dir, 290 | name=sl_rollout_fname) 291 | 292 | # Rollout RL agent and the approximator 293 | if args.eval_mode == 'comparison_data': 294 | assert workspace.rl_regressor is not None, "RL approximator is not loaded." 295 | 296 | # Rollout RL agent if not saved and not MAML 297 | agent_rollout_fname = f"{workspace.base_name}_rollout_{workspace.base_name}_rollout_agent" 298 | if not Path(f"{workspace.rollout_comparison_data}/{agent_rollout_fname}.npy").is_file() or workspace.is_meta_learning: 299 | rl_rollout_data = workspace.rollout(n_episodes=args.n_episodes, use_approximator=False) 300 | # workspace.save_rollout(rl_rollout_data, 301 | # dir=workspace.rollout_comparison_data, 302 | # name=agent_rollout_fname) 303 | else: 304 | print("Skipping rolling out the RL agent, because the data is already generated.") 305 | 306 | # Rollout the approximator if not saved 307 | approx_rollout_fname = f"{workspace.base_name}_rollout_approx-{workspace.rl_regressor_name}_approxseed-{workspace.rl_regressor_seed}" 308 | if not Path(f"{workspace.rollout_comparison_data}/{approx_rollout_fname}.npy").is_file(): 309 | approximator_rollout_data = workspace.rollout(n_episodes=args.n_episodes, use_approximator=True) 310 | workspace.save_rollout(approximator_rollout_data, 311 | dir=workspace.rollout_comparison_data, 312 | name=approx_rollout_fname) 313 | else: 314 | print(f"Skipping rolling out the {workspace.rl_regressor_name}, because the data is already generated.") 315 | 316 | # Finetune the meta policy with RL rollout and then evaluate it 317 | if workspace.is_meta_learning: 318 | finetuned_approx_rollout_fname = f"{workspace.base_name}_rollout_approx-finetuned_{workspace.rl_regressor_name}_approxseed-{workspace.rl_regressor_seed}" 319 | workspace.finetune_meta_policy(rl_rollout_data) 320 | finetuned_approximator_rollout_data = workspace.rollout(n_episodes=args.n_episodes, use_approximator=True) 321 | workspace.save_rollout(finetuned_approximator_rollout_data, 322 | dir=workspace.rollout_comparison_data, 323 | name=finetuned_approx_rollout_fname) 324 | 325 | # Visualization 326 | if args.vis: 327 | # Visualization 328 | plot_type = 'scatter' 329 | for z_data, label in zip([rl_rollout_data['value'], rl_rollout_data['reward']], ['V(s)', 'R(s)']): 330 | # visualization of the rollout, values/rewards in the actual MDP 331 | plots.visualize_phase_space( 332 | rl_rollout_data['physics_qpos'], 333 | rl_rollout_data['physics_qvel'], 334 | z_data, workspace.plot_dir, 335 | f"{workspace.base_name}_phase_{label}_{plot_type}_{args.random_rollout * 'random'}", 336 | goal_coord=None, 337 | plot_type='scatter', label=label 338 | ) 339 | 340 | # visualization of the predicted rollout, values/rewards in the actual MDP 341 | # TODO 342 | 343 | 344 | if __name__ == '__main__': 345 | main() 346 | -------------------------------------------------------------------------------- /eval_many_agents.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for evaluating and rolling out several RL agents. 3 | The rollout dataset is saved for training RL approximators 4 | using supervised learning. 5 | """ 6 | import warnings 7 | warnings.filterwarnings('ignore', category=DeprecationWarning) 8 | 9 | import os 10 | import subprocess 11 | import argparse 12 | from pathlib import Path 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--rootdir', type=str, default="results") 17 | parser.add_argument('--domain_task', type=str, default='cheetah_run') 18 | parser.add_argument('--step_to_load', type=int, default=1000000) 19 | parser.add_argument('--n_episodes', type=int, default=10) 20 | parser.add_argument('--vis', action='store_true', default=False) 21 | parser.add_argument('--rollout_dir', type=str, default='rollout_data') 22 | parser.add_argument('--video_dir', type=str, default='video_logs') 23 | args = parser.parse_args() 24 | 25 | 26 | root_dir = Path(args.rootdir) 27 | paths = sorted(root_dir.glob(f'**/*{args.domain_task}*/**/step_*{args.step_to_load}')) 28 | 29 | seeds_list = [] 30 | 31 | for p in paths: 32 | workdir = p.parents[1] 33 | seed = int(p.parents[2].name.split('_')[4]) 34 | command = [ 35 | 'python', 36 | 'eval.py', 37 | '--workdir', 38 | str(workdir), 39 | '--step_to_load', 40 | str(args.step_to_load), 41 | '--n_episodes', 42 | str(args.n_episodes), 43 | '--rollout_dir', 44 | str(args.rollout_dir), 45 | '--video_dir', 46 | str(args.video_dir), 47 | '--rl_regressor_workdir', 48 | str(None), 49 | '--eval_mode', 50 | 'sl_data', 51 | ] 52 | if args.vis: 53 | command += ['--vis'] 54 | 55 | print(f"Running {command}") 56 | process = subprocess.run(command, capture_output=True) 57 | print(f"Returncode of the process: {process.returncode}") 58 | if process.returncode == 0: 59 | seeds_list.append(seed) 60 | else: 61 | print(process.stderr) 62 | -------------------------------------------------------------------------------- /eval_many_approximators.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for evaluating and rolling out several RL agents. 3 | The rollout dataset is saved for training RL approximators. 4 | """ 5 | import warnings 6 | warnings.filterwarnings('ignore', category=DeprecationWarning) 7 | 8 | import os 9 | import subprocess 10 | import argparse 11 | from pathlib import Path 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--rootdir', type=str, default="results") 16 | parser.add_argument('--domain_task', type=str, default='cheetah_run') 17 | parser.add_argument('--approximator_rootdir', type=str, default="results_approximator") 18 | parser.add_argument('--step_to_load', type=int, default=1000000) 19 | parser.add_argument('--n_episodes', type=int, default=10) 20 | parser.add_argument('--random_rollout', action='store_true', default=False) 21 | parser.add_argument('--vis', action='store_true', default=False) 22 | parser.add_argument('--rollout_dir', type=str, default='rollout_data') 23 | parser.add_argument('--video_dir', type=str, default='video_logs') 24 | args = parser.parse_args() 25 | 26 | 27 | root_dir = Path(args.rootdir) 28 | approx_root_dir = Path(args.approximator_rootdir) 29 | paths = sorted(root_dir.glob(f'**/*{args.domain_task}*/**/step_*{args.step_to_load}')) 30 | approx_paths = sorted(approx_root_dir.glob(f'**/*{args.domain_task}*/**/*best_total')) 31 | 32 | seeds_list = [] 33 | 34 | for p in paths: 35 | # Skip evaluating zero margins 36 | if '-0.0' not in str(p): 37 | for approx_p in approx_paths: 38 | workdir = p.parents[1] 39 | approx_workdir = approx_p.parents[1] 40 | seed = int(p.parents[2].name.split('_')[4]) 41 | command = [ 42 | 'python', 43 | 'eval.py', 44 | '--workdir', 45 | str(workdir), 46 | '--step_to_load', 47 | str(args.step_to_load), 48 | '--rl_regressor_workdir', 49 | str(approx_workdir), 50 | '--n_episodes', 51 | str(args.n_episodes), 52 | '--rollout_dir', 53 | str(args.rollout_dir), 54 | '--video_dir', 55 | str(args.video_dir), 56 | '--eval_mode', 57 | 'comparison_data', 58 | ] 59 | if args.vis: 60 | command += ['--vis'] 61 | 62 | print(f"Running {command}") 63 | process = subprocess.run(command, capture_output=True) 64 | print(f"Returncode of the process: {process.returncode}") 65 | if process.returncode == 0: 66 | seeds_list.append(seed) 67 | else: 68 | print(process.stderr) 69 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAIC-MONTREAL/hyperzero/ab0508a73c09940d8c98267af8ae021d834915d0/models/__init__.py -------------------------------------------------------------------------------- /models/core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import utils.utils as utils 7 | 8 | 9 | def gaussian_logprob(noise, log_std): 10 | """ 11 | Compute Gaussian log probability. 12 | """ 13 | residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) 14 | return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) 15 | 16 | 17 | def squash(mu, pi, log_pi): 18 | """ 19 | Apply squashing function. 20 | """ 21 | mu = torch.tanh(mu) 22 | if pi is not None: 23 | pi = torch.tanh(pi) 24 | if log_pi is not None: 25 | log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) 26 | return mu, pi, log_pi 27 | 28 | 29 | class DeterministicActor(nn.Module): 30 | """ 31 | Original TD3 actor. 32 | """ 33 | def __init__(self, feature_dim, action_dim, hidden_dim): 34 | super(DeterministicActor, self).__init__() 35 | 36 | self.policy = nn.Sequential( 37 | nn.Linear(feature_dim, hidden_dim), 38 | nn.ReLU(inplace=True), 39 | nn.Linear(hidden_dim, hidden_dim), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(hidden_dim, action_dim) 42 | ) 43 | 44 | self.apply(utils.weight_init) 45 | 46 | def forward(self, state): 47 | a = self.policy(state) 48 | return torch.tanh(a) 49 | 50 | 51 | class Critic(nn.Module): 52 | """ 53 | Original TD3 critic. 54 | """ 55 | def __init__(self, feature_dim, action_dim, hidden_dim): 56 | super().__init__() 57 | 58 | # Q1 architecture 59 | self.Q1_net = nn.Sequential( 60 | nn.Linear(feature_dim + action_dim, hidden_dim), 61 | nn.ReLU(inplace=True), 62 | nn.Linear(hidden_dim, hidden_dim), 63 | nn.ReLU(inplace=True), 64 | nn.Linear(hidden_dim, 1) 65 | ) 66 | 67 | # Q2 architecture 68 | self.Q2_net = nn.Sequential( 69 | nn.Linear(feature_dim + action_dim, hidden_dim), 70 | nn.ReLU(inplace=True), 71 | nn.Linear(hidden_dim, hidden_dim), 72 | nn.ReLU(inplace=True), 73 | nn.Linear(hidden_dim, 1) 74 | ) 75 | 76 | self.apply(utils.weight_init) 77 | 78 | def forward(self, state, action): 79 | sa = torch.cat([state, action], 1) 80 | 81 | q1 = self.Q1_net(sa) 82 | q2 = self.Q2_net(sa) 83 | return q1, q2 84 | 85 | def Q1(self, state, action): 86 | sa = torch.cat([state, action], 1) 87 | 88 | q1 = self.Q1_net(sa) 89 | return q1 90 | -------------------------------------------------------------------------------- /models/hypenet_core.py: -------------------------------------------------------------------------------- 1 | """ 2 | HyperNetwork implementation is based on 3 | 4 | https://arxiv.org/abs/2106.06842 5 | https://github.com/keynans/HypeRL 6 | """ 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class ResBlock(nn.Module): 14 | """ 15 | Residual block used for learnable task embeddings. 16 | """ 17 | def __init__(self, in_size, out_size): 18 | super().__init__() 19 | self.fc = nn.Sequential( 20 | nn.ReLU(), 21 | nn.Linear(in_size, out_size), 22 | nn.ReLU(), 23 | nn.Linear(out_size, out_size), 24 | ) 25 | 26 | def forward(self, x): 27 | h = self.fc(x) 28 | return x + h 29 | 30 | 31 | class Head(nn.Module): 32 | """ 33 | Hypernetwork head for generating weights of a single layer of an MLP. 34 | """ 35 | def __init__(self, latent_dim, output_dim_in, output_dim_out, sttdev): 36 | super().__init__() 37 | 38 | self.output_dim_in = output_dim_in 39 | self.output_dim_out = output_dim_out 40 | 41 | self.W1 = nn.Linear(latent_dim, output_dim_in * output_dim_out) 42 | self.b1 = nn.Linear(latent_dim, output_dim_out) 43 | 44 | self.init_layers(sttdev) 45 | 46 | def forward(self, x): 47 | # weights, bias and scale for dynamic layer 48 | w = self.W1(x).view(-1, self.output_dim_out, self.output_dim_in) 49 | b = self.b1(x).view(-1, self.output_dim_out, 1) 50 | 51 | return w, b 52 | 53 | def init_layers(self, stddev): 54 | torch.nn.init.uniform_(self.W1.weight, -stddev, stddev) 55 | torch.nn.init.uniform_(self.b1.weight, -stddev, stddev) 56 | 57 | torch.nn.init.zeros_(self.W1.bias) 58 | torch.nn.init.zeros_(self.b1.bias) 59 | 60 | 61 | class Meta_Embadding(nn.Module): 62 | """ 63 | Hypernetwork meta embedding. 64 | """ 65 | def __init__(self, meta_dim, z_dim): 66 | super().__init__() 67 | 68 | self.z_dim = z_dim 69 | self.hyper = nn.Sequential( 70 | nn.Linear(meta_dim, z_dim // 4), 71 | ResBlock(z_dim // 4, z_dim // 4), 72 | ResBlock(z_dim // 4, z_dim // 4), 73 | 74 | nn.Linear(z_dim // 4, z_dim // 2), 75 | ResBlock(z_dim // 2, z_dim // 2), 76 | ResBlock(z_dim // 2, z_dim // 2), 77 | 78 | nn.Linear(z_dim // 2, z_dim), 79 | ResBlock(z_dim, z_dim), 80 | ResBlock(z_dim, z_dim), 81 | ) 82 | 83 | self.init_layers() 84 | 85 | def forward(self, meta_v): 86 | z = self.hyper(meta_v).view(-1, self.z_dim) 87 | return z 88 | 89 | def init_layers(self): 90 | for module in self.hyper.modules(): 91 | if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Linear)): 92 | fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(module.weight) 93 | bound = 1. / (2. * math.sqrt(fan_in)) 94 | torch.nn.init.uniform_(module.weight, -bound, bound) 95 | 96 | 97 | class HyperNetwork(nn.Module): 98 | """ 99 | A hypernetwork that creates another neural network of 100 | base_v_input_dim -> base_v_output_dim using z_dim. 101 | """ 102 | def __init__(self, meta_v_dim, z_dim, base_v_input_dim, base_v_output_dim, 103 | dynamic_layer_dim, base_output_activation=None): 104 | super().__init__() 105 | 106 | self.base_output_activation = base_output_activation 107 | self.hyper = Meta_Embadding(meta_v_dim, z_dim) 108 | 109 | # main network 110 | self.layer1 = Head(z_dim, base_v_input_dim, dynamic_layer_dim, sttdev=0.05) 111 | self.last_layer = Head(z_dim, dynamic_layer_dim, base_v_output_dim, sttdev=0.008) 112 | 113 | def forward(self, meta_v, base_v): 114 | # produce dynamic weights 115 | z = self.hyper(meta_v) 116 | w1, b1 = self.layer1(z) 117 | w2, b2 = self.last_layer(z) 118 | 119 | # dynamic network pass 120 | out = F.relu(torch.bmm(w1, base_v.unsqueeze(2)) + b1) 121 | out = torch.bmm(w2, out) + b2 122 | if self.base_output_activation is not None: 123 | out = self.base_output_activation(out) 124 | 125 | batch_size = out.shape[0] 126 | return z, out.view(batch_size, -1) 127 | 128 | 129 | class DoubleHeadedHyperNetwork(nn.Module): 130 | """ 131 | A hypernetwork that creates two neural networks of 132 | base_v_input_dim[i] -> base_v_output_dim[i] using z_dim. 133 | """ 134 | def __init__(self, meta_v_dim, z_dim, base_v_input_dim, base_v_output_dim, 135 | dynamic_layer_dim, base_output_activation=None): 136 | super().__init__() 137 | assert isinstance(base_v_input_dim, list) 138 | assert isinstance(base_v_output_dim, list) 139 | assert isinstance(base_output_activation, list) or base_output_activation is None 140 | 141 | self.base_output_activation = base_output_activation 142 | self.hyper = Meta_Embadding(meta_v_dim, z_dim) 143 | 144 | # main networks 145 | self.layer1_1 = Head(z_dim, base_v_input_dim[0], dynamic_layer_dim, sttdev=0.05) 146 | self.last_layer_1 = Head(z_dim, dynamic_layer_dim, base_v_output_dim[0], sttdev=0.008) 147 | 148 | self.layer1_2 = Head(z_dim, base_v_input_dim[1], dynamic_layer_dim, sttdev=0.05) 149 | self.last_layer_2 = Head(z_dim, dynamic_layer_dim, base_v_output_dim[1], sttdev=0.008) 150 | 151 | def forward(self, meta_v, base_v_1, base_v_2): 152 | z = self.hyper(meta_v) 153 | out_1 = self.forward_net_1(z, base_v_1) 154 | out_2 = self.forward_net_2(z, base_v_2) 155 | return z, out_1, out_2 156 | 157 | def embed(self, meta_v): 158 | z = self.hyper(meta_v) 159 | return z 160 | 161 | def forward_net_1(self, z, base_v_1): 162 | # produce dynamic weights for network #1 163 | w1_1, b1_1 = self.layer1_1(z) 164 | w2_1, b2_1 = self.last_layer_1(z) 165 | 166 | # dynamic network 1 pass 167 | out_1 = F.relu(torch.bmm(w1_1, base_v_1.unsqueeze(2)) + b1_1) 168 | out_1 = torch.bmm(w2_1, out_1) + b2_1 169 | if self.base_output_activation[0] is not None: 170 | out_1 = self.base_output_activation[0](out_1) 171 | 172 | batch_size = out_1.shape[0] 173 | return out_1.view(batch_size, -1) 174 | 175 | def forward_net_2(self, z, base_v_2): 176 | # produce dynamic weights for network #2 177 | w1_2, b1_2 = self.layer1_2(z) 178 | w2_2, b2_2 = self.last_layer_2(z) 179 | 180 | # dynamic network 2 pass 181 | out_2 = F.relu(torch.bmm(w1_2, base_v_2.unsqueeze(2)) + b1_2) 182 | out_2 = torch.bmm(w2_2, out_2) + b2_2 183 | if self.base_output_activation[1] is not None: 184 | out_2 = self.base_output_activation[1](out_2) 185 | 186 | batch_size = out_2.shape[0] 187 | return out_2.view(batch_size, -1) 188 | -------------------------------------------------------------------------------- /models/rl_regressor.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.nn as nn 4 | 5 | import utils.utils as utils 6 | from models.hypenet_core import HyperNetwork, DoubleHeadedHyperNetwork, Meta_Embadding 7 | 8 | 9 | class HyperPolicy(nn.Module): 10 | """ 11 | Approximates the mapping R(\phi) -> \pi^* (a|s) 12 | """ 13 | def __init__(self, input_param_dim, state_dim, action_dim, embed_dim, hidden_dim): 14 | super().__init__() 15 | 16 | self.hyper_policy = HyperNetwork( 17 | meta_v_dim=input_param_dim, 18 | z_dim=embed_dim, 19 | base_v_input_dim=state_dim, 20 | base_v_output_dim=action_dim, 21 | dynamic_layer_dim=hidden_dim, 22 | base_output_activation=torch.tanh 23 | ) 24 | 25 | def forward(self, input_param, state): 26 | z, action = self.hyper_policy(input_param, state) 27 | return action 28 | 29 | 30 | class HyperRLSolution(nn.Module): 31 | """ 32 | Baseline. Approximates the mapping R(\phi) -> Q^*(s, a), \pi^*(s) 33 | """ 34 | def __init__(self, input_param_dim, state_dim, action_dim, embed_dim, hidden_dim): 35 | super().__init__() 36 | 37 | self.hyper_rl_net = DoubleHeadedHyperNetwork( 38 | meta_v_dim=input_param_dim, 39 | z_dim=embed_dim, 40 | base_v_input_dim=[state_dim, state_dim + action_dim], 41 | base_v_output_dim=[action_dim, 1], 42 | dynamic_layer_dim=hidden_dim, 43 | base_output_activation=[torch.tanh, None] 44 | ) 45 | 46 | def forward(self, input_param, state, action): 47 | state_action = torch.cat([state, action], dim=-1) 48 | z, pred_action, q_value = self.hyper_rl_net(input_param, state, state_action) 49 | return z, pred_action, q_value 50 | 51 | def embed_task(self, input_param): 52 | z = self.hyper_rl_net.embed(input_param) 53 | return z 54 | 55 | def predict_action(self, z, state): 56 | pred_action = self.hyper_rl_net.forward_net_1(z, state) 57 | return pred_action 58 | 59 | def predict_q_value(self, z, state, action): 60 | state_action = torch.cat([state, action], dim=-1) 61 | q_value = self.hyper_rl_net.forward_net_2(z, state_action) 62 | return q_value 63 | 64 | 65 | class MLPActionPredictor(nn.Module): 66 | """ 67 | Baseline. Approximates the mapping R(\phi) -> \pi^*(s) 68 | """ 69 | def __init__(self, input_param_dim, state_dim, action_dim, embed_dim, hidden_dim): 70 | super().__init__() 71 | 72 | self.embedding = Meta_Embadding(input_param_dim, embed_dim) 73 | 74 | self.policy_net = nn.Sequential( 75 | nn.Linear(embed_dim + state_dim, hidden_dim), 76 | nn.ReLU(inplace=True), 77 | nn.Linear(hidden_dim, action_dim) 78 | ) 79 | self.apply(utils.weight_init) 80 | 81 | def forward(self, input_param, state): 82 | emb = self.embedding(input_param) 83 | emb_states = torch.cat([emb, state], dim=-1) 84 | action = self.policy_net(emb_states) 85 | return torch.tanh(action) 86 | 87 | 88 | class MLPRLSolution(nn.Module): 89 | """ 90 | Baseline. Approximates the mapping R(\phi) -> Q^*(s, a), \pi^*(s) 91 | """ 92 | def __init__(self, input_param_dim, state_dim, action_dim, embed_dim, hidden_dim): 93 | super().__init__() 94 | 95 | self.embedding = Meta_Embadding(input_param_dim, embed_dim) 96 | 97 | self.q_net = nn.Sequential( 98 | nn.Linear(embed_dim + state_dim + action_dim, hidden_dim), 99 | nn.ReLU(inplace=True), 100 | nn.Linear(hidden_dim, 1) 101 | ) 102 | self.policy_net = nn.Sequential( 103 | nn.Linear(embed_dim + state_dim, hidden_dim), 104 | nn.ReLU(inplace=True), 105 | nn.Linear(hidden_dim, action_dim), 106 | ) 107 | self.apply(utils.weight_init) 108 | 109 | def forward(self, input_param, state, action): 110 | task = self.embed_task(input_param) 111 | pred_action = self.predict_action(task, state) 112 | q_value = self.predict_q_value(task, state, action) 113 | return task, pred_action, q_value 114 | 115 | def embed_task(self, input_param): 116 | task = self.embedding(input_param) 117 | return task 118 | 119 | def predict_action(self, task, state): 120 | task_state = torch.cat([task, state], dim=-1) 121 | pred_action = self.policy_net(task_state) 122 | return torch.tanh(pred_action) 123 | 124 | def predict_q_value(self, task, state, action): 125 | task_state_action = torch.cat([task, state, action], dim=-1) 126 | q_value = self.q_net(task_state_action) 127 | return q_value 128 | 129 | 130 | class MLPContextEncoder(nn.Module): 131 | """ 132 | Context encoder of PEARL. 133 | """ 134 | def __init__(self, state_dim, action_dim, embed_dim, hidden_dim): 135 | super().__init__() 136 | 137 | self.fc = nn.Sequential( 138 | nn.Linear(state_dim + action_dim, hidden_dim), 139 | nn.ReLU(inplace=True), 140 | nn.Linear(hidden_dim, hidden_dim), 141 | nn.ReLU(inplace=True), 142 | ) 143 | self.mu = nn.Linear(hidden_dim, embed_dim) 144 | self.log_var = nn.Linear(hidden_dim, embed_dim) 145 | 146 | self.apply(utils.weight_init) 147 | 148 | def encode(self, state, action): 149 | state_action = torch.cat((state, action), dim=1) 150 | z = self.fc(state_action) 151 | return self.mu(z), self.log_var(z) 152 | 153 | def reparameterize(self, mu, log_var): 154 | std = torch.exp(0.5 * log_var) 155 | eps = torch.randn_like(std) 156 | return eps * std + mu 157 | 158 | def forward(self, state, action): 159 | mu, log_var = self.encode(state, action) 160 | z = self.reparameterize(mu, log_var) 161 | return z, mu, log_var 162 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | torch==1.12.0 3 | torchvision==0.13.0 4 | learn2learn==0.1.7 5 | termcolor==1.1.0 6 | dm_control 7 | dm_env 8 | imageio==2.9.0 9 | hydra-core==1.2.0 10 | hydra-submitit-launcher==1.1.5 11 | pandas==1.3.0 12 | ipdb==0.13.9 13 | yapf==0.31.0 14 | sklearn==0.0 15 | matplotlib 16 | opencv-python==4.5.3.56 17 | gitpython 18 | matplotlib==3.5.2 19 | protobuf==3.19.4 20 | imageio==2.9.0 21 | imageio-ffmpeg==0.4.4 22 | pandas==1.3.0 23 | numpy==1.24.1 24 | seaborn==0.11.2 25 | hypnettorch==0.0.4 26 | gym==0.23.1 27 | gym-notices==0.0.06 28 | setuptools==59.5.0 29 | tensorboard -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standard RL training loop. 3 | Works on DMC with states and pixel observations. 4 | """ 5 | 6 | import warnings 7 | 8 | warnings.filterwarnings('ignore', category=DeprecationWarning) 9 | 10 | import os 11 | import platform 12 | import logging 13 | 14 | if platform.system() == 'Linux': 15 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 16 | os.environ['MUJOCO_GL'] = 'egl' 17 | 18 | from pathlib import Path 19 | 20 | import hydra 21 | import omegaconf 22 | import numpy as np 23 | import torch 24 | from dm_env import specs 25 | from hydra.core.hydra_config import HydraConfig 26 | from omegaconf import OmegaConf 27 | 28 | import utils.dmc as dmc 29 | import utils.utils as utils 30 | from utils.logger import Logger 31 | from utils.replay_buffer import ReplayBufferStorage, make_replay_loader 32 | from utils.video import TrainVideoRecorder, VideoRecorder 33 | 34 | torch.backends.cudnn.benchmark = True 35 | 36 | # If using multirun, set the GPUs here: 37 | AVAILABLE_GPUS = [1, 2, 3, 4, 0] 38 | 39 | 40 | def make_agent(obs_spec, action_spec, cfg, device=None): 41 | cfg.obs_shape = obs_spec.shape 42 | cfg.action_shape = action_spec.shape 43 | if device is not None: 44 | cfg.device = device 45 | return hydra.utils.instantiate(cfg) 46 | 47 | 48 | class Workspace: 49 | def __init__(self, cfg): 50 | self.work_dir = Path.cwd() 51 | 52 | self.cfg = cfg 53 | utils.set_seed_everywhere(cfg.seed) 54 | self.device = torch.device(cfg.device) 55 | self.setup() 56 | 57 | self.agent = make_agent(self.train_env.observation_spec(), 58 | self.train_env.action_spec(), 59 | self.cfg.agent) 60 | self.timer = utils.Timer() 61 | self._global_step = 0 62 | self._global_episode = 0 63 | 64 | def setup(self): 65 | # some assertions 66 | utils.assert_agent(self.cfg['agent_name'], self.cfg['pixel_obs']) 67 | 68 | # create logger 69 | self.logger = Logger(self.work_dir) 70 | 71 | # get the reward parameters 72 | reward_parameters = OmegaConf.to_container(self.cfg.reward_parameters) 73 | 74 | # get the dynamics parameters 75 | dynamics_parameters = OmegaConf.to_container(self.cfg.dynamics_parameters) 76 | 77 | # create envs 78 | self.train_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack, 79 | self.cfg.action_repeat, reward_parameters, 80 | dynamics_parameters, self.cfg.seed, self.cfg.pixel_obs) 81 | self.eval_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack, 82 | self.cfg.action_repeat, reward_parameters, 83 | dynamics_parameters, self.cfg.seed, self.cfg.pixel_obs) 84 | # create replay buffer 85 | data_specs = (self.train_env.observation_spec(), 86 | self.train_env.action_spec(), 87 | specs.Array((1,), np.float32, 'reward'), 88 | specs.Array((1,), np.float32, 'discount')) 89 | 90 | self.replay_storage = ReplayBufferStorage(data_specs, 91 | self.work_dir / 'buffer') 92 | 93 | self.replay_loader = make_replay_loader( 94 | self.work_dir / 'buffer', self.cfg.replay_buffer_size, 95 | self.cfg.batch_size, self.cfg.replay_buffer_num_workers, 96 | self.cfg.save_snapshot, self.cfg.nstep, self.cfg.discount) 97 | self._replay_iter = None 98 | 99 | self.video_recorder = VideoRecorder( 100 | self.work_dir if self.cfg.save_video else None, 101 | fps=60 // self.cfg.action_repeat 102 | ) 103 | self.train_video_recorder = TrainVideoRecorder( 104 | self.work_dir if self.cfg.save_train_video else None, 105 | fps=60 // self.cfg.action_repeat 106 | ) 107 | 108 | self.plot_dir = self.work_dir / 'plots' 109 | self.plot_dir.mkdir(exist_ok=True) 110 | self.model_dir = self.work_dir / 'models' 111 | self.model_dir.mkdir(exist_ok=True) 112 | 113 | # save cfg and git sha 114 | utils.save_cfg(self.cfg, self.work_dir) 115 | utils.save_git_sha(self.work_dir) 116 | 117 | @property 118 | def global_step(self): 119 | return self._global_step 120 | 121 | @property 122 | def global_episode(self): 123 | return self._global_episode 124 | 125 | @property 126 | def global_frame(self): 127 | return self.global_step * self.cfg.action_repeat 128 | 129 | @property 130 | def replay_iter(self): 131 | if self._replay_iter is None: 132 | self._replay_iter = iter(self.replay_loader) 133 | return self._replay_iter 134 | 135 | def eval(self): 136 | step, episode, total_reward = 0, 0, 0 137 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 138 | 139 | while eval_until_episode(episode): 140 | time_step = self.eval_env.reset() 141 | self.video_recorder.init(self.eval_env, enabled=episode == 0) 142 | while not time_step.last(): 143 | with torch.no_grad(), utils.eval_mode(self.agent): 144 | action = self.agent.act(time_step.observation, 145 | self.global_step, 146 | eval_mode=True) 147 | time_step = self.eval_env.step(action) 148 | self.video_recorder.record(self.eval_env) 149 | total_reward += time_step.reward 150 | step += 1 151 | 152 | episode += 1 153 | self.video_recorder.save(f'{self.global_frame}_{episode}.mp4') 154 | 155 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log: 156 | log('episode_reward', total_reward / episode) 157 | log('episode_length', step * self.cfg.action_repeat / episode) 158 | log('episode', self.global_episode) 159 | log('step', self.global_step) 160 | 161 | def train(self, task_id=1): 162 | # predicates 163 | train_until_step = utils.Until(self.cfg.num_train_frames * task_id, 164 | self.cfg.action_repeat) 165 | seed_until_step = utils.Until(self.cfg.num_seed_frames + self.cfg.num_train_frames * (task_id - 1), 166 | self.cfg.action_repeat) 167 | eval_every_step = utils.Every(self.cfg.eval_every_frames, 168 | self.cfg.action_repeat) 169 | plot_every_step = utils.Every(self.cfg.plot_every_frames, 170 | self.cfg.action_repeat) 171 | save_every_step = utils.Every(self.cfg.save_every_frames, 172 | self.cfg.action_repeat) 173 | 174 | episode_step, episode_reward = 0, 0 175 | time_step = self.train_env.reset() 176 | self.replay_storage.add(time_step) 177 | self.train_video_recorder.init(time_step.observation) 178 | metrics = None 179 | while train_until_step(self.global_step): 180 | if time_step.last(): 181 | self._global_episode += 1 182 | self.train_video_recorder.save(f'{self.global_frame}.mp4') 183 | # wait until all the metrics schema is populated 184 | if metrics is not None: 185 | # log stats 186 | elapsed_time, total_time = self.timer.reset() 187 | episode_frame = episode_step * self.cfg.action_repeat 188 | with self.logger.log_and_dump_ctx(self.global_frame, 189 | ty='train') as log: 190 | log('fps', episode_frame / elapsed_time) 191 | log('total_time', total_time) 192 | log('episode_reward', episode_reward) 193 | log('episode_length', episode_frame) 194 | log('episode', self.global_episode) 195 | log('buffer_size', len(self.replay_storage)) 196 | log('step', self.global_step) 197 | 198 | # reset env 199 | time_step = self.train_env.reset() 200 | self.replay_storage.add(time_step) 201 | self.train_video_recorder.init(time_step.observation) 202 | episode_step = 0 203 | episode_reward = 0 204 | 205 | # try to evaluate 206 | if eval_every_step(self.global_step): 207 | self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame) 208 | self.eval() 209 | 210 | if save_every_step(self.global_step): 211 | self.agent.save(self.model_dir, self.global_frame) 212 | 213 | # try to save snapshot 214 | if self.cfg.save_snapshot: 215 | self.save_snapshot() 216 | 217 | # sample action 218 | with torch.no_grad(), utils.eval_mode(self.agent): 219 | action = self.agent.act(time_step.observation, 220 | self.global_step, 221 | eval_mode=False) 222 | 223 | # try to update the agent 224 | if not seed_until_step(self.global_step): 225 | metrics = self.agent.update(self.replay_iter, self.global_step) 226 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 227 | 228 | # take env step 229 | time_step = self.train_env.step(action) 230 | episode_reward += time_step.reward 231 | self.replay_storage.add(time_step) 232 | self.train_video_recorder.record(time_step.observation) 233 | episode_step += 1 234 | self._global_step += 1 235 | 236 | def save_snapshot(self): 237 | snapshot = self.work_dir / 'snapshot.pt' 238 | keys_to_save = ['agent', 'timer', '_global_step', '_global_episode'] 239 | payload = {k: self.__dict__[k] for k in keys_to_save} 240 | with snapshot.open('wb') as f: 241 | torch.save(payload, f) 242 | 243 | def load_snapshot(self): 244 | snapshot = self.work_dir / 'snapshot.pt' 245 | with snapshot.open('rb') as f: 246 | payload = torch.load(f) 247 | for k, v in payload.items(): 248 | self.__dict__[k] = v 249 | 250 | 251 | @hydra.main(version_base=None, config_path='cfgs', config_name='config') 252 | def main(cfg): 253 | log = logging.getLogger(__name__) 254 | try: 255 | device_id = AVAILABLE_GPUS[HydraConfig.get().job.num % len(AVAILABLE_GPUS)] 256 | cfg.device = f"{cfg.device}:{device_id}" 257 | log.info(f"Total number of GPUs is {AVAILABLE_GPUS}, running on {cfg.device}.") 258 | except omegaconf.errors.MissingMandatoryValue: 259 | pass 260 | 261 | root_dir = Path.cwd() 262 | workspace = Workspace(cfg) 263 | snapshot = root_dir / 'snapshot.pt' 264 | if snapshot.exists(): 265 | print(f'resuming: {snapshot}') 266 | workspace.load_snapshot() 267 | workspace.train() 268 | 269 | 270 | if __name__ == '__main__': 271 | main() 272 | -------------------------------------------------------------------------------- /train_rl_regressor.py: -------------------------------------------------------------------------------- 1 | """ 2 | RL approximator training loop. 3 | Works on DMC with states and pixel observations. 4 | """ 5 | 6 | import warnings 7 | 8 | warnings.filterwarnings('ignore', category=DeprecationWarning) 9 | 10 | import os 11 | import platform 12 | import logging 13 | import math 14 | 15 | if platform.system() == 'Linux': 16 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 17 | os.environ['MUJOCO_GL'] = 'egl' 18 | 19 | from pathlib import Path 20 | 21 | import hydra 22 | import omegaconf 23 | import torch 24 | from hydra.core.hydra_config import HydraConfig 25 | from torch.utils.tensorboard import SummaryWriter 26 | 27 | import utils.utils as utils 28 | from utils.dataset import RLSolutionDataset, RLSolutionMetaDataset 29 | from utils.dataloader import FastTensorDataLoader, FastTensorMetaDataLoader 30 | 31 | torch.backends.cudnn.benchmark = True 32 | 33 | # If using multirun, set the GPUs here: 34 | AVAILABLE_GPUS = [1, 2, 3, 4, 0] 35 | 36 | 37 | def make_approximator(input_dim, state_dim, action_dim, cfg, device=None): 38 | cfg.input_dim = input_dim 39 | cfg.state_dim = state_dim 40 | cfg.action_dim = action_dim 41 | if device is not None: 42 | cfg.device = device 43 | return hydra.utils.instantiate(cfg) 44 | 45 | 46 | class Workspace: 47 | def __init__(self, cfg): 48 | self.work_dir = Path.cwd() 49 | 50 | self.cfg = cfg 51 | utils.set_seed_everywhere(cfg.seed) 52 | self.device = torch.device(cfg.device) 53 | 54 | # hacked up way to see if we are using MAML or not 55 | self.is_meta_learning = True if 'meta' in self.cfg.approximator_name else False 56 | 57 | self.setup() 58 | 59 | if cfg.input_to_model == 'rew': 60 | input_dim = self.dataset.reward_param_dim 61 | elif cfg.input_to_model == 'dyn': 62 | input_dim = self.dataset.dynamic_param_dim 63 | elif cfg.input_to_model == 'rew_dyn': 64 | input_dim = self.dataset.reward_dynamic_param_dim 65 | else: 66 | raise NotImplementedError 67 | 68 | self.approximator = make_approximator(input_dim, 69 | self.dataset.state_dim, 70 | self.dataset.action_dim, 71 | self.cfg.approximator) 72 | self.timer = utils.Timer() 73 | self._global_epoch = 0 74 | self._global_episode = 0 75 | 76 | def setup(self): 77 | # create logger 78 | self.logger = SummaryWriter(str(self.work_dir)) 79 | 80 | self.model_dir = self.work_dir / 'models' 81 | self.model_dir.mkdir(exist_ok=True) 82 | 83 | self.rollout_dir = Path(self.cfg.rollout_dir).expanduser().joinpath(self.cfg.domain_task) 84 | 85 | # load dataset 86 | self.load_dataset() 87 | 88 | # save cfg and git sha 89 | utils.save_cfg(self.cfg, self.work_dir) 90 | utils.save_git_sha(self.work_dir) 91 | 92 | def load_dataset(self): 93 | dataset_fn = RLSolutionMetaDataset if self.is_meta_learning else RLSolutionDataset 94 | dataloader_fn = FastTensorMetaDataLoader if self.is_meta_learning else FastTensorDataLoader 95 | 96 | self.dataset = dataset_fn( 97 | self.rollout_dir, 98 | self.cfg.domain_task, 99 | self.cfg.input_to_model, 100 | self.cfg.seed, 101 | self.device, 102 | ) 103 | 104 | if self.is_meta_learning: 105 | batch_size = int(self.dataset.n_tasks * self.cfg.k_shot * 2) 106 | else: 107 | batch_size = self.cfg.batch_size 108 | 109 | self.train_loader = dataloader_fn(*self.dataset.train_dataset[:], device=self.device, 110 | batch_size=batch_size, shuffle=True) 111 | self.test_loader = dataloader_fn(*self.dataset.test_dataset[:], device=self.device, 112 | batch_size=batch_size, shuffle=True) 113 | 114 | @property 115 | def global_epoch(self): 116 | return self._global_epoch 117 | 118 | def train(self): 119 | # predicates 120 | train_until_epoch = utils.Until(self.cfg.num_train_epochs) 121 | save_every_epoch = utils.Every(self.cfg.save_every_frames) 122 | 123 | metrics = dict() 124 | best_valid_total_loss = math.inf 125 | best_valid_value_loss = math.inf 126 | best_valid_action_loss = math.inf 127 | best_valid_td_loss = math.inf 128 | 129 | while train_until_epoch(self.global_epoch): 130 | metrics.update() 131 | 132 | if self.is_meta_learning: 133 | self.train_loader.shuffle_indices() 134 | self.test_loader.shuffle_indices() 135 | 136 | metrics.update(self.approximator.update(self.train_loader)) 137 | metrics.update(self.approximator.eval(self.test_loader)) 138 | 139 | # Log metrics 140 | print(f"Epoch {self.global_epoch + 1} " 141 | f"\t Train loss {metrics['train/loss_total']:.3f} " 142 | f"\t Valid loss {metrics['valid/loss_total']:.3f}") 143 | for k, v in metrics.items(): 144 | self.logger.add_scalar(k, v, self.global_epoch + 1) 145 | utils.dump_dict(f"{self.work_dir}/train_valid.csv", metrics) 146 | 147 | # Save the model 148 | if metrics['valid/loss_total'] <= best_valid_total_loss: 149 | best_valid_total_loss = metrics['valid/loss_total'] 150 | self.approximator.save(self.model_dir, 'best_total') 151 | if metrics['valid/loss_action_pred'] <= best_valid_action_loss: 152 | best_valid_action_loss = metrics['valid/loss_action_pred'] 153 | self.approximator.save(self.model_dir, 'best_action') 154 | if 'valid/loss_value_pred' in metrics: 155 | if metrics['valid/loss_value_pred'] <= best_valid_value_loss: 156 | best_valid_value_loss = metrics['valid/loss_value_pred'] 157 | self.approximator.save(self.model_dir, 'best_value') 158 | if 'valid/loss_td' in metrics: 159 | if metrics['valid/loss_td'] <= best_valid_td_loss: 160 | best_valid_td_loss = metrics['valid/loss_td'] 161 | self.approximator.save(self.model_dir, 'best_td') 162 | 163 | if save_every_epoch(self.global_epoch + 1): 164 | self.approximator.save(self.model_dir, self.global_epoch + 1) 165 | 166 | self._global_epoch += 1 167 | 168 | def save_snapshot(self): 169 | snapshot = self.work_dir / 'snapshot.pt' 170 | keys_to_save = ['agent', 'timer', '_global_step', '_global_episode'] 171 | payload = {k: self.__dict__[k] for k in keys_to_save} 172 | with snapshot.open('wb') as f: 173 | torch.save(payload, f) 174 | 175 | def load_snapshot(self): 176 | snapshot = self.work_dir / 'snapshot.pt' 177 | with snapshot.open('rb') as f: 178 | payload = torch.load(f) 179 | for k, v in payload.items(): 180 | self.__dict__[k] = v 181 | 182 | 183 | @hydra.main(version_base=None, config_path='cfgs', config_name='config_rl_approximator') 184 | def main(cfg): 185 | log = logging.getLogger(__name__) 186 | try: 187 | device_id = AVAILABLE_GPUS[HydraConfig.get().job.num % len(AVAILABLE_GPUS)] 188 | cfg.device = f"{cfg.device}:{device_id}" 189 | log.info(f"Total number of GPUs is {AVAILABLE_GPUS}, running on {cfg.device}.") 190 | except omegaconf.errors.MissingMandatoryValue: 191 | pass 192 | 193 | root_dir = Path.cwd() 194 | workspace = Workspace(cfg) 195 | snapshot = root_dir / 'snapshot.pt' 196 | if snapshot.exists(): 197 | print(f'resuming: {snapshot}') 198 | workspace.load_snapshot() 199 | workspace.train() 200 | 201 | 202 | if __name__ == '__main__': 203 | main() 204 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAIC-MONTREAL/hyperzero/ab0508a73c09940d8c98267af8ae021d834915d0/utils/__init__.py -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FastTensorDataLoader: 5 | """ 6 | A DataLoader-like object for a set of tensors that can be much faster than 7 | TensorDataset + DataLoader because dataloader grabs individual indices of 8 | the dataset and calls cat (slow). 9 | Code based on: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/5 10 | """ 11 | def __init__(self, *tensors, device='cpu', batch_size=256, shuffle=False): 12 | """ 13 | Initialize a FastTensorDataLoader. 14 | Each tensor is in the form of (N, D) with N the number 15 | of datapoints and D the dimension of the data. 16 | 17 | :param *tensors: tensors to store. Must have the same length @ dim 0. 18 | :param batch_size: batch size to load. 19 | :param shuffle: if True, shuffle the data *in-place* whenever an 20 | iterator is created out of this object. 21 | 22 | :returns: A FastTensorDataLoader. 23 | """ 24 | assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) 25 | self.tensors = tensors 26 | self.device = device 27 | 28 | self.dataset_len = self.tensors[0].shape[0] 29 | self.batch_size = batch_size 30 | self.shuffle = shuffle 31 | 32 | # Calculate # batches 33 | n_batches, remainder = divmod(self.dataset_len, self.batch_size) 34 | if remainder > 0: 35 | n_batches += 1 36 | self.n_batches = n_batches 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | self.indices = torch.randperm(self.dataset_len).to(self.device) 41 | else: 42 | self.indices = None 43 | self.i = 0 44 | return self 45 | 46 | def __next__(self): 47 | if self.i >= self.dataset_len: 48 | raise StopIteration 49 | if self.indices is not None: 50 | indices = self.indices[self.i:self.i+self.batch_size] 51 | batch = tuple(torch.index_select(t, 0, indices) for t in self.tensors) 52 | else: 53 | batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors) 54 | self.i += self.batch_size 55 | return batch 56 | 57 | def __len__(self): 58 | return self.n_batches 59 | 60 | 61 | class FastTensorMetaDataLoader: 62 | """ 63 | Fast tensor dataloader for meta learning. 64 | """ 65 | def __init__(self, *tensors, device='cpu', batch_size=256, shuffle=False): 66 | """ 67 | Initialize a FastTensorDataLoader. 68 | Each tensor is in the form of (T, N, D) with T the number of tasks, 69 | N the number of datapoints and D the dimension of the data. 70 | 71 | :param *tensors: tensors to store. Must have the same length @ dim 0 and @ dim 1. 72 | :param batch_size: batch size to load. 73 | :param shuffle: if True, shuffle the data *in-place* whenever an 74 | iterator is created out of this object. 75 | 76 | :returns: A FastTensorDataLoader. 77 | """ 78 | assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) 79 | assert all(t.shape[1] == tensors[0].shape[1] for t in tensors) 80 | self.tensors = tensors 81 | self.device = device 82 | 83 | self.dataset_len = self.tensors[0].shape[1] 84 | self._n_tasks = self.tensors[0].shape[0] 85 | self.batch_size = batch_size 86 | self.shuffle = shuffle 87 | 88 | # Calculate # batches 89 | n_batches, remainder = divmod(self.dataset_len, self.batch_size) 90 | if remainder > 0: 91 | n_batches += 1 92 | self.n_batches = n_batches 93 | 94 | def shuffle_indices(self): 95 | self._shuffle_indices() 96 | self.i = 0 97 | 98 | def _shuffle_indices(self): 99 | if self.shuffle: 100 | self.indices = torch.randperm(self.dataset_len).to(self.device) 101 | else: 102 | self.indices = None 103 | 104 | def sample(self, task_id): 105 | assert task_id < self._n_tasks 106 | 107 | if self.i >= self.dataset_len: 108 | self._shuffle_indices() 109 | self.i = 0 110 | 111 | task_tensors = tuple(t[task_id] for t in self.tensors) 112 | 113 | if self.indices is not None: 114 | indices = self.indices[self.i:self.i + self.batch_size] 115 | batch = tuple(torch.index_select(t, 0, indices) for t in task_tensors) 116 | else: 117 | batch = tuple(t[self.i:self.i + self.batch_size] for t in task_tensors) 118 | self.i += self.batch_size 119 | return batch 120 | 121 | def __len__(self): 122 | return self.n_batches 123 | 124 | @property 125 | def n_tasks(self): 126 | return self._n_tasks 127 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | from pathlib import Path 5 | from collections import defaultdict 6 | from numpy.random import default_rng 7 | 8 | import torch 9 | from torch.utils.data import TensorDataset 10 | 11 | 12 | def _get_reward_param(data_dir, domain_task, seed, test_fraction): 13 | """ 14 | Helper function to get the reward parameter from an .npy rollout file. 15 | """ 16 | datadir = Path(data_dir) 17 | paths_to_load = sorted(datadir.glob(f'**/{domain_task}*_seed_*.npy')) 18 | 19 | # get reward parameters in a super hacked up way 20 | all_reward_params = sorted(list(set([re.findall('\-(.*?)\_', str(p))[0] for p in paths_to_load]))) 21 | # remove margin=0.0 from reward params 22 | reward_params = [x for x in all_reward_params if '-0.0' not in x] 23 | 24 | # Random picking based on the seed 25 | rng = default_rng(seed) 26 | rng.shuffle(reward_params) 27 | test_size = int(test_fraction * len(reward_params)) 28 | train_reward_params = reward_params[test_size:] 29 | test_reward_params = reward_params[:test_size] 30 | 31 | return train_reward_params, test_reward_params 32 | 33 | 34 | def _get_dynamic_param(data_dir, domain_task, seed, test_fraction): 35 | """ 36 | Helper function to get the dynamics parameter from an .npy rollout file. 37 | """ 38 | datadir = Path(data_dir) 39 | paths_to_load = sorted(datadir.glob(f'**/{domain_task}*_seed_*.npy')) 40 | 41 | # get dynamic parameters in a super hacked up way 42 | dynamic_params = sorted(list(set([re.findall('\_dyn_(.*?)\__', str(p))[0] for p in paths_to_load]))) 43 | 44 | # Random picking based on the seed 45 | rng = default_rng(seed) 46 | rng.shuffle(dynamic_params) 47 | test_size = int(test_fraction * len(dynamic_params)) 48 | train_dynamic_params = dynamic_params[test_size:] 49 | test_dynamic_params = dynamic_params[:test_size] 50 | 51 | return train_dynamic_params, test_dynamic_params 52 | 53 | 54 | def _get_reward_dynamic_param(data_dir, domain_task, seed, test_fraction): 55 | """ 56 | Helper function to get the reward and dynamics parameters from an .npy rollout file. 57 | """ 58 | datadir = Path(data_dir) 59 | paths_to_load = sorted(datadir.glob(f'**/{domain_task}*_seed_*.npy')) 60 | 61 | # get reward-dynamic parameters in a super hacked up way 62 | reward_dynamic_params = sorted(list(set([re.findall('\-(.*?)\__', str(p))[0] for p in paths_to_load]))) 63 | 64 | # Random picking based on the seed 65 | rng = default_rng(seed) 66 | rng.shuffle(reward_dynamic_params) 67 | test_size = int(test_fraction * len(reward_dynamic_params)) 68 | train_reward_dynamic_params = reward_dynamic_params[test_size:] 69 | test_reward_dynamic_params = reward_dynamic_params[:test_size] 70 | 71 | return train_reward_dynamic_params, test_reward_dynamic_params 72 | 73 | 74 | class RLSolutionDataset: 75 | """ 76 | Dataset of near-optimal trajectories on a family of MDPs. 77 | Used for training HyperZero and MLP baselines. 78 | """ 79 | def __init__(self, data_dir, domain_task, input_to_model, seed, device): 80 | assert input_to_model in ['rew', 'dyn', 'rew_dyn'] 81 | self.data_dir = data_dir 82 | self.domain_task = domain_task 83 | self.input_to_model = input_to_model 84 | self.device = device 85 | self.test_fraction = 0.15 86 | self.seed = seed 87 | 88 | # set the data keys 89 | if input_to_model == 'rew': 90 | self.data_keys = ['reward_param', 'state', 'action', 'next_state', 'reward', 'discount', 'value'] 91 | self.train_input_params, self.test_input_params = _get_reward_param(data_dir, domain_task, self.seed, self.test_fraction) 92 | elif input_to_model == 'dyn': 93 | self.data_keys = ['dynamics_param', 'state', 'action', 'next_state', 'reward', 'discount', 'value'] 94 | self.train_input_params, self.test_input_params = _get_dynamic_param(data_dir, domain_task, self.seed, self.test_fraction) 95 | elif input_to_model == 'rew_dyn': 96 | self.data_keys = ['reward_dynamics_param', 'state', 'action', 'next_state', 'reward', 'discount', 'value'] 97 | self.train_input_params, self.test_input_params = _get_reward_dynamic_param(data_dir, domain_task, self.seed, self.test_fraction) 98 | 99 | self.setup() 100 | 101 | def setup(self): 102 | train_tensors, test_tensors = [], [] 103 | 104 | # load the dataset 105 | self.train_data_np, self.test_data_np = self._load_dataset(flatten=True) 106 | 107 | # concatenate reward and dynamic parameters 108 | self.train_data_np['reward_dynamics_param'] = np.concatenate((self.train_data_np['reward_param'], 109 | self.train_data_np['dynamics_param']), 110 | axis=-1) 111 | self.test_data_np['reward_dynamics_param'] = np.concatenate((self.test_data_np['reward_param'], 112 | self.test_data_np['dynamics_param']), 113 | axis=-1) 114 | 115 | for k in self.data_keys: 116 | train_tensors.append( 117 | torch.tensor(self.train_data_np[k], dtype=torch.float, device=self.device) 118 | ) 119 | test_tensors.append( 120 | torch.tensor(self.test_data_np[k], dtype=torch.float, device=self.device) 121 | ) 122 | 123 | self.train_dataset = TensorDataset(*train_tensors) 124 | self.test_dataset = TensorDataset(*test_tensors) 125 | 126 | def _load_dataset(self, flatten=False): 127 | train_data, test_data = self._generate_data(flatten) 128 | return train_data, test_data 129 | 130 | def _generate_data(self, flatten=False): 131 | datadir = Path(self.data_dir) 132 | train_data_np, test_data_np = defaultdict(list), defaultdict(list) 133 | data = { 134 | 'train': train_data_np, 135 | 'test': test_data_np 136 | } 137 | 138 | for stage, input_params in zip(['train', 'test'], 139 | [self.train_input_params, self.test_input_params]): 140 | for r in input_params: 141 | paths_to_load = sorted(datadir.glob(f'**/{self.domain_task}*_seed_*{r}_*.npy')) 142 | 143 | for p in paths_to_load: 144 | print(f"Loading data from {str(p)}") 145 | d = np.load(str(p), allow_pickle=True).item() 146 | for k, v in d.items(): 147 | if flatten: 148 | # save the data as (n_episodes * n_steps, ?) 149 | n_episodes, n_steps = v.shape[0], v.shape[1] 150 | data[stage][k].append(v.reshape(n_episodes * n_steps, -1)) 151 | else: 152 | # save the data as (n_episodes, n_steps, ?) 153 | data[stage][k].append(v) 154 | 155 | # concatenate the loaded data 156 | for k, v in data[stage].items(): 157 | data[stage][k] = np.concatenate(v, axis=0) 158 | 159 | with open(os.path.join(self.data_dir, f'{stage}-{self.input_to_model}-params-seed-{self.seed}.txt'), 'w') as f: 160 | f.write(str(input_params)) 161 | 162 | return data['train'], data['test'] 163 | 164 | @property 165 | def reward_param_dim(self): 166 | return self.train_data_np['reward_param'].shape[-1] 167 | 168 | @property 169 | def dynamic_param_dim(self): 170 | return self.train_data_np['dynamics_param'].shape[-1] 171 | 172 | @property 173 | def reward_dynamic_param_dim(self): 174 | return self.train_data_np['reward_dynamics_param'].shape[-1] 175 | 176 | @property 177 | def state_dim(self): 178 | return self.train_data_np['state'].shape[-1] 179 | 180 | @property 181 | def action_dim(self): 182 | return self.train_data_np['action'].shape[-1] 183 | 184 | 185 | class RLSolutionMetaDataset(RLSolutionDataset): 186 | """ 187 | Dataset of near-optimal trajectories on a family of MDPs. 188 | Used for training meta learning (MAML and PEARL) baselines. 189 | """ 190 | def __init__(self, data_dir, domain_task, input_to_model, seed, device): 191 | super().__init__(data_dir, domain_task, input_to_model, seed, device) 192 | 193 | def _load_dataset(self, flatten=False): 194 | meta_train_data, meta_test_data = self._generate_data(flatten) 195 | return meta_train_data, meta_test_data 196 | 197 | def _generate_data(self, flatten=False): 198 | datadir = Path(self.data_dir) 199 | train_data_np, test_data_np = defaultdict(list), defaultdict(list) 200 | data = { 201 | 'train': train_data_np, 202 | 'test': test_data_np 203 | } 204 | 205 | for stage, input_params in zip(['train', 'test'], 206 | [self.train_input_params, self.test_input_params]): 207 | for r in input_params: 208 | paths_to_load = sorted(datadir.glob(f'**/{self.domain_task}*_seed_*{r}_*.npy')) 209 | 210 | for p in paths_to_load: 211 | print(f"Loading data from {str(p)}") 212 | d = np.load(str(p), allow_pickle=True).item() 213 | for k, v in d.items(): 214 | if flatten: 215 | # save the data as (n_episodes * n_steps, ?) 216 | n_episodes, n_steps = v.shape[0], v.shape[1] 217 | data[stage][k].append(v.reshape(n_episodes * n_steps, -1)) 218 | else: 219 | # save the data as (n_episodes, n_steps, ?) 220 | data[stage][k].append(v) 221 | 222 | # concatenate the loaded data 223 | for k, v in data[stage].items(): 224 | data[stage][k] = np.stack(v, axis=0) # note the difference from the standard dataset 225 | 226 | with open(os.path.join(self.data_dir, f'meta-{stage}-{self.input_to_model}-params-seed-{self.seed}.txt'), 'w') as f: 227 | f.write(str(input_params)) 228 | return data['train'], data['test'] 229 | 230 | @property 231 | def n_tasks(self): 232 | return self.train_data_np['state'].shape[0] 233 | -------------------------------------------------------------------------------- /utils/dmc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import gym 4 | from collections import deque, OrderedDict 5 | from typing import Any, NamedTuple 6 | 7 | import dm_env 8 | import numpy as np 9 | 10 | from contextual_control_suite import suite 11 | 12 | from dm_control import manipulation 13 | from dm_control.suite.wrappers import action_scale, pixels 14 | from dm_control.rl.control import FLAT_OBSERVATION_KEY 15 | from dm_env import StepType, specs, TimeStep 16 | 17 | 18 | class ExtendedTimeStep(NamedTuple): 19 | step_type: Any 20 | reward: Any 21 | discount: Any 22 | observation: Any 23 | action: Any 24 | 25 | def first(self): 26 | return self.step_type == StepType.FIRST 27 | 28 | def mid(self): 29 | return self.step_type == StepType.MID 30 | 31 | def last(self): 32 | return self.step_type == StepType.LAST 33 | 34 | def __getitem__(self, attr): 35 | return getattr(self, attr) 36 | 37 | 38 | class ActionRepeatWrapper(dm_env.Environment): 39 | def __init__(self, env, num_repeats): 40 | self._env = env 41 | self._num_repeats = num_repeats 42 | 43 | def step(self, action): 44 | reward = 0.0 45 | discount = 1.0 46 | for i in range(self._num_repeats): 47 | time_step = self._env.step(action) 48 | reward += (time_step.reward or 0.0) * discount 49 | discount *= time_step.discount 50 | if time_step.last(): 51 | break 52 | 53 | return time_step._replace(reward=reward, discount=discount) 54 | 55 | def observation_spec(self): 56 | return self._env.observation_spec() 57 | 58 | def action_spec(self): 59 | return self._env.action_spec() 60 | 61 | def reset(self): 62 | return self._env.reset() 63 | 64 | def __getattr__(self, name): 65 | return getattr(self._env, name) 66 | 67 | 68 | class FrameStackWrapper(dm_env.Environment): 69 | def __init__(self, env, num_frames, pixels_key='pixels'): 70 | self._env = env 71 | self._num_frames = num_frames 72 | self._frames = deque([], maxlen=num_frames) 73 | self._pixels_key = pixels_key 74 | 75 | wrapped_obs_spec = env.observation_spec() 76 | assert pixels_key in wrapped_obs_spec 77 | 78 | pixels_shape = wrapped_obs_spec[pixels_key].shape 79 | # remove batch dim 80 | if len(pixels_shape) == 4: 81 | pixels_shape = pixels_shape[1:] 82 | self._obs_spec = specs.BoundedArray(shape=np.concatenate( 83 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0), 84 | dtype=np.uint8, 85 | minimum=0, 86 | maximum=255, 87 | name='observation') 88 | 89 | def _transform_observation(self, time_step): 90 | assert len(self._frames) == self._num_frames 91 | obs = np.concatenate(list(self._frames), axis=0) 92 | return time_step._replace(observation=obs) 93 | 94 | def _extract_pixels(self, time_step): 95 | pixels = time_step.observation[self._pixels_key] 96 | # remove batch dim 97 | if len(pixels.shape) == 4: 98 | pixels = pixels[0] 99 | return pixels.transpose(2, 0, 1).copy() 100 | 101 | def reset(self): 102 | time_step = self._env.reset() 103 | pixels = self._extract_pixels(time_step) 104 | for _ in range(self._num_frames): 105 | self._frames.append(pixels) 106 | return self._transform_observation(time_step) 107 | 108 | def step(self, action): 109 | time_step = self._env.step(action) 110 | pixels = self._extract_pixels(time_step) 111 | self._frames.append(pixels) 112 | return self._transform_observation(time_step) 113 | 114 | def observation_spec(self): 115 | return self._obs_spec 116 | 117 | def action_spec(self): 118 | return self._env.action_spec() 119 | 120 | def __getattr__(self, name): 121 | return getattr(self._env, name) 122 | 123 | 124 | class ActionDTypeWrapper(dm_env.Environment): 125 | def __init__(self, env, dtype): 126 | self._env = env 127 | wrapped_action_spec = env.action_spec() 128 | self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, 129 | dtype, 130 | wrapped_action_spec.minimum, 131 | wrapped_action_spec.maximum, 132 | 'action') 133 | 134 | def step(self, action): 135 | action = action.astype(self._env.action_spec().dtype) 136 | return self._env.step(action) 137 | 138 | def observation_spec(self): 139 | return self._env.observation_spec() 140 | 141 | def action_spec(self): 142 | return self._action_spec 143 | 144 | def reset(self): 145 | return self._env.reset() 146 | 147 | def __getattr__(self, name): 148 | return getattr(self._env, name) 149 | 150 | 151 | class ObservationSpecWrapper(dm_env.Environment): 152 | def __init__(self, env): 153 | self._env = env 154 | wrapped_observation_spec = env.observation_spec()[FLAT_OBSERVATION_KEY] 155 | self._observation_spec = specs.Array(wrapped_observation_spec.shape, 156 | wrapped_observation_spec.dtype, 157 | 'observation') 158 | 159 | def _transform_observation(self, time_step): 160 | obs = time_step.observation[FLAT_OBSERVATION_KEY] 161 | return time_step._replace(observation=obs) 162 | 163 | def step(self, action): 164 | time_step = self._env.step(action) 165 | return self._transform_observation(time_step) 166 | 167 | def observation_spec(self): 168 | return self._observation_spec 169 | 170 | def action_spec(self): 171 | return self._env.action_spec() 172 | 173 | def reset(self): 174 | time_step = self._env.reset() 175 | return self._transform_observation(time_step) 176 | 177 | def __getattr__(self, name): 178 | return getattr(self._env, name) 179 | 180 | 181 | class ExtendedTimeStepWrapper(dm_env.Environment): 182 | def __init__(self, env): 183 | self._env = env 184 | 185 | def reset(self): 186 | time_step = self._env.reset() 187 | return self._augment_time_step(time_step) 188 | 189 | def step(self, action): 190 | time_step = self._env.step(action) 191 | return self._augment_time_step(time_step, action) 192 | 193 | def _augment_time_step(self, time_step, action=None): 194 | if action is None: 195 | action_spec = self.action_spec() 196 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 197 | return ExtendedTimeStep(observation=time_step.observation, 198 | step_type=time_step.step_type, 199 | action=action, 200 | reward=time_step.reward or 0.0, 201 | discount=time_step.discount or 1.0) 202 | 203 | def observation_spec(self): 204 | return self._env.observation_spec() 205 | 206 | def action_spec(self): 207 | return self._env.action_spec() 208 | 209 | def __getattr__(self, name): 210 | return getattr(self._env, name) 211 | 212 | 213 | class GymWrapper(dm_env.Environment): 214 | """Only works with Gym envs with continuous actions and states, 215 | also works with the fork of gym-miniworld with continuous actions 216 | https://github.com/sahandrez/gym-miniworld/tree/continuous_actions""" 217 | def __init__(self, env): 218 | self._env = env 219 | assert isinstance(env, gym.core.Env) 220 | assert isinstance(env.action_space, gym.spaces.Box) 221 | assert isinstance(env.observation_space, gym.spaces.Box) 222 | 223 | if len(env.observation_space.shape) == 3: 224 | # Pixel observations 225 | self.pixel_obs = True 226 | pixel_spec = specs.Array(shape=env.observation_space.shape, 227 | dtype=np.uint8, 228 | name='pixels') 229 | self._obs_spec = OrderedDict(pixels=pixel_spec) 230 | else: 231 | # State observations 232 | self.pixel_obs = False 233 | state_spec = specs.Array(shape=env.observation_space.shape, 234 | dtype=np.float64, 235 | name='observations') 236 | self._obs_spec = OrderedDict(observations=state_spec) 237 | 238 | self._action_spec = specs.BoundedArray(shape=env.action_space.shape, 239 | dtype=np.float64, 240 | minimum=env.action_space.low, 241 | maximum=env.action_space.high) 242 | 243 | def step(self, action): 244 | first = False 245 | obs, reward, done, _ = self._env.step(self._convert_action(action)) 246 | return self._to_time_step(obs, reward, done, first) 247 | 248 | def reset(self): 249 | obs, reward, done, first = self._env.reset(), 0.0, False, True 250 | return self._to_time_step(obs, reward, done, first) 251 | 252 | def observation_spec(self): 253 | return self._obs_spec 254 | 255 | def action_spec(self): 256 | return self._action_spec 257 | 258 | def _convert_action(self, action): 259 | return np.array(action, dtype=np.float64) 260 | 261 | def _convert_obs(self, obs): 262 | if self.pixel_obs: 263 | return OrderedDict(pixels=np.array(obs, dtype=np.uint8)) 264 | return OrderedDict(observations=np.array(obs, dtype=np.float64)) 265 | 266 | def _to_time_step(self, obs, reward, done, first): 267 | if first: 268 | step_type = StepType(0) 269 | elif done: 270 | step_type = StepType(2) 271 | else: 272 | step_type = StepType(1) 273 | 274 | discount = 1.0 - float(done) 275 | 276 | time_step = TimeStep(observation=self._convert_obs(obs), 277 | reward=reward, 278 | step_type=step_type, 279 | discount=discount) 280 | return time_step 281 | 282 | def __getattr__(self, item): 283 | return getattr(self._env, item) 284 | 285 | 286 | def make_dmc(name, frame_stack, action_repeat, reward_kwargs, dynamics_kwargs, seed, pixel_obs): 287 | environment_kwargs = dict() 288 | if not pixel_obs: 289 | environment_kwargs['flat_observation'] = True 290 | 291 | task_kwargs = {'random': seed, 292 | 'reward_kwargs': reward_kwargs} 293 | if not dynamics_kwargs['use_default']: 294 | task_kwargs['dynamics_kwargs'] = dynamics_kwargs 295 | 296 | domain, task = name.split('_', 1) 297 | # overwrite cup to ball_in_cup 298 | domain = dict(cup='ball_in_cup').get(domain, domain) 299 | # make sure reward is not visualized 300 | if (domain, task) in suite.ALL_TASKS: 301 | env = suite.load(domain, 302 | task, 303 | task_kwargs=task_kwargs, 304 | environment_kwargs=environment_kwargs, 305 | visualize_reward=False) 306 | pixels_key = 'pixels' 307 | else: 308 | name = f'{domain}_{task}_vision' 309 | env = manipulation.load(name, seed=seed) 310 | pixels_key = 'front_close' 311 | 312 | env = ActionDTypeWrapper(env, np.float32) 313 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) 314 | 315 | # add wrappers for pixel or state obs 316 | if pixel_obs: 317 | env = ActionRepeatWrapper(env, action_repeat) 318 | # add renderings for classical tasks 319 | if (domain, task) in suite.ALL_TASKS: 320 | # zoom in camera for quadruped 321 | camera_id = dict(quadruped=2).get(domain, 0) 322 | render_kwargs = dict(height=84, width=84, camera_id=camera_id) 323 | env = pixels.Wrapper(env, 324 | pixels_only=True, 325 | render_kwargs=render_kwargs) 326 | # stack several frames 327 | env = FrameStackWrapper(env, frame_stack, pixels_key) 328 | env = ExtendedTimeStepWrapper(env) 329 | else: 330 | env = ObservationSpecWrapper(env) 331 | env = ExtendedTimeStepWrapper(env) 332 | 333 | return env 334 | 335 | 336 | def make(name, frame_stack, action_repeat, reward_kwargs, dynamics_kwargs, seed, pixel_obs): 337 | return make_dmc(name, frame_stack, action_repeat, reward_kwargs, dynamics_kwargs, seed, pixel_obs) 338 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from termcolor import colored 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | TB_LOG_FREQ = 10 12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 14 | ('episode_reward', 'R', 'float'), 15 | ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'), 16 | ('total_time', 'T', 'time')] 17 | 18 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 19 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 20 | ('episode_reward', 'R', 'float'), 21 | ('total_time', 'T', 'time')] 22 | 23 | 24 | class AverageMeter(object): 25 | def __init__(self): 26 | self._sum = 0 27 | self._count = 0 28 | 29 | def update(self, value, n=1): 30 | self._sum += value 31 | self._count += n 32 | 33 | def value(self): 34 | return self._sum / max(1, self._count) 35 | 36 | 37 | class MetersGroup(object): 38 | def __init__(self, csv_file_name, formating): 39 | self._csv_file_name = csv_file_name 40 | self._formating = formating 41 | self._meters = defaultdict(AverageMeter) 42 | self._csv_file = None 43 | self._csv_writer = None 44 | 45 | def log(self, key, value, n=1): 46 | self._meters[key].update(value, n) 47 | 48 | def _prime_meters(self): 49 | data = dict() 50 | for key, meter in self._meters.items(): 51 | if key.startswith('train'): 52 | key = key[len('train') + 1:] 53 | else: 54 | key = key[len('eval') + 1:] 55 | key = key.replace('/', '_') 56 | data[key] = meter.value() 57 | return data 58 | 59 | def _remove_old_entries(self, data): 60 | rows = [] 61 | with self._csv_file_name.open('r') as f: 62 | reader = csv.DictReader(f) 63 | for row in reader: 64 | if float(row['episode']) >= data['episode']: 65 | break 66 | rows.append(row) 67 | with self._csv_file_name.open('w') as f: 68 | writer = csv.DictWriter(f, 69 | fieldnames=sorted(data.keys()), 70 | restval=0.0) 71 | writer.writeheader() 72 | for row in rows: 73 | writer.writerow(row) 74 | 75 | def _dump_to_csv(self, data): 76 | if self._csv_writer is None: 77 | should_write_header = True 78 | if self._csv_file_name.exists(): 79 | self._remove_old_entries(data) 80 | should_write_header = False 81 | 82 | self._csv_file = self._csv_file_name.open('a') 83 | self._csv_writer = csv.DictWriter(self._csv_file, 84 | fieldnames=sorted(data.keys()), 85 | restval=0.0) 86 | if should_write_header: 87 | self._csv_writer.writeheader() 88 | 89 | self._csv_writer.writerow(data) 90 | self._csv_file.flush() 91 | 92 | def _format(self, key, value, ty): 93 | if ty == 'int': 94 | value = int(value) 95 | return f'{key}: {value}' 96 | elif ty == 'float': 97 | return f'{key}: {value:.04f}' 98 | elif ty == 'time': 99 | value = str(datetime.timedelta(seconds=int(value))) 100 | return f'{key}: {value}' 101 | else: 102 | raise f'invalid format type: {ty}' 103 | 104 | def _dump_to_console(self, data, prefix): 105 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 106 | pieces = [f'| {prefix: <14}'] 107 | for key, disp_key, ty in self._formating: 108 | value = data.get(key, 0) 109 | pieces.append(self._format(disp_key, value, ty)) 110 | print(' | '.join(pieces)) 111 | 112 | def dump(self, step, prefix): 113 | if len(self._meters) == 0: 114 | return 115 | data = self._prime_meters() 116 | data['frame'] = step 117 | self._dump_to_csv(data) 118 | self._dump_to_console(data, prefix) 119 | self._meters.clear() 120 | 121 | 122 | class Logger(object): 123 | def __init__(self, log_dir): 124 | self._log_dir = log_dir 125 | self._train_mg = MetersGroup(log_dir / 'train.csv', 126 | formating=COMMON_TRAIN_FORMAT) 127 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 128 | formating=COMMON_EVAL_FORMAT) 129 | self._sw = SummaryWriter(str(log_dir / 'tb')) 130 | 131 | def _try_sw_log(self, key, value, step): 132 | if self._sw is not None and step % TB_LOG_FREQ == 0: 133 | self._sw.add_scalar(key, value, step) 134 | 135 | def log(self, key, value, step): 136 | assert key.startswith('train') or key.startswith('eval') 137 | if type(value) == torch.Tensor: 138 | value = value.item() 139 | self._try_sw_log(key, value, step) 140 | mg = self._train_mg if key.startswith('train') else self._eval_mg 141 | mg.log(key, value) 142 | 143 | def log_metrics(self, metrics, step, ty): 144 | for key, value in metrics.items(): 145 | self.log(f'{ty}/{key}', value, step) 146 | 147 | def dump(self, step, ty=None): 148 | if ty is None or ty == 'eval': 149 | self._eval_mg.dump(step, 'eval') 150 | if ty is None or ty == 'train': 151 | self._train_mg.dump(step, 'train') 152 | 153 | def log_and_dump_ctx(self, step, ty): 154 | return LogAndDumpCtx(self, step, ty) 155 | 156 | 157 | class LogAndDumpCtx: 158 | def __init__(self, logger, step, ty): 159 | self._logger = logger 160 | self._step = step 161 | self._ty = ty 162 | 163 | def __enter__(self): 164 | return self 165 | 166 | def __call__(self, key, value): 167 | self._logger.log(f'{self._ty}/{key}', value, self._step) 168 | 169 | def __exit__(self, *args): 170 | self._logger.dump(self._step, self._ty) 171 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy 5 | import numpy as np 6 | from sklearn.decomposition import PCA 7 | import matplotlib.pyplot as plt 8 | from matplotlib import cm 9 | from scipy.interpolate import griddata 10 | import seaborn as sns 11 | 12 | plt.style.use('seaborn-white') 13 | color_pallete = sns.color_palette('tab20b') 14 | 15 | 16 | def smooth(scalars, weight): 17 | last = scalars[0] 18 | smoothed = list() 19 | for point in scalars: 20 | smoothed_val = last * weight + (1 - weight) * point 21 | smoothed.append(smoothed_val) 22 | last = smoothed_val 23 | 24 | return smoothed 25 | 26 | 27 | def save_fig(name): 28 | plt.savefig(f"{name}.pdf", format='pdf', bbox_inches='tight') 29 | plt.savefig(f"{name}.png", format='png', dpi=300, bbox_inches='tight') 30 | plt.close() 31 | 32 | 33 | def create_3d_plot(): 34 | fig = plt.figure() 35 | ax = fig.add_subplot(111, projection='3d') 36 | return fig, ax 37 | 38 | 39 | def plot_vertices(vertices, plot_type='scatter', ax=None, fig=None, 40 | colors=None, bar=False, bar_label="Values", pad=True): 41 | """ 42 | Plot states and their values. 43 | """ 44 | if colors is None: 45 | colors = ["black" for i in range(vertices.shape[0])] 46 | bar = False 47 | vmin= -1 48 | vmax= 1 49 | else: 50 | vmin = min(colors) 51 | vmax = max(colors) 52 | 53 | if plot_type == 'scatter': 54 | im = ax.scatter(vertices[:, 0], vertices[:, 1], vertices[:, 2], c=colors, 55 | s=20, zorder=10, edgecolors="black", linewidth=0.5, 56 | vmin=vmin, vmax=vmax, alpha=0.3, cmap="cool") 57 | if bar: 58 | if pad: 59 | cbar = fig.colorbar(im, shrink=0.5, pad=0.1) 60 | else: 61 | cbar = fig.colorbar(im, shrink=0.5, pad=0) 62 | im.set_clim(vmin=vmin, vmax=vmax) 63 | cbar.ax.get_yaxis().labelpad = 10 64 | cbar.ax.set_ylabel(bar_label, rotation=270) 65 | 66 | elif plot_type == 'plot': 67 | im = ax.plot(vertices[:, 0], vertices[:, 1], vertices[:, 2], 68 | color=colors, zorder=10, linewidth=1, alpha=0.5) 69 | # scatter plot start and end points 70 | im = ax.scatter(vertices[0, 0], vertices[0, 1], vertices[0, 2], 71 | color=colors, s=10, zorder=10, linewidth=0.1, alpha=0.5) 72 | im = ax.scatter(vertices[-1, 0], vertices[-1, 1], vertices[-1, 2], 73 | color=colors, s=10, zorder=10, linewidth=0.1, alpha=0.5) 74 | 75 | return ax, fig 76 | 77 | 78 | def visualize_phase_space(qpos, qvel, values, save_dir, fname, plot_type, 79 | goal_coord=None, label='V'): 80 | # qpos = np.abs(qpos) 81 | qpos_dim = qpos.shape[-1] 82 | qvel_dim = qvel.shape[-1] 83 | 84 | is_pendulum = qpos_dim == 1 85 | 86 | if is_pendulum: 87 | vertices = np.concatenate([qpos, qvel, values], axis=-1) 88 | else: 89 | vertices = np.concatenate([qpos, qvel], axis=-1) 90 | # reshape vertices 91 | vertices = vertices.reshape(-1, qpos_dim + qvel_dim) 92 | values = values.reshape(-1, 1) 93 | 94 | # dimensionality reduction if necessary 95 | if vertices.shape[-1] > 3: 96 | pca = PCA(n_components=3) 97 | vertices = pca.fit_transform(vertices) 98 | 99 | vis_colors = random.choice(color_pallete) if plot_type == 'plot' else values 100 | 101 | fig, ax = create_3d_plot() 102 | ax, fig = plot_vertices(vertices, plot_type=plot_type, fig=fig, ax=ax, 103 | colors=vis_colors, bar=False, bar_label='', pad=is_pendulum) 104 | 105 | if is_pendulum: 106 | ax.set_xlabel(r"$\theta$") 107 | ax.set_ylabel(r"$\omega$") 108 | ax.set_zlabel(fr"${label}(\theta, \omega)$") 109 | else: 110 | ax.set_xticklabels([]) 111 | ax.set_yticklabels([]) 112 | ax.set_zticklabels([]) 113 | 114 | if goal_coord is not None: 115 | assert goal_coord.ndim == 2 116 | if goal_coord.shape[-1] > 3: 117 | goal_coord = pca.transform(goal_coord) 118 | 119 | im = ax.scatter(goal_coord[:, 0], goal_coord[:, 1], goal_coord[:, 2], color='green', 120 | marker='D', s=100, zorder=1, edgecolors="black", linewidth=0.5, alpha=0.6) 121 | 122 | plt.title(f"{label} on the Phase Space") 123 | plt.savefig(f"{save_dir}/{fname}.png", format='png', dpi=300, bbox_inches='tight') 124 | plt.close('all') 125 | -------------------------------------------------------------------------------- /utils/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import random 4 | import traceback 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import IterableDataset 10 | 11 | 12 | def episode_len(episode): 13 | # subtract -1 because the dummy first transition 14 | return next(iter(episode.values())).shape[0] - 1 15 | 16 | 17 | def save_episode(episode, fn): 18 | with io.BytesIO() as bs: 19 | np.savez_compressed(bs, **episode) 20 | bs.seek(0) 21 | with fn.open('wb') as f: 22 | f.write(bs.read()) 23 | 24 | 25 | def load_episode(fn): 26 | with fn.open('rb') as f: 27 | episode = np.load(f) 28 | episode = {k: episode[k] for k in episode.keys()} 29 | return episode 30 | 31 | 32 | class ReplayBufferStorage: 33 | def __init__(self, data_specs, replay_dir): 34 | self._data_specs = data_specs 35 | self._replay_dir = replay_dir 36 | replay_dir.mkdir(exist_ok=True) 37 | self._current_episode = defaultdict(list) 38 | self._preload() 39 | 40 | def __len__(self): 41 | return self._num_transitions 42 | 43 | def add(self, time_step): 44 | for spec in self._data_specs: 45 | value = time_step[spec.name] 46 | if np.isscalar(value): 47 | value = np.full(spec.shape, value, spec.dtype) 48 | assert spec.shape == value.shape and spec.dtype == value.dtype 49 | self._current_episode[spec.name].append(value) 50 | if time_step.last(): 51 | episode = dict() 52 | for spec in self._data_specs: 53 | value = self._current_episode[spec.name] 54 | episode[spec.name] = np.array(value, spec.dtype) 55 | self._current_episode = defaultdict(list) 56 | self._store_episode(episode) 57 | 58 | def _preload(self): 59 | self._num_episodes = 0 60 | self._num_transitions = 0 61 | for fn in self._replay_dir.glob('*.npz'): 62 | _, _, eps_len = fn.stem.split('_') 63 | self._num_episodes += 1 64 | self._num_transitions += int(eps_len) 65 | 66 | def _store_episode(self, episode): 67 | eps_idx = self._num_episodes 68 | eps_len = episode_len(episode) 69 | self._num_episodes += 1 70 | self._num_transitions += eps_len 71 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 72 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz' 73 | save_episode(episode, self._replay_dir / eps_fn) 74 | 75 | 76 | class ReplayBuffer(IterableDataset): 77 | def __init__(self, replay_dir, max_size, num_workers, nstep, discount, 78 | fetch_every, save_snapshot): 79 | self._replay_dir = replay_dir 80 | self._size = 0 81 | self._max_size = max_size 82 | self._num_workers = max(1, num_workers) 83 | self._episode_fns = [] 84 | self._episodes = dict() 85 | self._nstep = nstep 86 | self._discount = discount 87 | self._fetch_every = fetch_every 88 | self._samples_since_last_fetch = fetch_every 89 | self._save_snapshot = save_snapshot 90 | 91 | def _sample_episode(self): 92 | eps_fn = random.choice(self._episode_fns) 93 | return self._episodes[eps_fn] 94 | 95 | def _store_episode(self, eps_fn): 96 | try: 97 | episode = load_episode(eps_fn) 98 | except: 99 | return False 100 | eps_len = episode_len(episode) 101 | while eps_len + self._size > self._max_size: 102 | early_eps_fn = self._episode_fns.pop(0) 103 | early_eps = self._episodes.pop(early_eps_fn) 104 | self._size -= episode_len(early_eps) 105 | early_eps_fn.unlink(missing_ok=True) 106 | self._episode_fns.append(eps_fn) 107 | self._episode_fns.sort() 108 | self._episodes[eps_fn] = episode 109 | self._size += eps_len 110 | 111 | if not self._save_snapshot: 112 | eps_fn.unlink(missing_ok=True) 113 | return True 114 | 115 | def _try_fetch(self): 116 | if self._samples_since_last_fetch < self._fetch_every: 117 | return 118 | self._samples_since_last_fetch = 0 119 | try: 120 | worker_id = torch.utils.data.get_worker_info().id 121 | except: 122 | worker_id = 0 123 | eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True) 124 | fetched_size = 0 125 | for eps_fn in eps_fns: 126 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]] 127 | if eps_idx % self._num_workers != worker_id: 128 | continue 129 | if eps_fn in self._episodes.keys(): 130 | break 131 | if fetched_size + eps_len > self._max_size: 132 | break 133 | fetched_size += eps_len 134 | if not self._store_episode(eps_fn): 135 | break 136 | 137 | def _sample(self): 138 | try: 139 | self._try_fetch() 140 | except: 141 | traceback.print_exc() 142 | self._samples_since_last_fetch += 1 143 | episode = self._sample_episode() 144 | # add +1 for the first dummy transition 145 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1 146 | # negative samples for the contrastive loss 147 | neg_idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1 148 | 149 | obs = episode['observation'][idx - 1] 150 | action = episode['action'][idx] 151 | next_obs = episode['observation'][idx + self._nstep - 1] 152 | neg_obs = episode['observation'][neg_idx] 153 | reward = np.zeros_like(episode['reward'][idx]) 154 | discount = np.ones_like(episode['discount'][idx]) 155 | for i in range(self._nstep): 156 | step_reward = episode['reward'][idx + i] 157 | reward += discount * step_reward 158 | discount *= episode['discount'][idx + i] * self._discount 159 | return (obs, action, reward, discount, next_obs, neg_obs) 160 | 161 | def __iter__(self): 162 | while True: 163 | yield self._sample() 164 | 165 | 166 | def _worker_init_fn(worker_id): 167 | seed = np.random.get_state()[1][0] + worker_id 168 | np.random.seed(seed) 169 | random.seed(seed) 170 | 171 | 172 | def make_replay_loader(replay_dir, max_size, batch_size, num_workers, 173 | save_snapshot, nstep, discount): 174 | max_size_per_worker = max_size // max(1, num_workers) 175 | 176 | iterable = ReplayBuffer(replay_dir, 177 | max_size_per_worker, 178 | num_workers, 179 | nstep, 180 | discount, 181 | fetch_every=1000, 182 | save_snapshot=save_snapshot) 183 | 184 | loader = torch.utils.data.DataLoader(iterable, 185 | batch_size=batch_size, 186 | num_workers=num_workers, 187 | pin_memory=True, 188 | worker_init_fn=_worker_init_fn) 189 | return loader 190 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import csv 4 | import time 5 | import os 6 | import git 7 | import json 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from omegaconf import OmegaConf 14 | from torch import distributions as pyd 15 | from torch.distributions.utils import _standard_normal 16 | 17 | 18 | _STATE_AGENTS = ['td3', 'random', 'lapleig'] 19 | _PIXEL_AGENTS = ['drqv2', 'random'] 20 | 21 | 22 | class eval_mode: 23 | def __init__(self, *models): 24 | self.models = models 25 | 26 | def __enter__(self): 27 | self.prev_states = [] 28 | for model in self.models: 29 | self.prev_states.append(model.training) 30 | model.train(False) 31 | 32 | def __exit__(self, *args): 33 | for model, state in zip(self.models, self.prev_states): 34 | model.train(state) 35 | return False 36 | 37 | 38 | def assert_agent(agent_name, pixel_obs): 39 | agent_name = agent_name.partition('_')[0] 40 | if pixel_obs: 41 | assert agent_name in _PIXEL_AGENTS, f"{agent_name} does not support pixel observations" 42 | else: 43 | assert agent_name in _STATE_AGENTS, f"{agent_name} does not support state observations" 44 | 45 | 46 | def set_seed_everywhere(seed): 47 | torch.manual_seed(seed) 48 | if torch.cuda.is_available(): 49 | torch.cuda.manual_seed(seed) 50 | np.random.seed(seed) 51 | random.seed(seed) 52 | 53 | 54 | def soft_update_params(net, target_net, tau): 55 | for param, target_param in zip(net.parameters(), target_net.parameters()): 56 | target_param.data.copy_(tau * param.data + 57 | (1 - tau) * target_param.data) 58 | 59 | 60 | def to_torch(xs, device): 61 | return tuple(torch.as_tensor(x, device=device) for x in xs) 62 | 63 | 64 | def to_device(xs, device): 65 | return tuple(x.to(device) for x in xs) 66 | 67 | 68 | def select_indices(xs, indices): 69 | return tuple(x[indices] for x in xs) 70 | 71 | 72 | def preprocess_obs(obs, bits=5): 73 | """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" 74 | bins = 2**bits 75 | assert obs.dtype == torch.float32 76 | if bits < 8: 77 | obs = torch.floor(obs / 2**(8 - bits)) 78 | obs = obs / bins 79 | obs = obs + torch.rand_like(obs) / bins 80 | obs = obs - 0.5 81 | return obs 82 | 83 | 84 | def weight_init(m): 85 | if isinstance(m, nn.Linear): 86 | nn.init.orthogonal_(m.weight.data) 87 | if hasattr(m.bias, 'data'): 88 | m.bias.data.fill_(0.0) 89 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 90 | gain = nn.init.calculate_gain('relu') 91 | nn.init.orthogonal_(m.weight.data, gain) 92 | if hasattr(m.bias, 'data'): 93 | m.bias.data.fill_(0.0) 94 | 95 | 96 | def save_cfg(cfg, dir): 97 | with open(os.path.join(dir, 'cfg.yaml'), 'w') as f: 98 | OmegaConf.save(config=cfg, f=f.name) 99 | 100 | 101 | def save_args(args, dir): 102 | with open(os.path.join(dir, 'args.json'), 'w') as f: 103 | json.dump(args.__dict__, f, indent=4) 104 | 105 | 106 | def save_git_sha(dir): 107 | repo = git.Repo(search_parent_directories=True) 108 | sha = repo.head.object.hexsha 109 | with open(os.path.join(dir, 'git_sha.txt'), 'w') as f: 110 | f.write(sha) 111 | 112 | 113 | def get_last_model(model_dir): 114 | if not isinstance(model_dir, Path): 115 | model_dir = Path(model_dir) 116 | # return the step of the last saved model 117 | saved_models = [f for f in sorted(model_dir.glob(f'**/')) if not 'best' in str(f)] 118 | last_saved = saved_models[-1] 119 | last_step = str(last_saved.stem).partition('_')[-1] 120 | return int(last_step) 121 | 122 | 123 | def dump_dict(fname, logs): 124 | with open(fname, "a") as f: 125 | writer = csv.DictWriter(f, logs.keys()) 126 | if not os.path.getsize(fname): 127 | writer.writeheader() 128 | writer.writerow(logs) 129 | 130 | 131 | class Until: 132 | def __init__(self, until, action_repeat=1): 133 | self._until = until 134 | self._action_repeat = action_repeat 135 | 136 | def __call__(self, step): 137 | if self._until is None: 138 | return True 139 | until = self._until // self._action_repeat 140 | return step < until 141 | 142 | 143 | class Every: 144 | def __init__(self, every, action_repeat=1): 145 | self._every = every 146 | self._action_repeat = action_repeat 147 | 148 | def __call__(self, step): 149 | if self._every is None: 150 | return False 151 | every = self._every // self._action_repeat 152 | if step % every == 0: 153 | return True 154 | return False 155 | 156 | 157 | class Timer: 158 | def __init__(self): 159 | self._start_time = time.time() 160 | self._last_time = time.time() 161 | 162 | def reset(self): 163 | elapsed_time = time.time() - self._last_time 164 | self._last_time = time.time() 165 | total_time = time.time() - self._start_time 166 | return elapsed_time, total_time 167 | 168 | def total_time(self): 169 | return time.time() - self._start_time 170 | 171 | 172 | class TruncatedNormal(pyd.Normal): 173 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 174 | super().__init__(loc, scale, validate_args=False) 175 | self.low = low 176 | self.high = high 177 | self.eps = eps 178 | 179 | def _clamp(self, x): 180 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 181 | x = x - x.detach() + clamped_x.detach() 182 | return x 183 | 184 | def sample(self, clip=None, sample_shape=torch.Size()): 185 | shape = self._extended_shape(sample_shape) 186 | eps = _standard_normal(shape, 187 | dtype=self.loc.dtype, 188 | device=self.loc.device) 189 | eps *= self.scale 190 | if clip is not None: 191 | eps = torch.clamp(eps, -clip, clip) 192 | x = self.loc + eps 193 | return self._clamp(x) 194 | 195 | 196 | def schedule(schdl, step): 197 | try: 198 | return float(schdl) 199 | except ValueError: 200 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 201 | if match: 202 | init, final, duration = [float(g) for g in match.groups()] 203 | mix = np.clip(step / duration, 0.0, 1.0) 204 | return (1.0 - mix) * init + mix * final 205 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 206 | if match: 207 | init, final1, duration1, final2, duration2 = [ 208 | float(g) for g in match.groups() 209 | ] 210 | if step <= duration1: 211 | mix = np.clip(step / duration1, 0.0, 1.0) 212 | return (1.0 - mix) * init + mix * final1 213 | else: 214 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 215 | return (1.0 - mix) * final1 + mix * final2 216 | raise NotImplementedError(schdl) 217 | -------------------------------------------------------------------------------- /utils/video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | from PIL import Image, ImageDraw, ImageFont 5 | 6 | 7 | class VideoRecorder: 8 | def __init__(self, root_dir, render_size=256, fps=20): 9 | if root_dir is not None: 10 | self.save_dir = root_dir / 'eval_video' 11 | self.save_dir.mkdir(exist_ok=True) 12 | else: 13 | self.save_dir = None 14 | 15 | self.render_size = render_size 16 | self.fps = fps 17 | self.frames = [] 18 | 19 | def init(self, env, enabled=True): 20 | self.frames = [] 21 | self.enabled = self.save_dir is not None and enabled 22 | self.record(env) 23 | 24 | def record(self, env, text=None): 25 | if self.enabled: 26 | if hasattr(env, 'physics'): 27 | # DM control 28 | frame = env.physics.render(height=self.render_size, 29 | width=self.render_size, 30 | camera_id=0) 31 | elif hasattr(env, 'env'): 32 | # Meta World 33 | frame = env.env.render(offscreen=True, resolution=(self.render_size, self.render_size)) 34 | else: 35 | # OpenAI Gym 36 | frame = env.render() 37 | if text is not None: 38 | img = Image.fromarray(frame) 39 | draw = ImageDraw.Draw(img) 40 | font = ImageFont.truetype("SFCompact.ttf", 24) 41 | draw.text((50, 50), text, (255, 255, 255), font) 42 | frame = np.array(img) 43 | self.frames.append(frame) 44 | 45 | def save(self, file_name): 46 | if self.enabled: 47 | path = self.save_dir / file_name 48 | imageio.mimsave(str(path), self.frames, fps=self.fps) 49 | 50 | 51 | class TrainVideoRecorder: 52 | def __init__(self, root_dir, render_size=256, fps=20): 53 | if root_dir is not None: 54 | self.save_dir = root_dir / 'train_video' 55 | self.save_dir.mkdir(exist_ok=True) 56 | else: 57 | self.save_dir = None 58 | 59 | self.render_size = render_size 60 | self.fps = fps 61 | self.frames = [] 62 | 63 | def init(self, obs, enabled=True): 64 | self.frames = [] 65 | self.enabled = self.save_dir is not None and enabled 66 | self.record(obs) 67 | 68 | def record(self, obs): 69 | if self.enabled: 70 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0), 71 | dsize=(self.render_size, self.render_size), 72 | interpolation=cv2.INTER_CUBIC) 73 | self.frames.append(frame) 74 | 75 | def save(self, file_name): 76 | if self.enabled: 77 | path = self.save_dir / file_name 78 | imageio.mimsave(str(path), self.frames, fps=self.fps) 79 | --------------------------------------------------------------------------------