├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── arguments.py ├── common.py ├── envs.py ├── eval.py ├── eval_modules.py ├── generate_dynamics_data.py ├── main.py ├── model.py ├── my_optim.py ├── test.py ├── train_dynamics_module.py ├── train_online.py └── train_reward_module.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.facebook.com/codeofconduct) so that you can understand what actions will and will not be tolerated. 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ddr 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `master`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code lints. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Facebook's open source projects. 19 | 20 | Complete your CLA here: 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 27 | disclosure of security bugs. In those cases, please go through the process 28 | outlined on that page and do not file a public issue. 29 | 30 | ## Coding Style 31 | * 2 spaces for indentation rather than tabs 32 | * 80 character line length 33 | * ... 34 | 35 | ## License 36 | By contributing to ddr, you agree that your contributions will be licensed 37 | under the LICENSE file in the root directory of this source tree. 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ddr_for_tl 2 | Decoupling Dynamics and Reward for Transfer Learning 3 | 4 | Paper: https://arxiv.org/abs/1804.10689 5 | 6 | Generate data for Dynamics Module: (run twice in different locations for train and test sets) 7 | ``` 8 | python generate_dynamics_data.py --env-name HalfCheetahEnv --framework rllab --random-start --N 100000 --reset --out 9 | python generate_dynamics_data.py --env-name HalfCheetahEnv --framework rllab --random-start --N 10000 --reset --out 10 | ``` 11 | 12 | Continuous space + MuJoCo/gym: 13 | 14 | 15 | Example command to train the dynamics module: 16 | ``` 17 | python main.py --train-dynamics --train-set --test-set --train-batch 2500 --test-batch 250 --log-interval 10 --dim 200 --batch-size 512 --num-epochs 100 --env-name HalfCheetahEnv --framework rllab 18 | ``` 19 | 20 | 21 | Example command to train the rewards module: 22 | ``` 23 | python main.py --train-reward --env-name HalfCheetahEnv --framework rllab --dynamics-module --dim 200 --num-episodes 10000000 24 | ``` 25 | 26 | Transfer 27 | Dynamics: Include flag "--from-file {xml_file}" 28 | Reward: Include flag "--neg-reward" 29 | 30 | 31 | ## Tensorboard 32 | 33 | Make sure you have [tensorboardX](https://github.com/lanpa/tensorboard-pytorch) installed in your current conda installation. You can install it by executing following command: 34 | ``` 35 | pip install tensorboardX 36 | ``` 37 | 38 | Specify the directory where you want to log the tensorboard summaries (logs) with the ```--log-dir``` flag, eg: 39 | 40 | ``` 41 | python main.py --train-reward --dynamics-module /private/home/hsatija/ddr_for_tl/data/SwimmerMazeEnvmazeid0length1/_entropy_coef0.0_dec_loss_coef0.1_forward_loss_coef10_rollout3_train_size10000/dynamics_module_epoch10.pt --dim 10 --framework rllab --env-name SwimmerMazeEnv --out ./data/ --log-dir ./runs/ 42 | ``` 43 | 44 | The tensorboard logs (summaries/events) will be published in /tb_logs/ directory. Launch the tensorboard server on the devfair machine, 45 | ``` 46 | tensorboard --logdir /tb_logs/ --port 6006 47 | ``` 48 | 49 | 50 | You can then set up port forwarding to access the tensorboard on the local machine. 51 | 52 | ## License 53 | Attribution-NonCommercial 4.0 International as found in the LICENSE file. 54 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import argparse 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser(description='Train Modules') 12 | # Learning parameters 13 | parser.add_argument('--lr', type=float, default=0.0001, 14 | help='learning rate (default: 0.0001)') 15 | parser.add_argument('--gamma', type=float, default=0.99, 16 | help='discount factor for rewards (default: 0.99)') 17 | parser.add_argument('--tau', type=float, default=0.95, 18 | help='parameter for GAE (default: 0.95)') 19 | parser.add_argument('--eps', type=float, default=1e-5, 20 | help='RMSprop optimizer epsilon (default: 1e-5)') 21 | parser.add_argument('--alpha', type=float, default=0.99, 22 | help='RMSprop optimizer apha (default: 0.99)') 23 | parser.add_argument('--max-grad-norm', type=float, default=50, 24 | help='value loss coefficient (default: 50)') 25 | parser.add_argument('--no-shared', default=False, 26 | help='use an optimizer without shared momentum.') 27 | parser.add_argument('--dim', type=int, default=32, 28 | help='number of dimensions of representation space') 29 | parser.add_argument('--use-conv', action='store_true', help='Use conv layers') 30 | parser.add_argument('--discrete', action='store_true', help='discrete action space') 31 | parser.add_argument('--weight-decay', type=float, default=0.0001) 32 | # TODO:// finish implementation for discrete action spaces. 33 | 34 | # Environment settings 35 | parser.add_argument('--seed', type=int, default=1, 36 | help='random seed (default: 1)') 37 | parser.add_argument('--num-processes', type=int, default=40, 38 | help='how many training processes to use (default: 40)') 39 | parser.add_argument('--num-steps', type=int, default=200, 40 | help='number of forward steps in A3C (default: 20)') 41 | parser.add_argument('--framework', default='gym', 42 | help='framework of env (default: gym)') 43 | parser.add_argument('--env-name', default='InvertedPendulum-v1', 44 | help='environment to train on (default: InvertedPendulum-v1)') 45 | parser.add_argument('--maze-id', type=int, default=0) 46 | parser.add_argument('--maze-length', type=int, default=1) 47 | 48 | # Dynamics Module settings 49 | parser.add_argument('--rollout', type=int, default=20, help="rollout for goal") 50 | parser.add_argument('--train-set', type=str, default=None) 51 | parser.add_argument('--train-batch', type=int, default=2500) 52 | parser.add_argument('--test-set', type=str) 53 | parser.add_argument('--test-batch', type=int, default=2500) 54 | parser.add_argument('--train-size', type=int, default=100000) 55 | parser.add_argument('--dec-loss-coef', type=float, default=0.1, 56 | help='decoder loss coefficient (default: 0.1)') 57 | parser.add_argument('--forward-loss-coef', type=float, default=10, 58 | help='forward loss coefficient (default: 10)') 59 | parser.add_argument('--inv-loss-coef', type=float, default=100, 60 | help='inverse loss coefficient (default: 10)') 61 | parser.add_argument('--num-epochs', type=int, default=1000) 62 | parser.add_argument('--batch-size', type=int, default=128) 63 | parser.add_argument('--num-workers', type=int, default=20) 64 | parser.add_argument('--out', type=str, default='/checkpoint/amyzhang/ddr/models') 65 | parser.add_argument('--dec-mask', type=float, default = None, 66 | help="to use masking while calculating the decoder reconstruction loss ") 67 | 68 | # Rewards Module settings 69 | parser.add_argument('--coef-inner-rew', type=float, default=1.) 70 | parser.add_argument('--checkpoint-interval', type=int, default=1000) 71 | parser.add_argument('--num-episodes', type=int, default=1000000, 72 | help='max number of episodes to train') 73 | parser.add_argument('--max-episode-length', type=int, default=500, 74 | help='maximum length of an episode (default: 500)') 75 | parser.add_argument('--curriculum', type=int, default=0, 76 | help='number of iterations in curriculum. (default: 0, no curriculum)') 77 | parser.add_argument('--single-env', action='store_true') 78 | parser.add_argument('--entropy-coef', type=float, default=0., 79 | help='entropy term coefficient (default: 0.), use 0.0001 for mujoco') 80 | parser.add_argument('--value-loss-coef', type=float, default=0.5, 81 | help='value loss coefficient (default: 0.5)') 82 | parser.add_argument('--rew-loss-coef', type=float, default=0, 83 | help='reward loss coefficient (default: 0)') 84 | parser.add_argument('--lstm-dim', type=int, default=128, 85 | help='number of dimensions of lstm hidden state') 86 | parser.add_argument('--difficulty', type=int, default=-1, help='difficulty of maze') 87 | parser.add_argument('--clip-reward', action='store_true') 88 | parser.add_argument('--finetune-enc', action='store_true', 89 | help="allow the ActorCritic to change the observation space representation") 90 | parser.add_argument('--gae', action='store_true') 91 | parser.add_argument('--algo', default='a3c', 92 | help='algorithm to use: a3c') 93 | 94 | # General training settings 95 | parser.add_argument('--checkpoint', type=int, default=10000) 96 | parser.add_argument('--log-interval', type=int, default=100, 97 | help='interval between training status logs (default: 100)') 98 | parser.add_argument('-v', action='store_true', help='verbose logging') 99 | parser.add_argument('--gpu', action='store_true') 100 | parser.add_argument('--log-dir', type=str, default='/checkpoint/amyzhang/ddr/logs', 101 | help='The logging directory to record the logs and tensorboard summaries') 102 | parser.add_argument('--reset-dir', action='store_true', 103 | help="give this argument to delete the existing logs for the current set of parameters") 104 | 105 | # transfer 106 | parser.add_argument('--file-path', type=str, default=None, 107 | help='path to XML file for mujoco') 108 | parser.add_argument('--neg-reward', action='store_true', 109 | help='set reward negative for transfer') 110 | parser.add_argument('--random-start', action='store_true') 111 | 112 | # What to run 113 | parser.add_argument('--train-dynamics', action='store_true') 114 | parser.add_argument('--train-reward', action='store_true') 115 | parser.add_argument('--train-online', action='store_true', 116 | help='train both modules online') 117 | parser.add_argument('--dynamics-module', type=str, default=None, 118 | help='Encoder from dynamics module') 119 | parser.add_argument('--from-checkpoint', type=str, default=None, 120 | help='Start from stored model') 121 | parser.add_argument('--baseline', action='store_true', 122 | help='Running A3C baseline.') 123 | parser.add_argument('--planning', action='store_true', 124 | help='train with planning (reward and online only)') 125 | parser.add_argument('--transfer', action='store_true', 126 | help='Keep encoder and decoder static') 127 | parser.add_argument('--eval-every', type=float, default=10) 128 | parser.add_argument('--enc-dims', type=int, nargs='+', default=[256, 128]) 129 | parser.add_argument('--dec-dims', type=int, nargs='+', default=[128, 256]) 130 | parser.add_argument('--num-runs', type=int, default=5, 131 | help='number of models to train in parallel') 132 | parser.add_argument('--mcts', action='store_true', help='Monte Carlo Tree Search') 133 | parser.add_argument('--render', action='store_true') 134 | parser.add_argument('-b', type=int, default=4, help='branching factor') 135 | parser.add_argument('-d', type=int, default=3, help='planning depth') 136 | parser.add_argument('--eval', action='store_true') 137 | parser.add_argument('--local', action='store_true') 138 | 139 | args = parser.parse_args() 140 | return args 141 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import math 8 | import sys 9 | from datetime import datetime 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | from model import Encoder, D_Module 15 | 16 | pi = Variable(torch.FloatTensor([math.pi])) 17 | def get_prob(x, mu, sigma_sq): 18 | a = (-1*(Variable(x)-mu).pow(2)/(2*sigma_sq + 1e-5)).exp() 19 | b = 1/(2*sigma_sq*pi.expand_as(sigma_sq) + 1e-5).sqrt() 20 | return a*b 21 | 22 | 23 | def log(msg): 24 | print("[%s]\t%s" % (datetime.now().strftime("%Y-%m-%d %H:%M:%S"), msg)) 25 | sys.stdout.flush() 26 | 27 | 28 | def vlog(msg, v): 29 | if v: 30 | log(msg) 31 | 32 | 33 | def load_encoder(obs_space, args, freeze=True): 34 | enc = Encoder(obs_space, args.dim, 35 | use_conv=args.use_conv) 36 | enc_state = torch.load(args.dynamics_module, map_location=lambda storage, 37 | loc: storage)['enc'] 38 | enc.load_state_dict(enc_state) 39 | enc.eval() 40 | if freeze: 41 | for p in enc.parameters(): 42 | p.requires_grad = False 43 | return enc 44 | 45 | 46 | def load_d_module(action_space, args, freeze=True): 47 | d_module_state = torch.load(args.dynamics_module, map_location=lambda storage, 48 | loc: storage)['d_module'] 49 | d_module = D_Module(action_space, args.dim, args.discrete) 50 | d_module.load_state_dict(d_module_state) 51 | d_module.eval() 52 | if freeze: 53 | for p in d_module.parameters(): 54 | p.requires_grad = False 55 | return d_module 56 | 57 | 58 | def get_action(logit, discrete, v=False): 59 | """Compute action, entropy, and log prob for discrete and continuous case 60 | from logit. 61 | """ 62 | if discrete: 63 | prob = F.softmax(logit) 64 | log_prob = F.log_softmax(logit) 65 | # why entropy regularization ? 66 | entropy = -(log_prob * prob).sum(1, keepdim=True) 67 | action = prob.multinomial() 68 | log_prob = log_prob.gather(1, action) 69 | else: 70 | mu, sigma_sq = logit 71 | sigma_sq = F.softplus(sigma_sq) 72 | vlog('sigma_sq: %s' % str(sigma_sq.data), v) 73 | action = torch.normal(mu, sigma_sq) 74 | prob = get_prob(action.data, mu, sigma_sq) + 1e-5 75 | entropy = -0.5*((2 * sigma_sq * pi.expand_as(sigma_sq) + 1e-5).log() + 1) 76 | log_prob = prob.log() 77 | return action, entropy, log_prob 78 | 79 | 80 | def eval_action(logit, action, discrete, v=False): 81 | mu, sigma_sq = logit 82 | sigma_sq = F.softplus(sigma_sq) 83 | vlog('sigma_sq: %s' % str(sigma_sq.data), v) 84 | prob = get_prob(action.data, mu, sigma_sq) + 1e-5 85 | entropy = -0.5*((2 * sigma_sq * pi.expand_as(sigma_sq) + 1e-5).log() + 1) 86 | log_prob = prob.log() 87 | return entropy, log_prob 88 | 89 | 90 | def mcts(env, z_hat, r_module, d_module, enc, r_state, d_state, args, discrete, 91 | use_env=False): 92 | import torch 93 | import torch.nn.functional as F 94 | from torch.autograd import Variable 95 | 96 | from common import get_action 97 | from envs import get_obs 98 | 99 | (hx_r, cx_r) = r_state 100 | (hx_d, cx_d) = d_state 101 | parent_states = [(z_hat, [], (hx_r, cx_r), (hx_d, cx_d), [], [], [])] 102 | child_states = [] 103 | init_state = get_obs(env, args.framework) 104 | for i in range(args.d): 105 | actions = [] 106 | best_val = None 107 | for z_hat, trajectory, (hx_r, cx_r), (hx_d, cx_d), val, entropies, \ 108 | logprobs in parent_states: 109 | if best_val is None: 110 | best_val = val 111 | elif val < best_val: 112 | continue 113 | value, logit, (hx_r_prime, cx_r_prime) = r_module( 114 | (z_hat, (hx_r, cx_r))) 115 | val.append(value) 116 | if not discrete: 117 | for b in range(args.b): 118 | action, entropy, log_prob = get_action( 119 | logit, discrete=False, v=args.v) 120 | actions.append((action, entropy, log_prob)) 121 | else: 122 | prob = F.softmax(logit) 123 | actions = np.argpartition(prob.data.numpy(), args.b)[:b] 124 | for a, e, lp in actions: 125 | if not use_env: 126 | z_prime_hat, _, (hx_d_prime, cx_d_prime) = d_module( 127 | (z_hat, z_hat, a, (hx_d, cx_d))) 128 | else: 129 | state = get_obs(env, args.framework) 130 | for t in trajectory: 131 | env.step(t.data.numpy()) 132 | s_prime, _, _, _ = env.step(a.data.numpy()) 133 | s_prime = Variable(torch.from_numpy(s_prime).float()) 134 | z_prime_hat = enc(s_prime).unsqueeze(0) 135 | env.reset(state) 136 | hx_d_prime, cx_d_prime = hx_d, cx_d 137 | child_states.append( 138 | (z_prime_hat, trajectory + [a], (hx_r_prime, cx_r_prime), 139 | (hx_d_prime, cx_d_prime), val, entropies + [e], logprobs + [lp])) 140 | child_states = prune(child_states, b) 141 | parent_states = child_states 142 | child_states = [] 143 | 144 | # compute value of final state in each trajectory and choose best 145 | best_val = sum(parent_states[0][4]).data[0,0] 146 | best_ind = 0 147 | for ind, (z, traj, hr, hd, v, _, _) in enumerate(parent_states): 148 | vr, _, _ = r_module((z, hr)) 149 | v.append(vr) 150 | if sum(v).data[0,0] > best_val: 151 | best_ind = ind 152 | best_val = sum(v).data[0,0] 153 | return parent_states[best_ind] 154 | -------------------------------------------------------------------------------- /envs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import numpy as np 8 | 9 | import gym 10 | from gym.spaces.box import Box 11 | from rllab.envs.mujoco.swimmer_env import SwimmerEnv 12 | from rllab.envs.mujoco.ant_env import AntEnv 13 | from rllab.envs.mujoco.half_cheetah_env import HalfCheetahEnv 14 | from rllab.envs.mujoco.hopper_env import HopperEnv 15 | from rllab.envs.mujoco.humanoid_env import HumanoidEnv 16 | from rllab.envs.mujoco.simple_humanoid_env import SimpleHumanoidEnv 17 | from rllab.envs.mujoco.maze.point_maze_env import PointMazeEnv 18 | from rllab.envs.mujoco.maze.swimmer_maze_env import SwimmerMazeEnv 19 | from rllab.envs.mujoco.maze.ant_maze_env import AntMazeEnv 20 | from rllab.envs.mujoco.inverted_double_pendulum_env import InvertedDoublePendulumEnv 21 | from rllab.misc import ext 22 | from rllab.envs.normalized_env import normalize 23 | 24 | from common import * 25 | 26 | 27 | def create_env(env_str, framework='gym', args=None, eval_flag=False, norm=True, 28 | rank=0): 29 | if framework == 'gym': 30 | env = gym.make(env_str) 31 | if norm: 32 | env = NormalizedEnv(env) 33 | elif framework == 'rllab': 34 | if not hasattr(args, 'file_path'): 35 | args.file_path = None 36 | if env_str.endswith('MazeEnv'): 37 | if not hasattr(args, 'coef_inner_rew'): 38 | args.coef_inner_rew = 0. 39 | if not hasattr(args, 'maze_structure'): 40 | args.maze_structure = None 41 | if not hasattr(args, 'random_start'): 42 | args.random_start = False 43 | if not hasattr(args, 'difficulty'): 44 | args.difficulty = -1 45 | difficulty = args.difficulty 46 | if args.difficulty > 1 and not eval_flag: 47 | if args.difficulty <= 5: 48 | difficulty = np.random.choice(range( 49 | args.difficulty - 1, args.difficulty + 1)) 50 | elif args.difficulty == -1: 51 | difficulty = np.random.choice([1, 2, 3, 4, 5, -1]) 52 | env = eval(env_str)(maze_id=args.maze_id, length=args.maze_length, 53 | coef_inner_rew=args.coef_inner_rew, 54 | structure=args.maze_structure, 55 | file_path=args.file_path, 56 | random_start=args.random_start, 57 | difficulty=difficulty) 58 | env.horizon = args.max_episode_length 59 | vlog(args.maze_structure, args.v) 60 | else: 61 | env = eval(env_str)(file_path=args.file_path) 62 | if norm: 63 | env = normalize(env) 64 | else: 65 | raise("framework not supported") 66 | env.reset() 67 | set_seed(args.seed + rank, env, framework) 68 | return env 69 | 70 | 71 | def wrapper(env): 72 | def _wrap(): 73 | return env 74 | return _wrap 75 | 76 | 77 | def get_obs(env, framework): 78 | if framework == 'gym': 79 | state = env.unwrapped._get_obs() 80 | elif framework == 'rllab': 81 | state = env.get_current_obs() 82 | else: 83 | raise("framework not supported") 84 | return state 85 | 86 | 87 | def set_seed(seed, env, framework): 88 | if framework == 'gym': 89 | env.unwrapped.seed(seed) 90 | elif framework == 'rllab': 91 | ext.set_seed(seed) 92 | else: 93 | raise("framework not supported") 94 | return env 95 | 96 | 97 | def reset_env(env, args): 98 | """Reset env. Can differ based on env. e.g. in maze maybe we want to randomly 99 | deposit the agent in different locations?""" 100 | env.reset() 101 | return get_obs(env, args.framework) 102 | 103 | 104 | class NormalizedEnv(gym.ObservationWrapper): 105 | def __init__(self, env=None): 106 | super(NormalizedEnv, self).__init__(env) 107 | self.state_mean = 0 108 | self.state_std = 0 109 | self.alpha = 0.9999 110 | self.num_steps = 0 111 | 112 | def _observation(self, observation): 113 | self.num_steps += 1 114 | self.state_mean = self.state_mean * self.alpha + \ 115 | observation.mean() * (1 - self.alpha) 116 | self.state_std = self.state_std * self.alpha + \ 117 | observation.std() * (1 - self.alpha) 118 | 119 | unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps)) 120 | unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps)) 121 | 122 | return (observation - unbiased_mean) / (unbiased_std + 1e-8) 123 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import numpy as np 11 | import os 12 | import random 13 | from operator import itemgetter 14 | 15 | # Environment settings 16 | parser = argparse.ArgumentParser(description='Eval DDR') 17 | 18 | parser.add_argument('--dynamics-module', type=str, default=None, 19 | help='Dynamics module') 20 | parser.add_argument('--rewards-module', type=str, default=None, 21 | help='Rewards module') 22 | parser.add_argument('--num-processes', type=int, default=20, 23 | help='how many training processes to use (default: 20)') 24 | parser.add_argument('--N', type=int, default=1, 25 | help='Number of episodes') 26 | parser.add_argument('--rollout', type=int, default=20, help="rollout for goal") 27 | parser.add_argument('--seed', type=int, default=1, 28 | help='random seed (default: 1)') 29 | parser.add_argument('--render', action='store_true') 30 | parser.add_argument('--out', type=str, default=None) 31 | parser.add_argument('--max-episode-length', type=int, default=1000, 32 | help='maximum length of an episode') 33 | parser.add_argument('--framework', default='gym', 34 | help='framework of env (default: gym)') 35 | parser.add_argument('--env-name', default='InvertedPendulum-v1', 36 | help='environment to train on (default: InvertedPendulum-v1)') 37 | parser.add_argument('--maze-id', type=int, default=0) 38 | parser.add_argument('--maze-length', type=int, default=1) 39 | parser.add_argument('--log-interval', type=int, default=1) 40 | parser.add_argument('--baseline', action='store_true') 41 | parser.add_argument('--local', action='store_true', 42 | help='running locally to render, no multiprocessing') 43 | parser.add_argument('--single-env', action='store_true') 44 | parser.add_argument('--coef-inner-rew', type=float, default=1.) 45 | parser.add_argument('--mcts', action='store_true', help='Monte Carlo Tree Search') 46 | parser.add_argument('-b', type=int, default=4, help='branching factor') 47 | parser.add_argument('-d', type=int, default=3, help='planning depth') 48 | parser.add_argument('--file-path', type=str, default=None, 49 | help='path to XML file for mujoco') 50 | parser.add_argument('--save-figs', action='store_true') 51 | parser.add_argument('--neg-reward', action='store_true', 52 | help='set reward negative for transfer') 53 | parser.add_argument('--use-env', action='store_true', help='Use env with MCTS') 54 | parser.add_argument('-v', action='store_true', help='verbose logging') 55 | parser.add_argument('--difficulty', type=int, default=-1, help='difficulty of maze') 56 | 57 | 58 | def prune(states, b): 59 | """Prune states down to length b, sorting by val.""" 60 | return sorted(states, key=itemgetter(4))[:b] 61 | 62 | 63 | def test(block, args, d_args, r_args, d_module, r_module, enc, dec, q=None, rank=0): 64 | import torch 65 | from torch.autograd import Variable 66 | 67 | from envs import create_env, reset_env, get_obs 68 | from common import get_action, log 69 | 70 | seed = args.seed * 9823 + 194885 + rank # make sure doesn't copy train 71 | torch.manual_seed(seed) 72 | np.random.seed(seed) 73 | random.seed(seed) 74 | i = 1 75 | total_acc, total_reward = [], [] 76 | avg_succ, avg_reward, avg_len = 0, 0, 0 77 | while len(total_acc) < block: 78 | reward_sum, succ = 0, 0 79 | actions = [] 80 | if args.single_env and i > 1: 81 | reset_env(env, args) 82 | else: 83 | env = create_env(args.env_name, framework=args.framework, args=args, eval_flag=True) 84 | done = False 85 | step = 0 86 | 87 | # Should the two LSTMs share a hidden state? 88 | cx_r = Variable(torch.zeros(1, r_args.dim)) 89 | hx_r = Variable(torch.zeros(1, r_args.dim)) 90 | if not args.baseline: 91 | cx_d = Variable(torch.zeros(1, d_args.dim)) 92 | hx_d = Variable(torch.zeros(1, d_args.dim)) 93 | while step < args.max_episode_length and not done: 94 | # Encode state 95 | state = get_obs(env, r_args.framework) 96 | state = Variable(torch.from_numpy(state).float()) 97 | if not args.baseline: 98 | z = enc(state) 99 | z_prime_hat = z.unsqueeze(0) 100 | else: 101 | z_prime_hat = state.unsqueeze(0) 102 | actions = [] 103 | if args.mcts: 104 | z_prime_hat, actions, (hx_r, cx_r), (hx_d, cx_d), _, _, _ = mcts( 105 | env, z_prime_hat, r_module, d_module, enc, (hx_r, cx_r), 106 | (hx_d, cx_d), args, discrete=r_args.discrete, 107 | use_env=args.use_env) 108 | for r in range(args.rollout - args.d): 109 | value, logit, (hx_r, cx_r) = r_module( 110 | (z_prime_hat, (hx_r, cx_r))) 111 | action, entropy, log_prob = get_action( 112 | logit, discrete=r_args.discrete) 113 | actions.append(action) 114 | if not args.baseline: 115 | z_prime_hat, _, (hx_d, cx_d) = d_module( 116 | (z_prime_hat, z_prime_hat, action, (hx_d, cx_d))) 117 | if args.save_figs: 118 | s_prime_hat = dec(z_prime_hat) 119 | 120 | for action in actions[:args.rollout]: 121 | _, reward, done, _ = env.step(action.data.numpy()) 122 | if args.render: 123 | env.render() 124 | reward_sum += reward 125 | step += 1 126 | if done: 127 | succ = 1 128 | break 129 | U = 1. / i 130 | total_acc.append(succ) 131 | total_reward.append(reward_sum) 132 | avg_succ = avg_succ * (1 - U) + succ * U 133 | avg_reward = avg_reward * (1 - U) + reward_sum * U 134 | avg_len = avg_len * (1 - U) + (step + 1) * U 135 | if i % args.log_interval == 0: 136 | log("Eval: {:d} episodes, avg succ {:.2f}, avg reward {:.2f}, avg length {:.2f}".format( 137 | len(total_acc), avg_succ, reward_sum, step)) 138 | i += 1 139 | if args.local: 140 | return (sum(total_acc), len(total_acc), sum(total_reward), avg_len) 141 | q.put((sum(total_acc), len(total_acc), sum(total_reward))) 142 | 143 | 144 | if __name__ == '__main__': 145 | import torch 146 | import torch.multiprocessing as mp 147 | mp.set_start_method('spawn') 148 | 149 | from envs import * 150 | from model import * 151 | from common import * 152 | # from ppo.model import MLPPolicy 153 | 154 | os.environ['OMP_NUM_THREADS'] = '1' 155 | os.environ['CUDA_VISIBLE_DEVICES'] = "" 156 | 157 | args = parser.parse_args() 158 | if not args.mcts: 159 | args.d = 0 160 | log(args) 161 | torch.manual_seed(args.seed) 162 | 163 | d_args, d_module, enc, dec = None, None, None, None 164 | r_state_dict, r_args = torch.load(args.rewards_module, map_location=lambda storage, loc: storage) 165 | if args.single_env and hasattr(r_args, 'maze_structure'): 166 | args.maze_structure = r_args.maze_structure 167 | env = create_env(args.env_name, framework=args.framework, args=args, eval_flag=True) 168 | r_module = R_Module(env.action_space.shape[0], r_args.dim, 169 | discrete=r_args.discrete, baseline=r_args.baseline, 170 | state_space=env.observation_space.shape[0]) 171 | r_module.load_state_dict(r_state_dict) 172 | r_module.eval() 173 | if not args.baseline: 174 | if args.local: 175 | r_args.dynamics_module = '/Users/amyzhang/ddr_for_tl' + r_args.dynamics_module[24:] 176 | if args.dynamics_module is None: 177 | d_dict = torch.load(r_args.dynamics_module, map_location=lambda storage, loc: storage) 178 | else: 179 | d_dict = torch.load(args.dynamics_module, map_location=lambda storage, loc: storage) 180 | d_args = d_dict['args'] 181 | enc_state = d_dict['enc'] 182 | dec_state = d_dict['dec'] 183 | d_state_dict = d_dict['d_module'] 184 | d_module = D_Module(env.action_space.shape[0], d_args.dim, d_args.discrete) 185 | d_module.load_state_dict(d_state_dict) 186 | d_module.eval() 187 | 188 | enc = Encoder(env.observation_space.shape[0], d_args.dim, 189 | use_conv=d_args.use_conv) 190 | dec = Decoder(env.observation_space.shape[0], d_args.dim, 191 | use_conv=d_args.use_conv) 192 | enc.load_state_dict(enc_state) 193 | dec.load_state_dict(dec_state) 194 | enc.eval() 195 | dec.eval() 196 | 197 | block = int(args.N / args.num_processes) 198 | if args.local: 199 | all_succ, all_total, avg_reward = test( 200 | block, args, d_args, r_args, d_module, r_module, enc, dec) 201 | else: 202 | processes = [] 203 | queues = [] 204 | for rank in range(0, args.num_processes): 205 | q = mp.Queue() 206 | p = mp.Process(target=test, args=( 207 | block, args, d_args, r_args, d_module, r_module, enc, dec, q, rank)) 208 | p.start() 209 | processes.append(p) 210 | queues.append(q) 211 | 212 | for i, p in enumerate(processes): 213 | log("Exit process %d" % i) 214 | p.join() 215 | 216 | all_succ = 0 217 | all_total = 0 218 | total_reward = 0 219 | for q in queues: 220 | while not q.empty(): 221 | succ, total, total_r = q.get() 222 | all_succ += succ 223 | all_total += total 224 | total_reward += total_r 225 | log("Success: %s, %s, %s" % (all_succ / all_total, all_succ, all_total)) 226 | log("Average Reward: %s" % (total_reward / all_total)) 227 | if args.out: 228 | with open(args.out, 'a') as f: 229 | f.write("Success: %s \n" % (all_succ / all_total)) 230 | -------------------------------------------------------------------------------- /eval_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import math 8 | import numpy as np 9 | import os 10 | import time 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torch.autograd import Variable 16 | 17 | from envs import * 18 | from model import Encoder, Decoder, D_Module, R_Module 19 | from common import * 20 | from tensorboardX import SummaryWriter 21 | from itertools import chain 22 | from eval import test 23 | 24 | 25 | def eval_reward(args, shared_model, writer_dir=None): 26 | """ 27 | For evaluation 28 | 29 | Arguments: 30 | - writer: the tensorboard summary writer directory (note: can't get it working directly with the SummaryWriter object) 31 | """ 32 | writer = SummaryWriter(log_dir=os.path.join(writer_dir,'eval')) if writer_dir is not None else None 33 | 34 | # current episode stats 35 | episode_reward = episode_value_mse = episode_td_error = episode_pg_loss = episode_length = 0 36 | 37 | # global stats 38 | i_episode = 0 39 | total_episode = total_steps = 0 40 | num_goals_achieved = 0 41 | 42 | # intilialize the env and models 43 | torch.manual_seed(args.seed) 44 | env = create_env(args.env_name, framework=args.framework, args=args) 45 | set_seed(args.seed , env, args.framework) 46 | 47 | shared_enc, shared_dec, shared_d_module, shared_r_module = shared_model 48 | 49 | enc = Encoder(env.observation_space.shape[0], args.dim, 50 | use_conv=args.use_conv) 51 | dec = Decoder(env.observation_space.shape[0], args.dim, 52 | use_conv=args.use_conv) 53 | d_module = D_Module(env.action_space.shape[0], args.dim, args.discrete) 54 | r_module = R_Module(env.action_space.shape[0], args.dim, 55 | discrete=args.discrete, baseline=False, 56 | state_space=env.observation_space.shape[0]) 57 | 58 | 59 | all_params = chain(enc.parameters(), dec.parameters(), 60 | d_module.parameters(), 61 | r_module.parameters()) 62 | 63 | if args.from_checkpoint is not None: 64 | model_state, _ = torch.load(args.from_checkpoint) 65 | model.load_state_dict(model_state) 66 | 67 | # set the model to evaluation mode 68 | enc.eval() 69 | dec.eval() 70 | d_module.eval() 71 | r_module.eval() 72 | 73 | # reset the state 74 | state = env.reset() 75 | state = Variable(torch.from_numpy(state).float()) 76 | 77 | start = time.time() 78 | 79 | while total_episode < args.num_episodes: 80 | 81 | # Sync with the shared model 82 | r_module.load_state_dict(shared_r_module.state_dict()) 83 | d_module.load_state_dict(shared_d_module.state_dict()) 84 | enc.load_state_dict(shared_enc.state_dict()) 85 | dec.load_state_dict(shared_dec.state_dict()) 86 | 87 | # reset stuff 88 | cd_p = Variable(torch.zeros(1, args.lstm_dim)) 89 | hd_p = Variable(torch.zeros(1, args.lstm_dim)) 90 | 91 | # for the reward 92 | cr_p = Variable(torch.zeros(1, args.lstm_dim)) 93 | hr_p = Variable(torch.zeros(1, args.lstm_dim)) 94 | 95 | i_episode += 1 96 | episode_length = 0 97 | episode_reward = 0 98 | args.local = True 99 | args.d = 0 100 | succ, _, episode_reward, episode_length = test( 101 | 1, args, args, args, d_module, r_module, enc) 102 | log("Eval: succ {:.2f}, reward {:.2f}, length {:.2f}".format( 103 | succ, episode_reward, episode_length)) 104 | # Episode has ended, write the summaries here 105 | if writer_dir is not None: 106 | # current episode stats 107 | writer.add_scalar('eval/episode_reward', episode_reward, i_episode) 108 | writer.add_scalar('eval/episode_length', episode_length, i_episode) 109 | writer.add_scalar('eval/success', succ, i_episode) 110 | 111 | time.sleep(args.eval_every) 112 | print("sleep") 113 | -------------------------------------------------------------------------------- /generate_dynamics_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import argparse 8 | import os 9 | 10 | parser = argparse.ArgumentParser(description='Generate Data') 11 | parser.add_argument('--env-name', default='InvertedPendulum-v1', 12 | help='environment to train on (default: InvertedPendulum-v1)') 13 | parser.add_argument('--N', type=int, default=1000000) 14 | parser.add_argument('--seed', type=int, default=1, 15 | help='random seed (default: 1)') 16 | parser.add_argument('--out', type=str, default='/data/ddr') 17 | parser.add_argument('--num-processes', type=int, default=40, 18 | help='how many training processes to use (default: 40)') 19 | parser.add_argument('--rollout', type=int, default=20, help="rollout for goal") 20 | parser.add_argument('--method', type=str, default='random', 21 | help='["random", "pixel_control"]') 22 | parser.add_argument('--render', action='store_true') 23 | parser.add_argument('--reset', action='store_true') 24 | parser.add_argument('--from-policy', type=str, default=None, 25 | help="use reward module as policy") 26 | parser.add_argument('--framework', default='gym', 27 | help='framework of env (default: gym)') 28 | parser.add_argument('--maze-id', type=int, default=0) 29 | parser.add_argument('--maze-length', type=int, default=1) 30 | parser.add_argument('--single-env', action='store_true') 31 | parser.add_argument('--random-start', action='store_true') 32 | parser.add_argument('-v', action='store_true', help='verbose logging') 33 | parser.add_argument('--max-episode-length', type=int, default=500, 34 | help='maximum length of an episode (default: 500)') 35 | parser.add_argument('--file-path', type=str, default=None, 36 | help='path to XML file for mujoco') 37 | 38 | 39 | def generate_data(rank, args, start, end): 40 | 41 | from envs import create_env, set_seed, get_obs 42 | from model import R_Module 43 | import torch 44 | 45 | print(rank, "started") 46 | 47 | env = create_env(args.env_name, framework=args.framework, args=args) 48 | env = set_seed(args.seed + rank, env, args.framework) 49 | state = get_obs(env, args.framework) 50 | 51 | if args.from_policy is not None: 52 | model_state, r_args = torch.load(args.from_policy) 53 | policy = R_Module(env.action_space.shape[0], 54 | r_args.dim, 55 | discrete=r_args.discrete, baseline=r_args.baseline, 56 | state_space=env.observation_space.shape[0]) 57 | policy.load_state_dict(model_state) 58 | policy.eval() 59 | 60 | 61 | states = [] 62 | actions = [] 63 | i = start 64 | 65 | done = False 66 | 67 | while i < end: 68 | if i % 100 == 0: 69 | print(rank, i) 70 | ep_states = [] 71 | ep_actions = [] 72 | if args.from_policy is not None: 73 | cx_p = Variable(torch.zeros(1, r_args.dim)) 74 | hx_p = Variable(torch.zeros(1, r_args.dim)) 75 | for j in range(args.rollout): 76 | if args.from_policy is not None: 77 | value, logit, (hx_p, cx_p) = policy( 78 | state.unsqueeze(0), (hx_p, cx_p)) 79 | a, _, _ = get_action(logit, r_args.discrete) 80 | else: 81 | a = env.action_space.sample() 82 | ep_actions.append(a) 83 | 84 | state = get_obs(env, args.framework) 85 | env.step(a) 86 | 87 | if args.render: 88 | env.render() 89 | 90 | ep_states.append(state) 91 | 92 | final_state = get_obs(env, args.framework) 93 | ep_states.append(final_state) 94 | states.append(ep_states) 95 | actions.append(ep_actions) 96 | i += 1 97 | 98 | # reset the environment here 99 | if done or args.reset: 100 | env.reset() 101 | done = False 102 | 103 | torch.save((states, actions), os.path.join( 104 | args.out_dir, 'states_actions_%s_%s.pt' % (start, end))) 105 | 106 | 107 | 108 | if __name__ == '__main__': 109 | import torch 110 | import torch.multiprocessing as mp 111 | mp.set_start_method('spawn') 112 | 113 | from torch.autograd import Variable 114 | from envs import create_env, set_seed, get_obs 115 | from model import R_Module 116 | os.environ['OMP_NUM_THREADS'] = '1' 117 | 118 | args = parser.parse_args() 119 | env_name = args.env_name 120 | env_name += '_rollout%s' % args.rollout 121 | if args.env_name.endswith('MazeEnv'): 122 | env_name += 'mazeid%slength%s' % (args.maze_id, args.maze_length) 123 | if args.single_env and args.maze_id == -1: 124 | env = create_env(args.env_name, framework=args.framework, args=args) 125 | env_name += '_single_env' 126 | args.maze_structure = env._env.MAZE_STRUCTURE 127 | if args.random_start: 128 | env_name += '_randomstart' 129 | if args.file_path is not None: 130 | env_name += '_transfer' 131 | if args.framework == 'mazebase': 132 | env_name += '_rollout_%s_length_%s' % (args.rollout, args.maze_length) 133 | args.out_dir = os.path.join(args.out, env_name) 134 | print(args) 135 | print(args.out_dir) 136 | os.makedirs(args.out_dir, exist_ok=True) 137 | processes = [] 138 | block = int(args.N / args.num_processes) 139 | for rank in range(0, args.num_processes): 140 | start = rank * block 141 | end = (rank + 1) * block 142 | p = mp.Process(target=generate_data, args=(rank, args, start, end)) 143 | p.start() 144 | processes.append(p) 145 | 146 | torch.save(args, os.path.join(args.out_dir, 'args.pt')) 147 | 148 | # exit cleanly 149 | for p in processes: 150 | p.join() 151 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from __future__ import print_function 8 | 9 | import datetime 10 | import os 11 | import time 12 | import shutil 13 | from itertools import chain 14 | import dill 15 | 16 | from arguments import get_args 17 | 18 | 19 | if __name__ == '__main__': 20 | import torch 21 | import torch.multiprocessing as mp 22 | mp.set_start_method('spawn') 23 | 24 | import my_optim 25 | from envs import create_env 26 | from model import * 27 | from test import test 28 | from train_reward_module import train_rewards 29 | from common import * 30 | from train_dynamics_module import train_dynamics 31 | from train_online import train_online 32 | from eval_modules import eval_reward 33 | from tensorboardX import SummaryWriter 34 | 35 | os.environ['OMP_NUM_THREADS'] = '1' 36 | args = get_args() 37 | log(args) 38 | 39 | if not args.gpu: 40 | os.environ['CUDA_VISIBLE_DEVICES'] = "" 41 | 42 | torch.manual_seed(args.seed) 43 | 44 | args_param = vars(args) 45 | toprint = ['seed', 'lr', 'entropy_coef', 'value_loss_coef', 'num_steps', 46 | 'dim'] 47 | 48 | if args.planning: 49 | toprint += ['rollout'] 50 | 51 | env_name = args.env_name 52 | if args.env_name.endswith("MazeEnv"): 53 | env_name += 'mazeid%slength%s' % (args.maze_id, args.maze_length) 54 | toprint += ['random_start', 'difficulty'] 55 | if args.baseline: 56 | model_type = 'baseline' 57 | if args.neg_reward: 58 | model_type += '_neg_reward' 59 | if args.file_path: 60 | model_type += '_dynamics_transfer' 61 | toprint += ['algo', 'gae', 'num_processes'] 62 | elif args.train_dynamics: 63 | model_type = 'dynamics_planning' 64 | toprint = ['lr', 'forward_loss_coef', 'dec_loss_coef', 'inv_loss_coef', 'rollout', 'dim', 65 | 'train_size'] 66 | # env_name = os.path.basename(args.train_set.strip('/')) 67 | if args.single_env: 68 | data_args = torch.load(os.path.join(args.train_set, 'args.pt')) 69 | args.maze_structure = data_args.maze_structure 70 | elif args.train_reward: 71 | model_type = 'reward' 72 | if args.neg_reward: 73 | model_type += '_neg_reward' 74 | if args.file_path: 75 | model_type += '_dynamics_transfer' 76 | toprint += ['algo', 'gae'] 77 | if args.planning: 78 | model_type += '_planning' 79 | elif args.train_online: 80 | model_type = 'online' 81 | toprint += ['lr', 'dec_loss_coef', 'inv_loss_coef', 'rollout', 'dim'] 82 | if args.transfer: 83 | model_type += '_transfer' 84 | 85 | name = '' 86 | for arg in toprint: 87 | name += '_{}{}'.format(arg, args_param[arg]) 88 | out_dir = os.path.join(args.out, env_name, model_type, name) 89 | args.out = out_dir 90 | 91 | dynamics_path = '' 92 | if args.dynamics_module is not None and not args.baseline: 93 | dynamics_path = args.dynamics_module.split('/') 94 | dynamics_path = dynamics_path[-4] + dynamics_path[-2] +\ 95 | '_' + dynamics_path[-1].strip('.pt') 96 | args.out = os.path.join(out_dir, dynamics_path) 97 | os.makedirs(args.out, exist_ok=True) 98 | 99 | # create the tensorboard summary writer here 100 | tb_log_dir = os.path.join(args.log_dir, env_name, model_type, name, 101 | dynamics_path, 'tb_logs') 102 | print(tb_log_dir) 103 | print(args.out) 104 | 105 | if args.reset_dir: 106 | shutil.rmtree(tb_log_dir, ignore_errors=True) 107 | os.makedirs(tb_log_dir, exist_ok=True) 108 | tb_writer = SummaryWriter(log_dir=tb_log_dir) 109 | 110 | # dump all the arguments in the tb_log_dir 111 | print(args, file=open(os.path.join(tb_log_dir, "arguments"), "w")) 112 | 113 | 114 | env = create_env(args.env_name, framework=args.framework, args=args) 115 | if args.train_dynamics: 116 | train_dynamics(env, args, None) # tb_writer 117 | if args.train_reward: 118 | model_name = 'rewards_module' 119 | if args.from_checkpoint is not None: # using curriculum 120 | model_name += 'curr' 121 | if args.single_env: 122 | model_name += '_single_env' 123 | args.maze_structure = env._env.MAZE_STRUCTURE 124 | args.model_name = model_name 125 | enc = None 126 | d_module = None 127 | assert args.dynamics_module is not None 128 | enc = load_encoder(env.observation_space.shape[0], args) 129 | if args.planning: 130 | d_module = load_d_module(env.action_space.shape[0], args) 131 | 132 | shared_model = R_Module(env.action_space.shape[0], args.dim, 133 | discrete=args.discrete, baseline=args.baseline, 134 | state_space=env.observation_space.shape[0]) 135 | 136 | # shared reward module for everyone 137 | shared_model.share_memory() 138 | 139 | if args.no_shared: 140 | optimizer = None 141 | else: 142 | optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr) 143 | optimizer.share_memory() 144 | 145 | processes = [] 146 | 147 | train_agent_method = None 148 | 149 | total_args = args 150 | train_agent_method = train_rewards 151 | 152 | for rank in range(0, args.num_processes): 153 | if rank==0: 154 | p = mp.Process(target=train_agent_method, args=( 155 | rank, total_args, shared_model, enc, optimizer, tb_log_dir, 156 | d_module)) 157 | else: 158 | p = mp.Process(target=train_agent_method, args=( 159 | rank, total_args, shared_model, enc, optimizer, None, d_module)) 160 | p.start() 161 | processes.append(p) 162 | 163 | for p in processes: 164 | p.join() 165 | 166 | torch.save((shared_model.state_dict(), args), os.path.join( 167 | args.out, model_name + '%s.pt' % args.num_episodes)) 168 | 169 | print(os.path.join(args.out, model_name)) 170 | if args.train_online: 171 | model_name = 'rewards_module' 172 | if args.from_checkpoint is not None: # using curriculum 173 | model_name += 'curr' 174 | if args.single_env: 175 | model_name += '_single_env' 176 | args.maze_structure = env._env.MAZE_STRUCTURE 177 | args.model_name = model_name 178 | shared_enc = Encoder(env.observation_space.shape[0], args.dim, 179 | use_conv=args.use_conv) 180 | shared_dec = Decoder(env.observation_space.shape[0], args.dim, 181 | use_conv=args.use_conv) 182 | shared_d_module = D_Module(env.action_space.shape[0], args.dim, 183 | args.discrete) 184 | shared_r_module = R_Module(env.action_space.shape[0], args.dim, 185 | discrete=args.discrete, baseline=args.baseline, 186 | state_space=env.observation_space.shape[0]) 187 | 188 | shared_enc = Encoder(env.observation_space.shape[0], args.dim, 189 | use_conv=args.use_conv) 190 | shared_dec = Decoder(env.observation_space.shape[0], args.dim, 191 | use_conv=args.use_conv) 192 | shared_d_module = D_Module(env.action_space.shape[0], args.dim, 193 | args.discrete) 194 | shared_r_module = R_Module(env.action_space.shape[0], args.dim, 195 | discrete=args.discrete, baseline=args.baseline, 196 | state_space=env.observation_space.shape[0]) 197 | 198 | shared_enc.share_memory() 199 | shared_dec.share_memory() 200 | shared_d_module.share_memory() 201 | shared_r_module.share_memory() 202 | all_params = chain(shared_enc.parameters(), shared_dec.parameters(), 203 | shared_d_module.parameters(), 204 | shared_r_module.parameters()) 205 | shared_model = [shared_enc, shared_dec, shared_d_module, shared_r_module] 206 | 207 | if args.single_env: 208 | model_name += '_single_env' 209 | args.maze_structure = env.MAZE_STRUCTURE 210 | 211 | if args.no_shared: 212 | optimizer = None 213 | else: 214 | optimizer = my_optim.SharedAdam(all_params, lr=args.lr) 215 | optimizer.share_memory() 216 | 217 | train_agent_method = train_online 218 | 219 | processes = [] 220 | for rank in range(0, args.num_processes): 221 | if rank==0: 222 | p = mp.Process(target=train_agent_method, args=( 223 | rank, args, shared_model, optimizer, tb_log_dir)) 224 | else: 225 | p = mp.Process(target=train_agent_method, args=( 226 | rank, args, shared_model, optimizer)) 227 | p.start() 228 | processes.append(p) 229 | 230 | # start an eval process here 231 | eval_agent_method = eval_reward 232 | p = mp.Process(target=eval_agent_method, args=( 233 | args, shared_model, tb_log_dir)) 234 | p.start() 235 | processes.append(p) 236 | 237 | for p in processes: 238 | p.join() 239 | results_dict = {'args': args} 240 | torch.save((shared_r_module.state_dict(), args), os.path.join( 241 | args.out, 'reward_module%s.pt' % args.num_episodes)) 242 | results_dict['enc'] = shared_enc.state_dict() 243 | results_dict['dec'] = shared_dec.state_dict() 244 | results_dict['d_module'] = shared_d_module.state_dict() 245 | torch.save(results_dict, 246 | os.path.join(args.out, 'dynamics_module%s.pt' % args.num_episodes)) 247 | log("Saved model %s" % os.path.join(args.out, model_name)) 248 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.init as init 12 | 13 | def normalized_columns_initializer(weights, std=1.0): 14 | out = torch.randn(weights.size()) 15 | out *= std / torch.sqrt(out.pow(2).sum(1, keepdim=True)) 16 | return out 17 | 18 | 19 | def weights_init(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('Conv') != -1: 22 | weight_shape = list(m.weight.data.size()) 23 | fan_in = np.prod(weight_shape[1:4]) 24 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] 25 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 26 | m.weight.data.uniform_(-w_bound, w_bound) 27 | m.bias.data.fill_(0) 28 | elif classname.find('Linear') != -1: 29 | weight_shape = list(m.weight.data.size()) 30 | fan_in = weight_shape[1] 31 | fan_out = weight_shape[0] 32 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 33 | m.weight.data.uniform_(-w_bound, w_bound) 34 | m.bias.data.fill_(0) 35 | 36 | 37 | class Encoder(torch.nn.Module): 38 | def __init__(self, obs_space, dim, use_conv=False): 39 | """ 40 | architecture should be input, so that we can pass multiple jobs ! 41 | """ 42 | super(Encoder, self).__init__() 43 | self.use_conv = use_conv 44 | self.obs_space = obs_space 45 | if use_conv: 46 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1) 47 | self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1) 48 | self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1) 49 | self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1) 50 | else: 51 | self.linear1 = nn.Linear(obs_space, dim) 52 | self.linear2 = nn.Linear(dim, 32 * 3 * 3) 53 | self.fc = nn.Linear(32 * 3 * 3, dim) 54 | self.apply(weights_init) 55 | self.train() 56 | 57 | def forward(self, inputs): 58 | # why elu and not relu ? 59 | if self.use_conv: 60 | x = F.elu(self.conv1(inputs)) 61 | x = F.elu(self.conv2(x)) 62 | x = F.elu(self.conv3(x)) 63 | x = F.elu(self.conv4(x)) 64 | else: 65 | x = F.elu(self.linear1(inputs)) 66 | x = F.elu(self.linear2(x)) 67 | 68 | x = F.tanh(self.fc(x)) 69 | 70 | return x 71 | 72 | 73 | class Decoder(torch.nn.Module): 74 | def __init__(self, obs_space, dim, use_conv=False): 75 | super(Decoder, self).__init__() 76 | self.use_conv = use_conv 77 | self.fc = nn.Linear(dim, 32 * 3 * 3) 78 | if self.use_conv: 79 | self.deconv1 = nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1) 80 | self.deconv2 = nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1) 81 | self.deconv3 = nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1) 82 | self.deconv4 = nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1) 83 | else: 84 | self.linear1 = nn.Linear(32 * 3 * 3, dim) 85 | self.linear2 = nn.Linear(dim, obs_space) 86 | self.apply(weights_init) 87 | self.train() 88 | 89 | def forward(self, inputs): 90 | x = F.elu(self.fc(inputs)) 91 | if self.use_conv: 92 | x = F.elu(self.deconv1(x)) 93 | x = F.elu(self.deconv2(x)) 94 | x = F.elu(self.deconv3(x)) 95 | x = self.deconv4(x) 96 | else: 97 | x = F.elu(self.linear1(x)) 98 | x = self.linear2(x) 99 | return x 100 | 101 | 102 | class D_Module(torch.nn.Module): 103 | def __init__(self, action_space, dim, discrete=False): 104 | super(D_Module, self).__init__() 105 | self.dim = dim 106 | self.discrete = discrete 107 | 108 | self.za_embed = nn.Linear(2 * dim, dim) 109 | self.lstm_dynamics = nn.LSTMCell(dim, dim) 110 | self.z_embed = nn.Linear(dim, dim) 111 | 112 | self.inv = nn.Linear(2 * dim, dim) 113 | self.inv2 = nn.Linear(dim, action_space) 114 | 115 | self.action_linear = nn.Linear(action_space, dim) 116 | self.action_linear2 = nn.Linear(dim, dim) 117 | self.apply(weights_init) 118 | 119 | self.lstm_dynamics.bias_ih.data.fill_(0) 120 | self.lstm_dynamics.bias_hh.data.fill_(0) 121 | 122 | self.train() 123 | 124 | def forward(self, inputs): 125 | z, z_prime, actions, (hx_d, cx_d) = inputs 126 | z = z.view(-1, self.dim) 127 | 128 | a_embedding = F.elu(self.action_linear(actions)) 129 | a_embedding = self.action_linear2(a_embedding) 130 | 131 | za_embedding = self.za_embed( 132 | torch.cat([z, a_embedding.view(z.size())], 1)) 133 | hx_d, cx_d = self.lstm_dynamics(za_embedding, (hx_d, cx_d)) 134 | z_prime_hat = F.tanh(self.z_embed(hx_d)) 135 | 136 | # decode the action 137 | if z_prime is not None: 138 | z_prime = z_prime.view(-1, self.dim) 139 | else: 140 | z_prime = z_prime_hat 141 | a_hat = F.elu(self.inv(torch.cat([z, z_prime], 1))) 142 | a_hat = self.inv2(a_hat) 143 | return z_prime_hat, a_hat, (hx_d, cx_d) 144 | 145 | 146 | class R_Module(torch.nn.Module): 147 | def __init__(self, action_space, dim, discrete=False, baseline=False, 148 | state_space=None): 149 | super(R_Module, self).__init__() 150 | self.discrete = discrete 151 | self.baseline = baseline 152 | self.dim = dim 153 | 154 | if baseline: 155 | self.linear1 = nn.Linear(state_space, dim) 156 | self.linear2 = nn.Linear(dim, dim) 157 | self.lstm_policy = nn.LSTMCell(dim, dim) 158 | 159 | self.actor_linear = nn.Linear(dim, action_space) 160 | self.critic_linear = nn.Linear(dim, 1) 161 | self.rhat_linear = nn.Linear(dim, 1) 162 | if not discrete: 163 | self.actor_sigma_sq = nn.Linear(dim, action_space) 164 | 165 | self.apply(weights_init) 166 | 167 | self.actor_linear.weight.data = normalized_columns_initializer( 168 | self.actor_linear.weight.data, 0.01) 169 | self.actor_linear.bias.data.fill_(0) 170 | self.critic_linear.weight.data = normalized_columns_initializer( 171 | self.critic_linear.weight.data, 1.0) 172 | self.critic_linear.bias.data.fill_(0) 173 | 174 | # only forget should be 1 175 | self.lstm_policy.bias_ih.data.fill_(0) 176 | self.lstm_policy.bias_hh.data.fill_(0) 177 | 178 | if not discrete: 179 | self.actor_sigma_sq.weight.data = normalized_columns_initializer( 180 | self.actor_sigma_sq.weight.data, 0.01) 181 | self.actor_sigma_sq.bias.data.fill_(0) 182 | 183 | self.train() 184 | 185 | def forward(self, inputs): 186 | inputs, (hx_p, cx_p) = inputs 187 | if self.baseline: 188 | inputs = F.elu(self.linear1(inputs)) 189 | inputs = F.elu(self.linear2(inputs)) 190 | hx_p, cx_p = self.lstm_policy(inputs, (hx_p, cx_p)) 191 | x = hx_p 192 | if self.discrete: 193 | action = self.actor_linear(x) 194 | else: 195 | action = (self.actor_linear(x), self.actor_sigma_sq(x)) 196 | return self.critic_linear(x), action, (hx_p, cx_p) 197 | -------------------------------------------------------------------------------- /my_optim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import math 8 | 9 | import torch 10 | import torch.optim as optim 11 | 12 | 13 | class SharedAdam(optim.Adam): 14 | """Implements Adam algorithm with shared states. 15 | """ 16 | 17 | def __init__(self, 18 | params, 19 | lr=1e-3, 20 | betas=(0.9, 0.999), 21 | eps=1e-8, 22 | weight_decay=0): 23 | super(SharedAdam, self).__init__(params, lr, betas, eps, weight_decay) 24 | 25 | for group in self.param_groups: 26 | for p in group['params']: 27 | state = self.state[p] 28 | state['step'] = torch.zeros(1) 29 | state['exp_avg'] = p.data.new().resize_as_(p.data).zero_() 30 | state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_() 31 | 32 | def share_memory(self): 33 | for group in self.param_groups: 34 | for p in group['params']: 35 | state = self.state[p] 36 | state['step'].share_memory_() 37 | state['exp_avg'].share_memory_() 38 | state['exp_avg_sq'].share_memory_() 39 | 40 | def step(self, closure=None): 41 | """Performs a single optimization step. 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | for p in group['params']: 52 | if p.grad is None: 53 | continue 54 | grad = p.grad.data 55 | state = self.state[p] 56 | 57 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 58 | beta1, beta2 = group['betas'] 59 | 60 | state['step'] += 1 61 | 62 | if group['weight_decay'] != 0: 63 | grad = grad.add(group['weight_decay'], p.data) 64 | 65 | # Decay the first and second moment running average coefficient 66 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 67 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 68 | 69 | denom = exp_avg_sq.sqrt().add_(group['eps']) 70 | 71 | bias_correction1 = 1 - beta1**state['step'][0] 72 | bias_correction2 = 1 - beta2**state['step'][0] 73 | step_size = group['lr'] * math.sqrt( 74 | bias_correction2) / bias_correction1 75 | 76 | p.data.addcdiv_(exp_avg, denom, value=-float(step_size.data.numpy()[0])) 77 | 78 | return loss 79 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import time 8 | from collections import deque 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | 14 | from envs import create_env 15 | from model import * 16 | 17 | 18 | def test(rank, args, shared_model, counter): 19 | torch.manual_seed(args.seed + rank) 20 | 21 | env = create_env(args.env_name) 22 | env.seed(args.seed + rank) 23 | 24 | model = ActorCritic(env.observation_space.shape[0], env.action_space) 25 | 26 | model.eval() 27 | 28 | state = env.reset() 29 | state = torch.from_numpy(state).float() 30 | reward_sum = 0 31 | done = True 32 | 33 | start_time = time.time() 34 | 35 | # a quick hack to prevent the agent from stucking 36 | actions = deque(maxlen=100) 37 | episode_length = 0 38 | while True: 39 | episode_length += 1 40 | # Sync with the shared model 41 | if done: 42 | model.load_state_dict(shared_model.state_dict()) 43 | cx_d = Variable(torch.zeros(1, 256), volatile=True) 44 | hx_d = Variable(torch.zeros(1, 256), volatile=True) 45 | cx_p = Variable(torch.zeros(1, 256), volatile=True) 46 | hx_p = Variable(torch.zeros(1, 256), volatile=True) 47 | else: 48 | cx_d = Variable(cx_d.data, volatile=True) 49 | hx_d = Variable(hx_d.data, volatile=True) 50 | cx_p = Variable(cx_p.data, volatile=True) 51 | hx_p = Variable(hx_p.data, volatile=True) 52 | 53 | value, logit, (hx_d, cx_d), (hx_p, cx_p) = model((Variable( 54 | state.unsqueeze(0), volatile=True), (hx_d, cx_d), (hx_p, cx_p))) 55 | if args.discrete: 56 | prob = F.softmax(logit) 57 | action = prob.max(1, keepdim=True)[1].data.numpy() 58 | else: 59 | mu, sigma_sq = logit 60 | sigma_sq = F.softplus(sigma_sq) 61 | eps = torch.randn(mu.size()) 62 | action = (mu + sigma_sq.sqrt()*Variable(eps)).data 63 | state, reward, done, _ = env.step(action[0, 0]) 64 | done = done or episode_length >= args.max_episode_length 65 | reward_sum += reward 66 | 67 | # a quick hack to prevent the agent from stucking 68 | actions.append(action[0, 0]) 69 | if actions.count(actions[0]) == actions.maxlen: 70 | done = True 71 | 72 | if done: 73 | print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format( 74 | time.strftime("%Hh %Mm %Ss", 75 | time.gmtime(time.time() - start_time)), 76 | counter.value, counter.value / (time.time() - start_time), 77 | reward_sum, episode_length)) 78 | reward_sum = 0 79 | episode_length = 0 80 | actions.clear() 81 | state = env.reset() 82 | time.sleep(60) 83 | 84 | state = torch.from_numpy(state).float() 85 | -------------------------------------------------------------------------------- /train_dynamics_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import math 8 | import numpy as np 9 | import os 10 | import time 11 | from itertools import chain 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | import torch.utils.data as data 18 | from torch.autograd import Variable 19 | 20 | from model import Encoder, Decoder, D_Module 21 | from common import * 22 | 23 | 24 | def get_dynamics_losses(s, s_hat, s_prime, s_prime_hat, z_prime, z_prime_hat, 25 | a_hat, curr_actions, discrete=False): 26 | # reconstruction loss 27 | recon_loss = F.mse_loss(s_hat, s) 28 | 29 | # next state prediction loss 30 | model_loss = F.mse_loss(s_prime_hat, s_prime) 31 | 32 | # net decoder loss 33 | dec_loss = (F.mse_loss(s_hat, s) + F.mse_loss(s_prime_hat, s_prime)) 34 | 35 | # action reconstruction loss 36 | if discrete: 37 | a_hat = F.log_softmax(a_hat) 38 | inv_loss = F.mse_loss(a_hat, curr_actions) 39 | 40 | # representation space constraint 41 | forward_loss = F.mse_loss(z_prime_hat, z_prime.detach()) 42 | return recon_loss, model_loss, dec_loss, inv_loss, forward_loss 43 | 44 | 45 | def get_maze_dynamics_losses(s, s_hat_logits, 46 | s_prime, s_prime_hat_logits, 47 | z_prime, z_prime_hat, 48 | a_hat_logits, curr_actions, discrete=True, 49 | dec_mask=None): 50 | """ 51 | dec_mask: if to reweigh the weights on the agent and goal locations, 52 | """ 53 | # reconstruction loss 54 | if dec_mask is not None: 55 | recon_loss = F.cross_entropy(s_hat_logits.view(-1, 2), s.view(-1).long(), reduce=False) 56 | recon_loss = (recon_loss * dec_mask).mean() 57 | else: 58 | recon_loss = F.cross_entropy(s_hat_logits.view(-1, 2), s.view(-1).long()) 59 | 60 | # next state prediction loss 61 | if dec_mask is not None: 62 | model_loss = F.cross_entropy(s_prime_hat_logits.view(-1, 2), s_prime.view(-1).long(), reduce=False) 63 | model_loss = (model_loss * dec_mask).mean() 64 | else: 65 | model_loss = F.cross_entropy(s_prime_hat_logits.view(-1, 2), s_prime.view(-1).long()) 66 | 67 | # net decoder loss 68 | dec_loss = recon_loss + model_loss 69 | 70 | # action reconstruction loss 71 | inv_loss = F.cross_entropy(a_hat_logits, curr_actions.view(-1).long()) 72 | 73 | # representation space constraint 74 | forward_loss = F.mse_loss(z_prime_hat, z_prime.detach()) 75 | 76 | return recon_loss, model_loss, dec_loss, inv_loss, forward_loss 77 | 78 | class DynamicsDataset(data.Dataset): 79 | def __init__(self, root, size, batch, rollout): 80 | self.size = size 81 | self.root = root 82 | self.actions = [] 83 | self.states = [] 84 | start = 0 85 | 86 | while len(self.actions) < size: 87 | end = start + batch 88 | states, actions = torch.load( 89 | os.path.join(self.root, 'states_actions_%s_%s.pt' % (start, end))) 90 | self.states += states 91 | self.actions += actions 92 | start = end 93 | rollout = len(actions[0]) 94 | self.actions = torch.Tensor(self.actions[:size]).view( 95 | self.size, rollout, -1) 96 | self.states = torch.Tensor(self.states[:size]).view( 97 | self.size, rollout + 1, -1) 98 | 99 | def __getitem__(self, index): 100 | assert index < self.size 101 | return self.states[index], self.actions[index] 102 | 103 | def __len__(self): 104 | return len(self.actions) 105 | 106 | 107 | class MazeDynamicsDataset(data.Dataset): 108 | def __init__(self, root, size, batch, rollout): 109 | """ 110 | batch: is the size of the blocks of the data 111 | size: total size of the dataset, num of trajectories 112 | rollout: length of the trajectory 113 | """ 114 | self.size = size 115 | self.root = root 116 | self.actions = [] 117 | self.states = [] 118 | start = 0 119 | 120 | while len(self.actions) < size: 121 | end = start + batch 122 | states, actions = torch.load( 123 | os.path.join(self.root, 'states_actions_%s_%s.pt' % (start, end))) 124 | self.states += states 125 | self.actions += actions 126 | start = end 127 | 128 | # convert the state and actions to the float 129 | self.states = np.asarray(self.states, dtype=np.float32) 130 | self.actions = np.asarray(self.actions, dtype=np.float32) 131 | 132 | # convert to tensors 133 | self.actions = torch.Tensor(self.actions).view( 134 | self.size, rollout, -1) 135 | self.states = torch.Tensor(self.states).view( 136 | self.size, rollout + 1, -1) 137 | 138 | def __getitem__(self, index): 139 | assert index < self.size 140 | return self.states[index], self.actions[index] 141 | 142 | def __len__(self): 143 | return len(self.actions) 144 | 145 | 146 | def forward(i, states, target_actions, enc, dec, d_module, args, 147 | d_init=None, dec_mask=None): 148 | if args.framework == "mazebase": 149 | # cx_d = Variable(torch.zeros(states.size(0), args.lstm_dim)) 150 | # hx_d = Variable(torch.zeros(states.size(0), args.lstm_dim)) 151 | hx_d, cx_d = d_init(Variable(states[:, 0, :]).contiguous().cuda()) 152 | else: 153 | cx_d = Variable(torch.zeros(states.size(0), args.dim)) 154 | hx_d = Variable(torch.zeros(states.size(0), args.dim)) 155 | 156 | if args.gpu: 157 | cx_d = cx_d.cuda() 158 | hx_d = hx_d.cuda() 159 | 160 | 161 | dec_loss = 0 162 | inv_loss = 0 163 | model_loss = 0 164 | recon_loss = 0 165 | forward_loss = 0 166 | 167 | 168 | current_epoch_actions = 0 169 | current_epoch_predicted_a_hat = 0 170 | 171 | s = None 172 | for r in range(args.rollout): 173 | curr_state = states[:, r, :] 174 | next_state = states[:, r + 1, :] 175 | if args.framework == "mazebase": 176 | curr_actions = Variable(target_actions[:, r].contiguous().view( 177 | -1, 1)) 178 | else: 179 | curr_actions = Variable(target_actions[:, r].contiguous().view( 180 | -1, args.action_space.shape[0])) 181 | if s is None: 182 | s = Variable(curr_state.contiguous()) 183 | if args.gpu: 184 | s = s.cuda() 185 | z = enc(s) 186 | s_prime = Variable(next_state.contiguous()) 187 | if args.gpu: 188 | s_prime = s_prime.cuda() 189 | z_prime = enc(s_prime) 190 | 191 | if args.gpu: 192 | curr_actions = curr_actions.cuda() 193 | 194 | if args.framework == "mazebase": 195 | s_hat, s_hat_binary = dec(z) 196 | z_prime_hat, a_hat, (hx_d, cx_d) = d_module( 197 | z, curr_actions.long(), z_prime.detach(), (hx_d, cx_d)) 198 | s_prime_hat, s_prime_hat_binary = dec(z_prime_hat) 199 | r_loss, m_loss, d_loss, i_loss, f_loss = get_maze_dynamics_losses( 200 | s, s_hat, s_prime, s_prime_hat, z_prime, z_prime_hat, a_hat, 201 | curr_actions, discrete=args.discrete, dec_mask= dec_mask) 202 | 203 | # caculate the accuracy here 204 | _, predicted_a = torch.max(F.sigmoid(a_hat),1) 205 | current_epoch_predicted_a_hat += (predicted_a == curr_actions.view(-1).long()).sum().data[0] 206 | current_epoch_actions += curr_actions.size(0) 207 | 208 | else: 209 | s_hat = dec(z) 210 | z_prime_hat, a_hat, (hx_d, cx_d) = d_module( 211 | (z, z_prime, curr_actions, (hx_d, cx_d))) 212 | s_prime_hat = dec(z_prime_hat) 213 | r_loss, m_loss, d_loss, i_loss, f_loss = get_dynamics_losses( 214 | s, s_hat, s_prime, s_prime_hat, z_prime, z_prime_hat, 215 | a_hat, curr_actions, discrete=args.discrete) 216 | 217 | inv_loss += i_loss 218 | dec_loss += d_loss 219 | forward_loss += f_loss 220 | recon_loss += r_loss 221 | model_loss += m_loss 222 | 223 | s = s_prime 224 | z = z_prime 225 | 226 | return forward_loss, inv_loss, dec_loss, recon_loss, model_loss, \ 227 | current_epoch_predicted_a_hat, current_epoch_actions 228 | 229 | 230 | def forward_planning(i, states, target_actions, enc, dec, d_module, args, 231 | d_init=None, dec_mask=None): 232 | cx_d = Variable(torch.zeros(states.size(0), args.dim)) 233 | hx_d = Variable(torch.zeros(states.size(0), args.dim)) 234 | 235 | if args.gpu: 236 | cx_d = cx_d.cuda() 237 | hx_d = hx_d.cuda() 238 | 239 | 240 | dec_loss = 0 241 | inv_loss = 0 242 | model_loss = 0 243 | recon_loss = 0 244 | forward_loss = 0 245 | 246 | 247 | current_epoch_actions = 0 248 | current_epoch_predicted_a_hat = 0 249 | 250 | s = None 251 | for r in range(args.rollout): 252 | curr_state = states[:, r, :] 253 | next_state = states[:, r + 1, :] 254 | curr_actions = Variable(target_actions[:, r].contiguous().view( 255 | -1, args.action_space.shape[0])) 256 | if s is None: 257 | s = Variable(curr_state.contiguous()) 258 | if args.gpu: 259 | s = s.cuda() 260 | z = enc(s) 261 | s_prime = Variable(next_state.contiguous()) 262 | if args.gpu: 263 | s_prime = s_prime.cuda() 264 | z_prime = enc(s_prime) 265 | 266 | if args.gpu: 267 | curr_actions = curr_actions.cuda() 268 | 269 | s_hat = dec(z) 270 | z_prime_hat, a_hat, (hx_d, cx_d) = d_module( 271 | (z, z_prime, curr_actions, (hx_d, cx_d))) 272 | s_prime_hat = dec(z_prime_hat) 273 | r_loss, m_loss, d_loss, i_loss, f_loss = get_dynamics_losses( 274 | s, s_hat, s_prime, s_prime_hat, z_prime, z_prime_hat, 275 | a_hat, curr_actions, discrete=args.discrete) 276 | 277 | inv_loss += i_loss 278 | dec_loss += d_loss 279 | forward_loss += f_loss 280 | recon_loss += r_loss 281 | model_loss += m_loss 282 | 283 | s = s_prime 284 | z = z_prime_hat 285 | 286 | return forward_loss, inv_loss, dec_loss, recon_loss, model_loss, \ 287 | current_epoch_predicted_a_hat, current_epoch_actions 288 | 289 | 290 | def multiple_forward(i, states, target_actions, enc, dec, d_module, args, 291 | d_init=None, dec_mask = None): 292 | cx_d = Variable(torch.zeros(states.size(0), args.dim)) 293 | hx_d = Variable(torch.zeros(states.size(0), args.dim)) 294 | 295 | if args.gpu: 296 | cx_d = cx_d.cuda() 297 | hx_d = hx_d.cuda() 298 | 299 | 300 | dec_loss = 0 301 | inv_loss = 0 302 | model_loss = 0 303 | recon_loss = 0 304 | forward_loss = 0 305 | 306 | 307 | current_epoch_actions = 0 308 | current_epoch_predicted_a_hat = 0 309 | 310 | s = None 311 | for r in range(args.rollout): 312 | curr_state = states[:, r, :] 313 | next_state = states[:, r + 1, :] 314 | if args.framework == "mazebase": 315 | curr_actions = Variable(target_actions[:, r].contiguous().view( 316 | -1, 1)) 317 | else: 318 | curr_actions = Variable(target_actions[:, r].contiguous().view( 319 | -1, args.action_space.shape[0])) 320 | if s is None: 321 | s = Variable(curr_state.contiguous()) 322 | if args.gpu: 323 | s = s.cuda() 324 | z = enc(s) 325 | s_prime = Variable(next_state.contiguous()) 326 | if args.gpu: 327 | s_prime = s_prime.cuda() 328 | z_prime = enc(s_prime) 329 | 330 | if args.gpu: 331 | curr_actions = curr_actions.cuda() 332 | 333 | if args.framework == "mazebase": 334 | s_hat, s_hat_binary = dec(z) 335 | z_prime_hat, a_hat, (hx_d, cx_d) = d_module( 336 | z, curr_actions.long(), z_prime.detach(), (hx_d, cx_d)) 337 | s_prime_hat, s_prime_hat_binary = dec(z_prime_hat) 338 | r_loss, m_loss, d_loss, i_loss, f_loss = get_maze_dynamics_losses( 339 | s, s_hat, s_prime, s_prime_hat, z_prime, z_prime_hat, a_hat, 340 | curr_actions, discrete=args.discrete, dec_mask= dec_mask) 341 | 342 | # caculate the accuracy here 343 | _, predicted_a = torch.max(F.sigmoid(a_hat),1) 344 | current_epoch_predicted_a_hat += (predicted_a == curr_actions.view(-1).long()).sum().data[0] 345 | current_epoch_actions += curr_actions.size(0) 346 | 347 | else: 348 | s_hat = dec(z) 349 | z_prime_hat, a_hat, (hx_d, cx_d) = d_module( 350 | (z, z_prime, curr_actions, (hx_d, cx_d))) 351 | s_prime_hat = dec(z_prime_hat) 352 | r_loss, m_loss, d_loss, i_loss, f_loss = get_dynamics_losses( 353 | s, s_hat, s_prime, s_prime_hat, z_prime, z_prime_hat, a_hat, 354 | curr_actions, discrete=args.discrete) 355 | 356 | inv_loss += i_loss 357 | dec_loss += d_loss 358 | forward_loss += f_loss 359 | recon_loss += r_loss 360 | model_loss += m_loss 361 | 362 | s = s_prime 363 | z = z_prime_hat 364 | 365 | return forward_loss, inv_loss, dec_loss, recon_loss, model_loss, \ 366 | current_epoch_predicted_a_hat, current_epoch_actions 367 | 368 | def train_dynamics(env, args, writer=None): 369 | """ 370 | Trains the Dynamics module. Supervised. 371 | 372 | Arguments: 373 | env: the initialized environment (rllab/gym) 374 | args: input arguments 375 | writer: initialized summary writer for tensorboard 376 | """ 377 | args.action_space = env.action_space 378 | 379 | # Initialize models 380 | enc = Encoder(env.observation_space.shape[0], args.dim, 381 | use_conv=args.use_conv) 382 | dec = Decoder(env.observation_space.shape[0], args.dim, 383 | use_conv=args.use_conv) 384 | d_module = D_Module(env.action_space.shape[0], args.dim, args.discrete) 385 | 386 | if args.from_checkpoint is not None: 387 | results_dict = torch.load(args.from_checkpoint) 388 | enc.load_state_dict(results_dict['enc']) 389 | dec.load_state_dict(results_dict['dec']) 390 | d_module.load_state_dict(results_dict['d_module']) 391 | 392 | all_params = chain(enc.parameters(), dec.parameters(), d_module.parameters()) 393 | 394 | if args.transfer: 395 | for p in enc.parameters(): 396 | p.requires_grad = False 397 | 398 | for p in dec.parameters(): 399 | p.requires_grad = False 400 | all_params = d_module.parameters() 401 | 402 | optimizer = torch.optim.Adam(all_params, lr=args.lr, 403 | weight_decay=args.weight_decay) 404 | 405 | if args.gpu: 406 | enc = enc.cuda() 407 | dec = dec.cuda() 408 | d_module = d_module.cuda() 409 | 410 | # Initialize datasets 411 | val_loader = None 412 | train_dataset = DynamicsDataset( 413 | args.train_set, args.train_size, batch=args.train_batch, 414 | rollout=args.rollout) 415 | val_dataset = DynamicsDataset(args.test_set, 5000, batch=args.test_batch, 416 | rollout=args.rollout) 417 | val_loader = torch.utils.data.DataLoader( 418 | dataset=val_dataset, batch_size=args.batch_size, shuffle=False, 419 | num_workers=args.num_workers) 420 | 421 | train_loader = torch.utils.data.DataLoader( 422 | dataset=train_dataset, batch_size=args.batch_size, shuffle=True, 423 | num_workers=args.num_workers) 424 | 425 | results_dict = { 426 | 'dec_losses': [], 427 | 'forward_losses': [], 428 | 'inverse_losses': [], 429 | 'total_losses': [], 430 | 'enc': None, 431 | 'dec': None, 432 | 'd_module': None, 433 | 'd_init':None, 434 | 'args': args 435 | } 436 | 437 | total_action_taken = 0 438 | correct_predicted_a_hat = 0 439 | 440 | # create the mask here for re-weighting 441 | dec_mask = None 442 | if args.dec_mask is not None: 443 | dec_mask = torch.ones(9) 444 | game_vocab = dict([(b, a) for a, b in enumerate(sorted(env.game.all_possible_features()))]) 445 | dec_mask[game_vocab['Agent']] = args.dec_mask 446 | dec_mask[game_vocab['Goal']] = args.dec_mask 447 | dec_mask = dec_mask.expand(args.batch_size, args.maze_length,args.maze_length,9).contiguous().view(-1) 448 | dec_mask = Variable(dec_mask, requires_grad = False) 449 | if args.gpu: 450 | dec_mask = dec_mask.cuda() 451 | 452 | for epoch in range(1, args.num_epochs + 1): 453 | enc.train() 454 | dec.train() 455 | d_module.train() 456 | 457 | if args.framework == "mazebase": 458 | d_init.train() 459 | 460 | # for measuring the accuracy 461 | train_acc = 0 462 | current_epoch_actions = 0 463 | current_epoch_predicted_a_hat = 0 464 | 465 | start = time.time() 466 | for i, (states, target_actions) in enumerate(train_loader): 467 | 468 | optimizer.zero_grad() 469 | 470 | if args.framework != "mazebase": 471 | forward_loss, inv_loss, dec_loss, recon_loss, model_loss, _, _ = forward_planning( 472 | i, states, target_actions, enc, dec, d_module, args) 473 | else: 474 | forward_loss, inv_loss, dec_loss, recon_loss, model_loss, current_epoch_predicted_a_hat, current_epoch_actions = multiple_forward( 475 | i, states, target_actions, enc, dec, d_module, args, d_init, dec_mask ) 476 | 477 | loss = forward_loss + args.inv_loss_coef * inv_loss + \ 478 | args.dec_loss_coef * dec_loss 479 | 480 | 481 | if i % args.log_interval == 0: 482 | log( 483 | 'Epoch [{}/{}]\tIter [{}/{}]\t'.format( 484 | epoch, args.num_epochs, i+1, len( 485 | train_dataset)//args.batch_size) + \ 486 | 'Time: {:.2f}\t'.format(time.time() - start) + \ 487 | 'Decoder Loss: {:.2f}\t'.format(dec_loss.data[0]) + \ 488 | 'Forward Loss: {:.2f}\t'.format(forward_loss.data[0] ) + \ 489 | 'Inverse Loss: {:.2f}\t'.format(inv_loss.data[0]) + \ 490 | 'Loss: {:.2f}\t'.format(loss.data[0])) 491 | 492 | results_dict['dec_losses'].append(dec_loss.data[0]) 493 | results_dict['forward_losses'].append(forward_loss.data[0]) 494 | results_dict['inverse_losses'].append(inv_loss.data[0]) 495 | results_dict['total_losses'].append(loss.data[0]) 496 | 497 | # write the summaries here 498 | if writer: 499 | writer.add_scalar('dynamics/total_loss', loss.data[0], epoch) 500 | writer.add_scalar('dynamics/decoder', dec_loss.data[0], epoch) 501 | writer.add_scalar( 502 | 'dynamics/reconstruction_loss', recon_loss.data[0], epoch) 503 | writer.add_scalar( 504 | 'dynamics/next_state_prediction_loss', 505 | model_loss.data[0], epoch) 506 | writer.add_scalar('dynamics/inv_loss', inv_loss.data[0], epoch) 507 | writer.add_scalar( 508 | 'dynamics/forward_loss', forward_loss.data[0], epoch) 509 | 510 | writer.add_scalars( 511 | 'dynamics/all_losses', 512 | {"total_loss":loss.data[0], 513 | "reconstruction_loss":recon_loss.data[0], 514 | "next_state_prediction_loss":model_loss.data[0], 515 | "decoder_loss":dec_loss.data[0], 516 | "inv_loss":inv_loss.data[0], 517 | "forward_loss":forward_loss.data[0], 518 | } , epoch) 519 | 520 | loss.backward() 521 | 522 | correct_predicted_a_hat += current_epoch_predicted_a_hat 523 | total_action_taken += current_epoch_actions 524 | 525 | # does it not work at all without grad clipping ? 526 | torch.nn.utils.clip_grad_norm(all_params, args.max_grad_norm) 527 | optimizer.step() 528 | 529 | # maybe add the generated image to add the logs 530 | # writer.add_image() 531 | 532 | # Run validation 533 | if val_loader is not None: 534 | enc.eval() 535 | dec.eval() 536 | d_module.eval() 537 | forward_loss, inv_loss, dec_loss = 0, 0, 0 538 | for i, (states, target_actions) in enumerate(val_loader): 539 | f_loss, i_loss, d_loss, _, _, _, _ = forward_planning( 540 | i, states, target_actions, enc, dec, d_module, args) 541 | forward_loss += f_loss 542 | inv_loss += i_loss 543 | dec_loss += d_loss 544 | loss = forward_loss + args.inv_loss_coef * inv_loss + \ 545 | args.dec_loss_coef * dec_loss 546 | if writer: 547 | writer.add_scalar('val/forward_loss', forward_loss.data[0] / i, epoch) 548 | writer.add_scalar('val/inverse_loss', inv_loss.data[0] / i, epoch) 549 | writer.add_scalar('val/decoder_loss', dec_loss.data[0] / i, epoch) 550 | log( 551 | '[Validation]\t' + \ 552 | 'Decoder Loss: {:.2f}\t'.format(dec_loss.data[0] / i) + \ 553 | 'Forward Loss: {:.2f}\t'.format(forward_loss.data[0] / i) + \ 554 | 'Inverse Loss: {:.2f}\t'.format(inv_loss.data[0] / i) + \ 555 | 'Loss: {:.2f}\t'.format(loss.data[0] / i)) 556 | if epoch % args.checkpoint == 0: 557 | results_dict['enc'] = enc.state_dict() 558 | results_dict['dec'] = dec.state_dict() 559 | results_dict['d_module'] = d_module.state_dict() 560 | if args.framework == "mazebase": 561 | results_dict['d_init'] = d_init.state_dict() 562 | torch.save(results_dict, 563 | os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch)) 564 | log('Saved model %s' % epoch) 565 | 566 | results_dict['enc'] = enc.state_dict() 567 | results_dict['dec'] = dec.state_dict() 568 | results_dict['d_module'] = d_module.state_dict() 569 | torch.save(results_dict, 570 | os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch)) 571 | print(os.path.join(args.out, 'dynamics_module_epoch%s.pt' % epoch)) 572 | -------------------------------------------------------------------------------- /train_online.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import math 8 | import numpy as np 9 | import os 10 | import time 11 | from itertools import chain 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from torch.autograd import Variable 17 | 18 | from envs import * 19 | from model import Encoder, Decoder, D_Module, R_Module 20 | from train_dynamics_module import D_Module, get_dynamics_losses 21 | from common import * 22 | from tensorboardX import SummaryWriter 23 | 24 | def ensure_shared_grads(model, shared_model): 25 | for param, shared_param in zip(model.parameters(), 26 | shared_model.parameters()): 27 | if shared_param.grad is not None: 28 | return 29 | shared_param._grad = param.grad 30 | 31 | 32 | def train_online(rank, args, shared_model, optimizer=None, writer_dir=None): 33 | """ 34 | 35 | Arguments: 36 | - writer: the tensorboard summary writer directory (note: can't get it working directly with the SummaryWriter object) 37 | """ 38 | # create writer here itself 39 | writer = None 40 | if writer_dir is not None: 41 | writer = SummaryWriter(log_dir=writer_dir) 42 | 43 | shared_enc, shared_dec, shared_d_module, shared_r_module = shared_model 44 | running_t, running_reward, running_value_loss, running_policy_loss, \ 45 | running_reward_loss = 0, 0, 0, 0, 0 46 | 47 | torch.manual_seed(args.seed + rank) 48 | env = create_env(args.env_name, framework=args.framework, args=args) 49 | set_seed(args.seed + rank, env, args.framework) 50 | enc = Encoder(env.observation_space.shape[0], args.dim, 51 | use_conv=args.use_conv) 52 | dec = Decoder(env.observation_space.shape[0], args.dim, 53 | use_conv=args.use_conv) 54 | d_module = D_Module(env.action_space.shape[0], args.dim, args.discrete) 55 | r_module = R_Module(env.action_space.shape[0], args.dim, 56 | discrete=args.discrete, baseline=False, 57 | state_space=env.observation_space.shape[0]) 58 | 59 | all_params = chain(enc.parameters(), dec.parameters(), 60 | d_module.parameters(), 61 | r_module.parameters()) 62 | # no shared adam ? 63 | if optimizer is None: 64 | optimizer = optim.Adam(all_params, lr=args.lr) 65 | 66 | enc.train() 67 | dec.train() 68 | d_module.train() 69 | r_module.train() 70 | 71 | results_dict = { 72 | 'enc': None, 73 | 'dec': None, 74 | 'd_module': None, 75 | 'args': args, 76 | 'reward': [], 77 | 'policy_loss': [], 78 | 'value_loss': [], 79 | 'mean_entropy': [], 80 | 'mean_predicted_value': [], 81 | 'dec_losses': [], 82 | 'forward_losses': [], 83 | 'inverse_losses': [], 84 | 'total_losses': [], 85 | } 86 | episode_length = 0 87 | i_episode, total_episode = 0, 0 88 | done = True 89 | start = time.time() 90 | while total_episode < args.num_episodes: 91 | # Sync with the shared model 92 | r_module.load_state_dict(shared_r_module.state_dict()) 93 | d_module.load_state_dict(shared_d_module.state_dict()) 94 | enc.load_state_dict(shared_enc.state_dict()) 95 | dec.load_state_dict(shared_dec.state_dict()) 96 | if done: 97 | cx_p = Variable(torch.zeros(1, args.dim)) 98 | hx_p = Variable(torch.zeros(1, args.dim)) 99 | cx_d = Variable(torch.zeros(1, args.dim)) 100 | hx_d = Variable(torch.zeros(1, args.dim)) 101 | i_episode += 1 102 | episode_length = 0 103 | total_episode = args.num_processes * (i_episode - 1) + rank 104 | start = time.time() 105 | last_episode_length = episode_length 106 | if not args.single_env and args.env_name.endswith('MazeEnv'): # generate new maze 107 | env = create_env( 108 | args.env_name, framework=args.framework, args=args) 109 | s = env.reset() 110 | s = Variable(torch.from_numpy(s).float()) 111 | else: 112 | cx_p = Variable(cx_p.data) 113 | hx_p = Variable(hx_p.data) 114 | cx_d = Variable(cx_d.data) 115 | hx_d = Variable(hx_d.data) 116 | s = Variable(s.data) 117 | z = enc(s).unsqueeze(0) 118 | s_hat = dec(z) 119 | 120 | values = [] 121 | rhats = [] 122 | log_probs = [] 123 | rewards = [] 124 | entropies = [] 125 | dec_loss = 0 126 | inv_loss = 0 127 | model_loss = 0 128 | recon_loss = 0 129 | forward_loss = 0 130 | for step in range(args.num_steps): 131 | episode_length += 1 132 | value, rhat, logit, (hx_p, cx_p) = r_module(( 133 | z.detach(), (hx_p, cx_p))) 134 | action, entropy, log_prob = get_action(logit, discrete=args.discrete) 135 | vlog("Action: %s\t Bounds: %s" % (str(action), str((env.action_space.low, env.action_space.high))), args.v) 136 | entropies.append(entropy) 137 | s_prime, reward, done, _ = env.step(action.data.numpy()) 138 | s_prime = Variable(torch.from_numpy(s_prime).float()) 139 | done = done or episode_length >= args.max_episode_length 140 | 141 | z_prime = enc(s_prime) 142 | z_prime_hat, a_hat, (hx_d, cx_d) = d_module( 143 | (z, z_prime, action, (hx_d, cx_d))) 144 | s_prime_hat = dec(z_prime_hat) 145 | r_loss, m_loss, d_loss, i_loss, f_loss = get_dynamics_losses( 146 | s, s_hat, s_prime, s_prime_hat, z_prime, z_prime_hat, a_hat, 147 | action) 148 | values.append(value) 149 | rhats.append(rhat) 150 | log_probs.append(log_prob) 151 | rewards.append(reward) 152 | dec_loss += d_loss 153 | inv_loss += i_loss 154 | model_loss += m_loss 155 | recon_loss += r_loss 156 | forward_loss += f_loss 157 | z = z_prime_hat 158 | s = s_prime 159 | s_hat = s_prime_hat 160 | if done: 161 | break 162 | R = torch.zeros(1, 1) 163 | if not done: 164 | value, _, _, _ = r_module((z, (hx_p, cx_p))) 165 | R = value.data 166 | 167 | values.append(Variable(R)) 168 | policy_loss = 0 169 | value_loss = 0 170 | rew_loss = 0 171 | pred_reward_loss = 0 172 | R = Variable(R) 173 | gae = torch.zeros(1, 1) 174 | vlog("values: %s" % str([v.data[0,0] for v in values]), args.v) 175 | vlog("rhats: %s" % str(rhats), args.v) 176 | for i in reversed(range(len(rewards))): 177 | R = args.gamma * R + rewards[i] 178 | advantage = R - values[i] 179 | value_loss += 0.5 * advantage.pow(2) 180 | 181 | # reward loss 182 | rew_loss += F.mse_loss(rhats[i], Variable(torch.from_numpy( 183 | np.array([rewards[i]])).float())) 184 | 185 | # Generalized Advantage Estimation 186 | delta_t = rewards[i] + args.gamma * values[i + 1].data \ 187 | - values[i].data 188 | gae = gae * args.gamma * args.tau + delta_t 189 | if args.discrete: 190 | policy_loss = policy_loss - log_probs[i] * Variable(gae) \ 191 | - args.entropy_coef * entropies[i] 192 | else: 193 | policy_loss = policy_loss - (log_probs[i] * Variable(gae).expand_as( 194 | log_probs[i])).sum() - (args.entropy_coef * entropies[i]).sum() 195 | 196 | optimizer.zero_grad() 197 | U = 1. / min(i_episode, 100) 198 | running_reward = running_reward * (1 - U) + sum(rewards) * U 199 | running_t = running_t * (1 - U) + episode_length * U 200 | running_policy_loss = running_policy_loss * (1 - U) + policy_loss.data[0] * U 201 | running_value_loss = running_value_loss * (1 - U) + \ 202 | args.value_loss_coef * value_loss.data[0, 0] * U 203 | running_reward_loss = running_reward_loss * (1 - U) + \ 204 | args.rew_loss_coef * rew_loss.data[0] * U 205 | mean_entropy = np.mean([e.sum().data[0] for e in entropies]) 206 | 207 | mean_predicted_value = np.mean([v.sum().data[0] for v in values]) 208 | loss = policy_loss + args.value_loss_coef * value_loss + \ 209 | args.rew_loss_coef * rew_loss + args.inv_loss_coef * inv_loss + \ 210 | args.dec_loss_coef * dec_loss + forward_loss 211 | if total_episode % args.log_interval == 0 and done: 212 | if not args.discrete: 213 | sample_logits = (list(logit[0].data[0].numpy()), 214 | list(logit[1].data[0].numpy())) 215 | else: 216 | sample_logits = list(logit.data[0].numpy()) 217 | log( 218 | 'Episode {}\t'.format(total_episode) + \ 219 | 'Avg reward: {:.2f}\tAverage length: {:.2f}\t'.format( 220 | running_reward, running_t) + \ 221 | 'Entropy: {:.2f}\tTime: {:.2f}\tRank: {}\t'.format( 222 | mean_entropy, time.time() - start, rank) + \ 223 | 'Policy Loss: {:.2f}\t'.format(running_policy_loss) + \ 224 | 'Reward Loss: {:.2f}\t'.format(running_reward_loss) + \ 225 | 'Weighted Value Loss: {:.2f}\t'.format(running_value_loss) + \ 226 | 'Sample Action: %s\t' % str(list(action.data.numpy())) + \ 227 | 'Logits: %s\t' % str(sample_logits) + \ 228 | 'Decoder Loss: {:.2f}\t'.format(dec_loss.data[0]) + \ 229 | 'Forward Loss: {:.2f}\t'.format(forward_loss.data[0]) + \ 230 | 'Inverse Loss: {:.2f}\t'.format(inv_loss.data[0]) + \ 231 | 'Loss: {:.2f}\t'.format(loss.data[0, 0])) 232 | 233 | # write summaries here 234 | if writer_dir is not None and done: 235 | log('writing to tensorboard') 236 | 237 | # running losses 238 | writer.add_scalar('reward/running_reward', running_reward, i_episode) 239 | writer.add_scalar('reward/running_policy_loss', running_policy_loss, i_episode) 240 | writer.add_scalar('reward/running_value_loss', running_value_loss, i_episode) 241 | 242 | # current episode stats 243 | writer.add_scalar('reward/episode_reward', sum(rewards), i_episode) 244 | writer.add_scalar('reward/episode_policy_loss', policy_loss.data[0], i_episode) 245 | writer.add_scalar('reward/episode_value_loss', value_loss.data[0,0], i_episode) 246 | writer.add_scalar('reward/mean_entropy', mean_entropy, i_episode) 247 | writer.add_scalar('reward/mean_predicted_value', mean_predicted_value, i_episode) 248 | writer.add_scalar('dynamics/total_loss', loss.data[0], i_episode) 249 | writer.add_scalar('dynamics/decoder', dec_loss.data[0], i_episode) 250 | writer.add_scalar('dynamics/reconstruction_loss', recon_loss.data[0], i_episode) 251 | writer.add_scalar('dynamics/next_state_prediction_loss', model_loss.data[0], i_episode) 252 | writer.add_scalar('dynamics/inv_loss', inv_loss.data[0], i_episode) 253 | writer.add_scalar('dynamics/forward_loss', forward_loss.data[0], i_episode) 254 | 255 | results_dict['reward'].append(sum(rewards)) 256 | results_dict['policy_loss'].append(policy_loss.data[0]) 257 | results_dict['value_loss'].append(value_loss.data[0,0]) 258 | results_dict['mean_entropy'].append(mean_entropy) 259 | results_dict['mean_predicted_value'].append(mean_predicted_value) 260 | results_dict['dec_losses'].append(dec_loss.data[0]) 261 | results_dict['forward_losses'].append(forward_loss.data[0]) 262 | results_dict['inverse_losses'].append(inv_loss.data[0]) 263 | results_dict['total_losses'].append(loss.data[0]) 264 | 265 | loss.backward() 266 | torch.nn.utils.clip_grad_norm(all_params, args.max_grad_norm) 267 | ensure_shared_grads(r_module, shared_r_module) 268 | ensure_shared_grads(d_module, shared_d_module) 269 | ensure_shared_grads(enc, shared_enc) 270 | ensure_shared_grads(dec, shared_dec) 271 | optimizer.step() 272 | 273 | if total_episode % args.checkpoint_interval == 0: 274 | args.curr_iter = total_episode 275 | args.dynamics_module = os.path.join( 276 | args.out, 'dynamics_module%s.pt' % total_episode) 277 | torch.save((shared_r_module.state_dict(), args), os.path.join( 278 | args.out, 'reward_module%s.pt' % total_episode)) 279 | results_dict['enc'] = shared_enc.state_dict() 280 | results_dict['dec'] = shared_dec.state_dict() 281 | results_dict['d_module'] = shared_d_module.state_dict() 282 | torch.save(results_dict, 283 | os.path.join(args.out, 'dynamics_module%s.pt' % total_episode)) 284 | log("Saved model %d" % total_episode) 285 | 286 | if writer_dir is not None and i_episode % \ 287 | (args.checkpoint_interval // args.num_processes) == 0: 288 | torch.save(results_dict, 289 | os.path.join(args.out, 'results_dict.pt')) 290 | print(os.path.join(args.out, 'results_dict.pt')) 291 | 292 | if writer_dir is not None: 293 | torch.save(results_dict, 294 | os.path.join(args.out, 'results_dict.pt')) 295 | print(os.path.join(args.out, 'results_dict.pt')) 296 | -------------------------------------------------------------------------------- /train_reward_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import math 8 | import numpy as np 9 | import os 10 | import time 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torch.autograd import Variable 16 | 17 | from envs import * 18 | from model import R_Module 19 | from common import * 20 | from tensorboardX import SummaryWriter 21 | 22 | def ensure_shared_grads(model, shared_model): 23 | for param, shared_param in zip(model.parameters(), 24 | shared_model.parameters()): 25 | if shared_param.grad is not None: 26 | return 27 | shared_param._grad = param.grad 28 | 29 | 30 | def train_rewards(rank, args, shared_model, enc, optimizer=None, writer_dir=None, 31 | d_module=None): 32 | """ 33 | 34 | Arguments: 35 | - writer: the tensorboard summary writer directory (note: can't get it working directly with the SummaryWriter object) 36 | """ 37 | # create writer here itself 38 | writer = None 39 | if writer_dir is not None: 40 | writer = SummaryWriter(log_dir=writer_dir) 41 | 42 | results_dict = { 43 | 'reward': [], 44 | 'policy_loss': [], 45 | 'value_loss': [], 46 | 'mean_entropy': [], 47 | 'mean_predicted_value': [] 48 | } 49 | 50 | running_t, running_reward, running_value_loss, running_policy_loss, \ 51 | running_reward_loss = 0, 0, 0, 0, 0 52 | 53 | torch.manual_seed(args.seed + rank) 54 | env = create_env(args.env_name, framework=args.framework, args=args) 55 | set_seed(args.seed + rank, env, args.framework) 56 | model = R_Module(env.action_space.shape[0], args.dim, 57 | discrete=args.discrete, baseline=args.baseline, 58 | state_space=env.observation_space.shape[0]) 59 | max_rollout = 0 60 | if args.planning: 61 | max_rollout = args.rollout 62 | 63 | if args.from_checkpoint is not None: 64 | model_state, _ = torch.load(args.from_checkpoint, map_location=lambda storage, loc: storage) 65 | model.load_state_dict(model_state) 66 | 67 | # no shared adam ? 68 | if optimizer is None: 69 | optimizer = optim.Adam(shared_model.parameters(), lr=args.lr, eps=args.eps) 70 | 71 | model.train() 72 | 73 | done = True 74 | episode_length = 0 75 | i_episode, total_episode = 0, 0 76 | start = time.time() 77 | while total_episode < args.num_episodes: 78 | # Sync with the shared model 79 | model.load_state_dict(shared_model.state_dict()) 80 | if done: 81 | cx_p = Variable(torch.zeros(1, args.dim)) 82 | hx_p = Variable(torch.zeros(1, args.dim)) 83 | cx_d = Variable(torch.zeros(1, args.dim)) 84 | hx_d = Variable(torch.zeros(1, args.dim)) 85 | i_episode += 1 86 | episode_length = 0 87 | total_episode = args.num_steps * (i_episode - 1) + rank 88 | start = time.time() 89 | last_episode_length = episode_length 90 | if not args.single_env and args.env_name.endswith('MazeEnv'): # generate new maze 91 | env = create_env( 92 | args.env_name, framework=args.framework, args=args) 93 | state = env.reset() 94 | state = Variable(torch.from_numpy(state).float()) 95 | if not args.baseline: 96 | state = enc(state) 97 | else: 98 | cx_p = Variable(cx_p.data) 99 | hx_p = Variable(hx_p.data) 100 | cx_d = Variable(cx_d.data) 101 | hx_d = Variable(hx_d.data) 102 | 103 | values = [] 104 | value_preds = [] 105 | log_probs = [] 106 | rewards = [] 107 | total_actions = [] 108 | entropies = [] 109 | obses = [] 110 | hx_ps = [] 111 | cx_ps = [] 112 | step = 0 113 | while step < args.num_steps: 114 | episode_length += 1 115 | if args.planning: 116 | _, actions, (hx_p, cx_p), (hx_d, cx_d), values, es, \ 117 | lps = mcts( 118 | env, state, model, d_module, enc, (hx_p, cx_p), (hx_d, cx_d), 119 | args, discrete=args.discrete) 120 | log_probs += lps 121 | entropies += es 122 | actions = actions[:1] 123 | else: 124 | obses.append(state.unsqueeze(0)) 125 | hx_ps.append(hx_p) 126 | cx_ps.append(cx_p) 127 | value, logit, (hx_p, cx_p) = model(( 128 | state.unsqueeze(0), (hx_p, cx_p))) 129 | action, entropy, log_prob = get_action( 130 | logit, discrete=args.discrete) 131 | vlog("Action: %s\t Bounds: %s" % (str(action), str( 132 | (env.action_space.low, env.action_space.high))), args.v) 133 | entropies.append(entropy.mean().data) 134 | actions = [action] 135 | values.append(value) 136 | log_probs.append(log_prob) 137 | for action in actions: 138 | state, reward, done, _ = env.step(action.data.numpy()) 139 | if args.neg_reward: 140 | reward = -reward 141 | state = Variable(torch.from_numpy(state).float()) 142 | if args.clip_reward: 143 | reward = max(min(reward, 1), -1) 144 | if not args.baseline: 145 | state = enc(state) 146 | rewards.append(reward) 147 | total_actions.append(action) 148 | step += 1 149 | if done: 150 | break 151 | if done: 152 | break 153 | R = torch.zeros(1, 1) 154 | if not done: 155 | value, _, _ = model((state.unsqueeze(0), (hx_p, cx_p))) 156 | R = value.data 157 | done = True 158 | 159 | values.append(Variable(R)) 160 | policy_loss = 0 161 | value_loss = 0 162 | advantages = np.zeros_like(rewards, dtype=float) 163 | R = Variable(R) 164 | gae = torch.zeros(1, 1) 165 | Rs = np.zeros_like(rewards, dtype=float) 166 | vlog("values: %s" % str([v.data[0,0] for v in values]), args.v) 167 | for i in reversed(range(len(rewards))): 168 | R = args.gamma * R + rewards[i] 169 | Rs[i] = R 170 | advantage = R - values[i] 171 | advantages[i] = advantage 172 | if args.algo == 'a3c': 173 | value_loss += 0.5 * advantage.pow(2) 174 | # Generalized Advantage Estimation 175 | if args.gae: 176 | delta_t = rewards[i] + args.gamma * values[i + 1].data \ 177 | - values[i].data 178 | gae = gae * args.gamma * args.tau + delta_t 179 | policy_loss -= (log_probs[i] * Variable(gae).expand_as( 180 | log_probs[i])).mean() 181 | else: 182 | policy_loss -= advantage * (log_probs[i].mean()) 183 | if args.algo == 'a3c': 184 | optimizer.zero_grad() 185 | (policy_loss + args.value_loss_coef * value_loss - \ 186 | args.entropy_coef * np.mean(entropies)).backward() 187 | torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) 188 | ensure_shared_grads(model, shared_model) 189 | optimizer.step() 190 | 191 | ########Bookkeeping and logging############# 192 | U = 1. / min(i_episode, 100) 193 | running_reward = running_reward * (1 - U) + sum(rewards) * U 194 | running_t = running_t * (1 - U) + episode_length * U 195 | running_policy_loss = running_policy_loss * (1 - U) + policy_loss.squeeze().data[0] * U 196 | running_value_loss = running_value_loss * (1 - U) + \ 197 | args.value_loss_coef * value_loss.squeeze().data[0] * U 198 | mean_entropy = np.mean([e.mean().data[0] for e in entropies]) 199 | 200 | mean_predicted_value = np.mean([v.sum().data[0] for v in values]) 201 | if total_episode % args.log_interval == 0 and done: 202 | if not args.discrete: 203 | sample_logits = (list(logit[0].data[0].numpy()), 204 | list(logit[1].data[0].numpy())) 205 | else: 206 | sample_logits = list(logit.data[0].numpy()) 207 | log( 208 | 'Frames {}\t'.format(total_episode) + \ 209 | 'Avg reward: {:.2f}\tAverage length: {:.2f}\t'.format( 210 | running_reward, running_t) + \ 211 | 'Entropy: {:.2f}\tTime: {:.2f}\tRank: {}\t'.format( 212 | mean_entropy, time.time() - start, rank) + \ 213 | 'Policy Loss: {:.2f}\t'.format(running_policy_loss) + \ 214 | # 'Reward Loss: {:.2f}\t'.format(running_reward_loss) + \ 215 | 'Weighted Value Loss: {:.2f}\t'.format(running_value_loss)) 216 | vlog('Sample Action: %s\t' % str(list(action.data.numpy())) + \ 217 | 'Logits: %s\t' % str(sample_logits), args.v) 218 | 219 | # write summaries here 220 | if writer_dir is not None and done: 221 | log('writing to tensorboard') 222 | # running losses 223 | writer.add_scalar('reward/running_reward', running_reward, i_episode) 224 | writer.add_scalar('reward/running_policy_loss', running_policy_loss, i_episode) 225 | writer.add_scalar('reward/running_value_loss', running_value_loss, i_episode) 226 | 227 | # current episode stats 228 | writer.add_scalar('reward/episode_reward', sum(rewards), i_episode) 229 | writer.add_scalar('reward/episode_policy_loss', policy_loss.squeeze().data[0], i_episode) 230 | writer.add_scalar('reward/episode_value_loss', value_loss.squeeze().data[0], i_episode) 231 | writer.add_scalar('reward/mean_entropy', mean_entropy, i_episode) 232 | writer.add_scalar('reward/mean_predicted_value', mean_predicted_value, i_episode) 233 | 234 | results_dict['reward'].append(sum(rewards)) 235 | results_dict['policy_loss'].append(policy_loss.squeeze().data[0]) 236 | results_dict['value_loss'].append(value_loss.squeeze().data[0]) 237 | results_dict['mean_entropy'].append(mean_entropy) 238 | results_dict['mean_predicted_value'].append(mean_predicted_value) 239 | 240 | if total_episode % args.checkpoint_interval == 0: 241 | args.curr_iter = total_episode 242 | args.optimizer = optimizer 243 | torch.save((shared_model.state_dict(), args), os.path.join( 244 | args.out, args.model_name + '%s.pt' % total_episode)) 245 | log("Saved model %d rank %s" % (total_episode, rank)) 246 | log(os.path.join( 247 | args.out, args.model_name + '%s.pt' % total_episode)) 248 | 249 | if writer_dir is not None and i_episode % \ 250 | (args.checkpoint_interval // args.num_processes) == 0: 251 | torch.save(results_dict, 252 | os.path.join(args.out, 'results_dict.pt')) 253 | log(os.path.join(args.out, 'results_dict.pt')) 254 | 255 | if writer_dir is not None: 256 | torch.save(results_dict, 257 | os.path.join(args.out, 'results_dict.pt')) 258 | log(os.path.join(args.out, 'results_dict.pt')) 259 | --------------------------------------------------------------------------------