├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── babyai ├── __init__.py ├── agents │ └── __init__.py ├── levels │ ├── __init__.py │ ├── instr_gen.py │ ├── instrs.py │ ├── levels.py │ ├── roomgrid.py │ └── verifier.py ├── model.py └── utils │ ├── __init__.py │ ├── agent.py │ ├── demos.py │ ├── format.py │ ├── log.py │ └── model.py ├── env.yml ├── scripts ├── __init__.py ├── enjoy.py ├── enjoy_zf.py ├── evaluate.py ├── evaluate_all_demos.py ├── evaluate_all_models.py ├── gen_samples.py ├── gen_samples_bidir.py ├── make_agent_demos.py ├── make_human_demos.py ├── rl_zforcing.py ├── rl_zforcing_dec.py ├── train_curclm.py ├── train_rl.py └── zforcing_main_state_dec.py └── setup.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to modeling_long_term_future 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | 33 | ## License 34 | By contributing to modeling_long_term_future, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the 278 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo contains code for our paper [Learning Dynamics Model in Reinforcement Learning by Incorporating the Long Term Future](https://arxiv.org/abs/1903.01599) 2 | 3 | The code base contains multiple branches. 4 | 5 | - The main branch contains experiments for the BabyAI tasks. 6 | - The mujoco branch contains experiments for the Mujoco tasks. 7 | - The carracing branch contains experiments for CarRacing task. 8 | 9 | Based on code base for the BabyAI project at Mila. https://github.com/mila-iqia/babyai 10 | 11 | Follow similar installations as in https://github.com/mila-iqia/babyai. 12 | 13 | 14 | Requirements: 15 | ## Installation 16 | 17 | Requirements: 18 | - Python 3.5+ 19 | - OpenAI Gym 20 | - NumPy 21 | - PyQT5 22 | - PyTorch 0.4.1+ 23 | 24 | Start by manually installing PyTorch. See the [PyTorch website](http://pytorch.org/) 25 | for installation instructions specific to your platform. 26 | 27 | Then, clone this repository and install the other dependencies with `pip3`: 28 | 29 | git clone https://github.com/facebookresearch/modeling_long_term_future.git 30 | cd modeling_long_term_future 31 | pip3 install --editable . 32 | 33 | Create a new conda env using env.yml in the repo 34 | 35 | ## Training teacher 36 | 37 | We use the BabyAI Pickup-Unlock game. 38 | 39 | First train the teacher (for imitation learning) using PPO with curriculum learning. Start with a room size of 6 and then work our way up to room size of 15. 40 | 41 | 42 | python3 -m scripts.train_curclm. --env BabyAI-UnlockPickup-v0 --algo ppo --arch cnn1 --tb --seed 1 --save-interval 10 --room-size 6 43 | python3 -m scripts.train_curclm. --env BabyAI-UnlockPickup-v0 --algo ppo --arch cnn1 --tb --seed 1 --save-interval 10 --room-size 8 --model MODEL_ROOM6_PRETRAINED 44 | python3 -m scripts.train_curclm. --env BabyAI-UnlockPickup-v0 --algo ppo --arch cnn1 --tb --seed 1 --save-interval 10 --room-size 10 --model MODEL_ROOM8_PRETRAINED 45 | python3 -m scripts.train_curclm. --env BabyAI-UnlockPickup-v0 --algo ppo --arch cnn1 --tb --seed 1 --save-interval 10 --room-size 12 --model MODEL_ROOM10_PRETRAINED 46 | python3 -m scripts.train_curclm. --env BabyAI-UnlockPickup-v0 --algo ppo --arch cnn1 --tb --seed 1 --save-interval 10 --room-size 15 --model MODEL_ROOM12_PRETRAINED 47 | 48 | 49 | ## Generate expert trajectories 50 | 51 | Generate expert trajectories from the experts trained using curriculum learning 52 | 53 | mnkdir data 54 | python3 -m scripts.gen_samples --episodes 10000 --env BabyAI-UnlockPickup-v0 --model pretrained_model_room_10 --room 10 55 | 56 | 57 | ## Training the student to imitate the expert 58 | 59 | To run our model 60 | 61 | python3 -m scripts.zforcing_main_state_dec --env BabyAI-UnlockPickup-v0 --datafile EXPERT_DATA_TO_LOAD --model MODEL_NAME --eval-episodes 100 --eval-interval 200 --bwd-weight 0.0 --lr 1e-4 --aux-weight-start 0.0001 --bwd-l2-weight 1. --kld-weight-start 0.2 --aux-weight-end 0.0001 --room 10 62 | 63 | 64 | To run the baseline 65 | 66 | python3 -m scripts.zforcing_main_state_dec --datafile EXPERT_DATA_TO_LOAD --env BabyAI-UnlockPickup-v0 --model MODEL_NAME --eval-episodes 100 --eval-interval 200 --bwd-weight 0.0 --lr 1e-4 --aux-weight-start 0.000 --aux-weight-end 0.0 --room 10 67 | 68 | 69 | ## License 70 | 71 | Find license in [LICENSE](LICENSE) file. 72 | -------------------------------------------------------------------------------- /babyai/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Import levels so that the OpenAI Gym environments get registered 3 | # when the babyai package is imported 4 | from . import levels 5 | from . import utils 6 | -------------------------------------------------------------------------------- /babyai/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /babyai/levels/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .levels import level_dict 3 | -------------------------------------------------------------------------------- /babyai/levels/instr_gen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | __author__ = "Saizheng Zhang" 3 | __credits__ = "Saizheng Zhang, Lucas Willems, Thien Huu Nguyen" 4 | 5 | from collections import namedtuple, defaultdict 6 | from copy import deepcopy 7 | import itertools 8 | import random 9 | 10 | from .instrs import * 11 | from gym_minigrid.minigrid import COLOR_NAMES 12 | 13 | CONCEPTS = { 14 | 'action': {'pickup', 'goto', 'drop', 'open'}, 15 | 'object': {'door', 'wall', 'ball', 'key', 'box'}, 16 | 'attr': {'color', 'loc', 'state'}, 17 | 'loc': {'loc_rel'}, 18 | 'color': COLOR_NAMES, 19 | 'loc_abs': {'east', 'west', 'north', 'south'}, 20 | 'loc_rel': {'left', 'right', 'front', 'behind'}, 21 | 'state': {'locked'}} 22 | 23 | # constraints (a, b) means that a and b can appear at the same time. 24 | CONSTRAINTS = \ 25 | {('key', v) for v in {'goto', 'pickup', 'drop'}} | \ 26 | {('wall', 'goto')} | \ 27 | {('door', v) for v in {'goto', 'open', 'locked'}} | \ 28 | {('ball', v) for v in {'goto', 'pickup', 'drop'}} | \ 29 | {('box', v) for v in {'goto', 'pickup', 'drop', 'open', 'locked'}} | \ 30 | {('object', 'color'), ('object', 'loc')} 31 | 32 | 33 | def check_valid_concept(name): 34 | if name in CONCEPTS: 35 | return True 36 | for c in CONCEPTS: 37 | if name in CONCEPTS[c]: 38 | return True 39 | raise ValueError("Incorrect concept name: {}".format(name)) 40 | 41 | 42 | def parent_concepts(name): 43 | check_valid_concept(name) 44 | parents = set() 45 | for k in CONCEPTS: 46 | if name in CONCEPTS[k]: 47 | parents = parents | {k} 48 | return parents 49 | 50 | 51 | def ancestor_concepts(name): 52 | parent_c = parent_concepts(name) 53 | if parent_c == set(): 54 | return set() 55 | else: 56 | ancestor_c = set() 57 | for pa in parent_c: 58 | ancestor_c |= ancestor_concepts(pa) 59 | return parent_concepts(name) | ancestor_c 60 | 61 | 62 | def is_ancestor(x, y): 63 | return True if x in ancestor_concepts(y) else False 64 | 65 | 66 | def child_concepts(name): 67 | check_valid_concept(name) 68 | return CONCEPTS[name] 69 | 70 | 71 | def root_concepts(name): 72 | check_valid_concept(name) 73 | if parent_concepts(name) == set(): 74 | return {name} 75 | else: 76 | roots = set() 77 | for pa in parent_concepts(name): 78 | roots |= root_concepts(pa) 79 | return roots 80 | 81 | 82 | def is_consistent(m, n): 83 | check_valid_concept(m) 84 | check_valid_concept(n) 85 | 86 | # ancestor reduction rule 87 | def rule_anc_reduct(x, y): 88 | prod_xy = itertools.product(ancestor_concepts(x) | {x}, ancestor_concepts(y) | {y}) 89 | if any([(p_xy in CONSTRAINTS) or ((p_xy[1], p_xy[0]) in CONSTRAINTS) for p_xy in prod_xy]): 90 | return True 91 | else: 92 | return False 93 | 94 | if rule_anc_reduct(m, n): 95 | return True 96 | 97 | # action-object-attribute rule 98 | def rule_act_obj_attr(x, y): 99 | for o in CONCEPTS['object']: 100 | if rule_anc_reduct(x, o) and rule_anc_reduct(y, o) and \ 101 | root_concepts(x) in [{'action'}, {'attr'}] and \ 102 | root_concepts(y) in [{'action'}, {'attr'}] and \ 103 | root_concepts(x) != root_concepts(y): 104 | return True 105 | return False 106 | 107 | if rule_act_obj_attr(m, n): 108 | return True 109 | 110 | # exclusive rule 111 | def rule_exclusive(x, y): 112 | if x == y: 113 | return True 114 | if is_ancestor(x, y) or is_ancestor(y, x): 115 | return True 116 | 117 | exclusive_list = {'action', 'color', 'loc'} 118 | if ancestor_concepts(x) & ancestor_concepts(y) & exclusive_list: 119 | return False 120 | if 'attr' in ancestor_concepts(x) and 'attr' in ancestor_concepts(y): 121 | return True 122 | return False 123 | 124 | if rule_exclusive(m, n): 125 | return True 126 | 127 | return False 128 | 129 | 130 | def extract_cands_in_generate(type, constraints=set()): 131 | cands = [] 132 | for t in CONCEPTS[type]: 133 | if all([is_consistent(t, c) for c in constraints]) or not constraints: 134 | cands.append(t) 135 | return cands 136 | 137 | 138 | def gen_instr_seq(seed, constraintss=[set()]): 139 | random.seed(seed) 140 | return [gen_ainstr(constraints) for constraints in constraintss] 141 | 142 | 143 | def gen_ainstr(constraints=set()): 144 | act = gen_action(constraints) 145 | obj = gen_object(act, constraints) 146 | return Instr(action=act, object=obj) 147 | 148 | 149 | def gen_action(constraints=set()): 150 | action_cands = extract_cands_in_generate('action', constraints) 151 | action = random.choice(action_cands) 152 | return action 153 | 154 | 155 | def gen_object(act=None, constraints=set()): 156 | o_cands = extract_cands_in_generate('object', constraints | (set() if not act else {act})) 157 | o = random.choice(o_cands) 158 | 159 | o_color = gen_color(constraints=constraints) 160 | o_loc = gen_loc(constraints=constraints) 161 | o_state = gen_state(obj=o, constraints=constraints) 162 | 163 | return Object(type=o, color=o_color, loc=o_loc, state=o_state) 164 | 165 | 166 | def gen_subattr(type, constraints=set()): 167 | cands = extract_cands_in_generate(type, constraints) 168 | if not cands: 169 | return None 170 | 171 | if any([is_ancestor(type, c) or type == c for c in constraints]): 172 | return random.choice(cands) 173 | else: 174 | if random.choice([True, False]): 175 | return random.choice(cands) 176 | else: 177 | return None 178 | 179 | 180 | def gen_color(obj=None, constraints=set()): 181 | return gen_subattr('color', constraints) 182 | 183 | 184 | def gen_loc(obj=None, act=None, constraints=set()): 185 | subloc = gen_subattr('loc', constraints) 186 | if not subloc: 187 | return None 188 | if subloc == 'loc_abs': 189 | return gen_locabs(obj=obj, act=act, constraints=constraints | {'loc_abs'}) 190 | if subloc == 'loc_rel': 191 | return gen_locrel(obj=obj, act=act, constraints=constraints | {'loc_rel'}) 192 | 193 | 194 | def gen_locabs(obj=None, act=None, constraints=set()): 195 | return gen_subattr('loc_abs', constraints) 196 | 197 | 198 | def gen_locrel(obj=None, act=None, constraints=set()): 199 | return gen_subattr('loc_rel', constraints) 200 | 201 | 202 | def gen_state(obj=None, act=None, constraints=set()): 203 | return gen_subattr('state', constraints | (set() if not obj else {obj})) 204 | 205 | 206 | def gen_surface(ntup, conditions={}, seed=0, lang_variation=None): 207 | # Create a private RNG to avoid interfering with the global Python RNG 208 | rng = random.Random(seed) 209 | 210 | def choice(elems, lang_variation=None): 211 | if lang_variation != None: 212 | elems = elems[:lang_variation] 213 | return rng.choice(elems) 214 | 215 | def gen(ntup, conditions={}): 216 | if ntup == None: 217 | return '' 218 | 219 | if isinstance(ntup, list): 220 | random.seed(seed) 221 | s_instr = '' 222 | for i, ainstr in enumerate(ntup): 223 | if i > 0: 224 | s_instr += choice([' and then', ', then'], lang_variation) + ' ' 225 | s_instr += gen(ainstr) 226 | return s_instr 227 | 228 | if isinstance(ntup, Instr): 229 | s_ainstr = gen(ntup.action) 230 | if ntup.action != 'drop': 231 | s_ainstr += ' ' + gen_surface(ntup.object) 232 | return s_ainstr 233 | 234 | if isinstance(ntup, Object): 235 | s_obj = ntup.type 236 | if s_obj is None: 237 | s_obj = "thing" 238 | s_attrs = list([ntup.color, ntup.loc, ntup.state]) 239 | random.shuffle(s_attrs) 240 | for f in s_attrs: 241 | if not f: 242 | continue 243 | if f == ntup.color: 244 | cond = choice(['pre'], lang_variation) 245 | elif f == ntup.loc: 246 | cond = choice(['after', 'which is', 'that is'], lang_variation) 247 | else: 248 | cond = choice(['pre', 'after', 'which is', 'that is'], lang_variation) 249 | if cond == 'pre': 250 | s_obj = gen(f, conditions={cond}) + ' ' + s_obj 251 | if cond == 'after': 252 | if 'which is' in s_obj or 'that is' in s_obj: 253 | s_obj = s_obj + ' and ' + gen(f, conditions={cond}) 254 | else: 255 | s_obj = s_obj + ' ' + 'which is ' + gen(f, conditions={cond}) if f in CONCEPTS[ 256 | 'state'] else gen(f, conditions={cond}) 257 | if cond in ['which is', 'that is']: 258 | if 'which is' in s_obj or 'that is' in s_obj: 259 | s_obj = s_obj + ' and ' + gen(f, conditions={cond}) 260 | else: 261 | s_obj = s_obj + ' {} '.format(cond) + gen(f, conditions={cond}) 262 | return 'the ' + s_obj 263 | 264 | if ntup == 'goto': 265 | return choice(['go to', 'reach', 'find', 'walk to'], lang_variation) 266 | 267 | if ntup == 'pickup': 268 | return choice( 269 | ['pickup', 'pick up', 'grasp', 'go pick', 'go grasp', 'go get', 'get', 'go fetch', 'fetch'], 270 | lang_variation) 271 | 272 | if ntup == 'drop': 273 | return choice(['drop', 'drop down', 'put down'], lang_variation) 274 | 275 | if ntup == 'open': 276 | return choice(['open'], lang_variation) 277 | 278 | if ntup in CONCEPTS['color']: 279 | if {'pre'} & conditions: 280 | return ntup 281 | if {'after'} & conditions: 282 | return choice(['in {}'.format(ntup)], lang_variation) 283 | if {'which is', 'that is'} & conditions: 284 | return ntup 285 | 286 | if ntup in CONCEPTS['loc_abs']: 287 | if {'pre'} & conditions: 288 | return ntup 289 | if {'after', 'which is', 'that is'} & conditions: 290 | return choice(['to the {}'.format(ntup), 'on the {} direction'.format(ntup)], lang_variation) 291 | 292 | if ntup in CONCEPTS['loc_rel']: 293 | if {'pre'} & conditions: 294 | return ntup 295 | if {'which is', 'that is', 'after'} & conditions: 296 | if ntup == 'front': 297 | return 'in front of you' 298 | return choice(['on your {}'.format(ntup)], lang_variation) 299 | 300 | if ntup in CONCEPTS['state']: 301 | return choice([ntup], lang_variation) 302 | 303 | return gen(ntup, lang_variation) 304 | 305 | 306 | def test(): 307 | for i in range(10): 308 | seed = i 309 | instr = gen_instr_seq(seed) 310 | # print(instr) 311 | gen_surface(instr, seed) 312 | 313 | for i in range(10): 314 | instr = gen_instr_seq(i, constraintss=[{'pickup', 'key'}, {'drop'}]) 315 | # print(instr) 316 | gen_surface(instr, seed) 317 | 318 | # Same seed must yield the same instructions and string 319 | str1 = gen_surface(gen_instr_seq(seed), seed=7) 320 | str2 = gen_surface(gen_instr_seq(seed), seed=7) 321 | assert str1 == str2 322 | 323 | 324 | if __name__ == "__main__": 325 | test() 326 | -------------------------------------------------------------------------------- /babyai/levels/instrs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from collections import namedtuple 3 | 4 | from gym_minigrid.minigrid import COLOR_NAMES 5 | 6 | ACTION_NAMES = ['goto', 'open', 'pickup', 'drop'] 7 | TYPE_NAMES = ['door', 'locked_door', 'box', 'ball', 'key', 'wall'] 8 | LOC_NAMES = ['left', 'right', 'front', 'behind'] 9 | STATE_NAMES = ['locked'] 10 | 11 | 12 | class Instr: 13 | def __init__(self, action, object): 14 | assert action in [*ACTION_NAMES] 15 | 16 | self.action = action 17 | self.object = object 18 | 19 | 20 | class Object: 21 | def __init__(self, type=None, color=None, loc=None, state=None): 22 | assert type in [None, *TYPE_NAMES] 23 | assert color in [None, *COLOR_NAMES] 24 | assert loc in [None, *LOC_NAMES] 25 | assert state in [None, *STATE_NAMES] 26 | 27 | self.type = type 28 | self.color = color 29 | self.loc = loc 30 | self.state = state 31 | 32 | if type is 'locked_door': 33 | self.type = 'door' -------------------------------------------------------------------------------- /babyai/levels/levels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import random 3 | from collections import OrderedDict 4 | from copy import deepcopy 5 | 6 | import gym 7 | 8 | from gym_minigrid.envs import Key, Ball, Box 9 | 10 | from .roomgrid import RoomGrid 11 | from .instrs import * 12 | from .instr_gen import gen_instr_seq, gen_object, gen_surface 13 | from .verifier import InstrSeqVerifier, OpenVerifier, PickupVerifier 14 | 15 | 16 | class RoomGridLevel(RoomGrid): 17 | """ 18 | Base for levels based on RoomGrid 19 | A level, given a random seed, generates missions generated from 20 | one or more patterns. Levels should produce a family of missions 21 | of approximately similar difficulty. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | lang_variation=1, 27 | room_size=6, 28 | max_steps=None, 29 | **kwargs 30 | ): 31 | # Default max steps computation 32 | if max_steps is None: 33 | max_steps = 4 * (room_size ** 2) 34 | 35 | self.lang_variation = lang_variation 36 | super().__init__( 37 | room_size=room_size, 38 | max_steps=max_steps, 39 | **kwargs 40 | ) 41 | 42 | def reset(self, **kwargs): 43 | obs = super().reset(**kwargs) 44 | 45 | # Recreate the verifier 46 | self.verifier = InstrSeqVerifier(self, self.instrs) 47 | 48 | return obs 49 | 50 | def step(self, action): 51 | obs, reward, done, info = super().step(action) 52 | 53 | # If we've successfully completed the mission 54 | if self.verifier.step() is True: 55 | done = True 56 | reward = self._reward() 57 | 58 | return obs, reward, done, info 59 | 60 | def _gen_grid(self, width, height): 61 | super()._gen_grid(width, height) 62 | 63 | # Generate the mission 64 | self.gen_mission() 65 | 66 | # Generate the surface form for the instructions 67 | seed = self._rand_int(0, 0xFFFFFFFF) 68 | self.surface = gen_surface(self.instrs, seed, lang_variation=self.lang_variation) 69 | self.mission = self.surface 70 | 71 | def gen_mission(self): 72 | """ 73 | Generate a mission (instructions and matching environment) 74 | Derived level classes should implement this method 75 | """ 76 | raise NotImplementedError 77 | 78 | @property 79 | def level_name(self): 80 | return self.__class__.level_name 81 | 82 | @property 83 | def gym_id(self): 84 | return self.__class__.gym_id 85 | 86 | 87 | class Level_OpenRedDoor(RoomGridLevel): 88 | """ 89 | Go to the red door 90 | (always unlocked, in the current room) 91 | Note: this level is intentionally meant for debugging and is 92 | intentionally kept very simple. 93 | """ 94 | 95 | def __init__(self, seed=None): 96 | super().__init__( 97 | num_rows=1, 98 | num_cols=2, 99 | room_size=5, 100 | seed=seed 101 | ) 102 | 103 | def gen_mission(self): 104 | obj, _ = self.add_door(0, 0, 0, 'red', locked=False) 105 | self.place_agent(0, 0) 106 | self.instrs = [Instr(action="open", object=Object(obj.type, obj.color))] 107 | 108 | 109 | class Level_OpenDoor(RoomGridLevel): 110 | """ 111 | Go to the door 112 | The door to open is given by its color or by its location. 113 | (always unlocked, in the current room) 114 | """ 115 | 116 | def __init__(self, select_by=None, seed=None): 117 | self.select_by = select_by 118 | 119 | super().__init__( 120 | seed=seed 121 | ) 122 | 123 | def gen_mission(self): 124 | door_colors = self._rand_subset(COLOR_NAMES, 4) 125 | objs = [] 126 | 127 | for i, color in enumerate(door_colors): 128 | obj, _ = self.add_door(1, 1, door_idx=i, color=color, locked=False) 129 | objs.append(obj) 130 | 131 | select_by = self.select_by 132 | if select_by is None: 133 | select_by = self._rand_elem(["color", "loc"]) 134 | if select_by == "color": 135 | object = Object(objs[0].type, color=objs[0].color) 136 | elif select_by == "loc": 137 | object = Object(objs[0].type, loc=self._rand_elem(LOC_NAMES)) 138 | 139 | self.place_agent(1, 1) 140 | self.instrs = [Instr(action="open", object=object)] 141 | 142 | 143 | class Level_OpenDoorDebug(Level_OpenDoor): 144 | """ 145 | Same as OpenDoor but the level stops when any door is opened 146 | """ 147 | 148 | def reset(self, **kwargs): 149 | obs = super().reset(**kwargs) 150 | 151 | # Recreate the verifier 152 | self.verifier = InstrSeqVerifier(self, self.instrs) 153 | # Recreate the open verifier 154 | self.open_verifier = OpenVerifier(self, Object("door")) 155 | 156 | return obs 157 | 158 | def step(self, action): 159 | obs, reward, done, info = super().step(action) 160 | 161 | # If we've successfully completed the mission 162 | if self.verifier.step() is True: 163 | done = True 164 | reward = self._reward() 165 | # If we've opened the wrong door 166 | elif self.open_verifier.step() is True: 167 | done = True 168 | 169 | return obs, reward, done, info 170 | 171 | 172 | class Level_OpenDoorColor(Level_OpenDoor): 173 | """ 174 | Go to the door 175 | The door is selected by color. 176 | (always unlocked, in the current room) 177 | """ 178 | 179 | def __init__(self, seed=None): 180 | super().__init__( 181 | select_by="color", 182 | seed=seed 183 | ) 184 | 185 | 186 | class Level_OpenDoorColorDebug(Level_OpenDoorColor, Level_OpenDoorDebug): 187 | """ 188 | Same as OpenDoorColor but the level stops when any door is opened 189 | """ 190 | 191 | pass 192 | 193 | 194 | class Level_OpenDoorLoc(Level_OpenDoor): 195 | """ 196 | Go to the door 197 | The door is selected by location. 198 | (always unlocked, in the current room) 199 | """ 200 | 201 | def __init__(self, seed=None): 202 | super().__init__( 203 | select_by="loc", 204 | seed=seed 205 | ) 206 | 207 | 208 | class Level_OpenDoorLocDebug(Level_OpenDoorLoc, Level_OpenDoorDebug): 209 | """ 210 | Same as OpenDoorLoc but the level stops when any door is opened 211 | """ 212 | 213 | pass 214 | 215 | 216 | class Level_GoToObjDoor(RoomGridLevel): 217 | """ 218 | Go to an object or door 219 | (of a given type and color, in the current room) 220 | """ 221 | 222 | def __init__(self, seed=None): 223 | super().__init__( 224 | room_size=7, 225 | lang_variation=2, 226 | seed=seed 227 | ) 228 | 229 | def gen_mission(self): 230 | objs = self.add_distractors(num_distractors=5, room_i=1, room_j=1) 231 | for _ in range(4): 232 | door, _ = self.add_door(1, 1) 233 | objs.append((door.type, door.color)) 234 | self.place_agent(1, 1) 235 | 236 | type, color = self._rand_elem(objs) 237 | self.instrs = [Instr(action="goto", object=Object(type, color))] 238 | 239 | 240 | class Level_ActionObjDoor(RoomGridLevel): 241 | """ 242 | [pick up an object] or 243 | [go to an object or door] or 244 | [open a door] 245 | (in the current room) 246 | """ 247 | 248 | def __init__(self, seed=None): 249 | super().__init__( 250 | room_size=7, 251 | lang_variation=2, 252 | seed=seed 253 | ) 254 | 255 | def gen_mission(self): 256 | objs = self.add_distractors(num_distractors=5, room_i=1, room_j=1) 257 | for _ in range(4): 258 | door, _ = self.add_door(1, 1, locked=False) 259 | objs.append((door.type, door.color)) 260 | 261 | self.place_agent(1, 1) 262 | 263 | type, color = self._rand_elem(objs) 264 | if type == door.type: 265 | action = self._rand_elem(['goto', 'open']) 266 | else: 267 | action = self._rand_elem(['goto', 'pickup']) 268 | self.instrs = [Instr(action=action, object=Object(type, color))] 269 | 270 | 271 | class Level_Unlock(RoomGridLevel): 272 | """ 273 | Fetch a key and unlock a door 274 | (in the current room) 275 | """ 276 | 277 | def __init__(self, distractors=False, seed=None): 278 | self.distractors = distractors 279 | super().__init__( 280 | seed=seed 281 | ) 282 | 283 | def gen_mission(self): 284 | door, _ = self.add_door(1, 1, locked=True) 285 | self.add_object(1, 1, 'key', door.color) 286 | if self.distractors: 287 | self.add_distractors(num_distractors=3, room_i=1, room_j=1) 288 | self.place_agent(1, 1) 289 | 290 | self.instrs = [Instr(action="open", object=Object(door.type))] 291 | 292 | 293 | class Level_UnlockDist(Level_Unlock): 294 | """ 295 | Fetch a key and unlock a door 296 | (in the current room, with distractors) 297 | """ 298 | 299 | def __init__(self, seed=None): 300 | super().__init__(distractors=True, seed=seed) 301 | 302 | 303 | class Level_KeyInBox(RoomGridLevel): 304 | """ 305 | Unlock a door. Key is in a box (in the current room). 306 | """ 307 | 308 | def __init__(self, seed=None): 309 | super().__init__( 310 | seed=seed 311 | ) 312 | 313 | def gen_mission(self): 314 | door, _ = self.add_door(1, 1, locked=True) 315 | 316 | # Put the key in the box, then place the box in the room 317 | key = Key(door.color) 318 | box = Box(self._rand_color(), key) 319 | self.place_in_room(1, 1, box) 320 | 321 | self.place_agent(1, 1) 322 | 323 | self.instrs = [Instr(action="open", object=Object(door.type))] 324 | 325 | 326 | class Level_UnlockPickup(RoomGridLevel): 327 | """ 328 | Unlock a door, then pick up a box in another room 329 | """ 330 | 331 | def __init__(self,room_size=6, distractors=False, seed=None): 332 | self.distractors = distractors 333 | super().__init__( 334 | num_rows=1, 335 | num_cols=2, 336 | room_size=room_size, 337 | max_steps=8*room_size**2, 338 | seed=seed 339 | ) 340 | 341 | def set_roomsize(self, room_size): 342 | #self.room_size = room_size 343 | #self.max_steps = 8 * self.room_size ** 2 344 | #self.gen_mission() 345 | super().__init__( 346 | num_rows=1, 347 | num_cols=2, 348 | room_size=room_size, 349 | max_steps=8*room_size**2 350 | ) 351 | 352 | def gen_mission(self): 353 | # Add a random object to the room on the right 354 | obj, _ = self.add_object(1, 0, kind="box") 355 | # Make sure the two rooms are directly connected by a locked door 356 | door, _ = self.add_door(0, 0, 0, locked=True) 357 | # Add a key to unlock the door 358 | self.add_object(0, 0, 'key', door.color) 359 | if self.distractors: 360 | self.add_distractors(num_distractors=4) 361 | 362 | self.place_agent(0, 0) 363 | 364 | self.instrs = [Instr(action="pickup", object=Object(obj.type, obj.color))] 365 | 366 | 367 | class Level_UnlockPickupDist(Level_UnlockPickup): 368 | """ 369 | Unlock a door, then pick up an object in another room 370 | (with distractors) 371 | """ 372 | 373 | def __init__(self, seed=None): 374 | super().__init__(distractors=True, seed=seed) 375 | 376 | 377 | class Level_BlockedUnlockPickup(RoomGridLevel): 378 | """ 379 | Unlock a door blocked by a ball, then pick up a box 380 | in another room 381 | """ 382 | 383 | def __init__(self, room_size=6, seed=None): 384 | super().__init__( 385 | num_rows=1, 386 | num_cols=2, 387 | room_size=room_size, 388 | max_steps=16*room_size**2, 389 | seed=seed 390 | ) 391 | 392 | def gen_mission(self): 393 | # Add a box to the room on the right 394 | obj, _ = self.add_object(1, 0, kind="box") 395 | # Make sure the two rooms are directly connected by a locked door 396 | door, pos = self.add_door(0, 0, 0, locked=True) 397 | # Block the door with a ball 398 | color = self._rand_color() 399 | self.grid.set(pos[0]-1, pos[1], Ball(color)) 400 | # Add a key to unlock the door 401 | self.add_object(0, 0, 'key', door.color) 402 | 403 | self.place_agent(0, 0) 404 | 405 | self.instrs = [Instr(action="pickup", object=Object(obj.type))] 406 | 407 | 408 | class Level_UnlockToUnlock(RoomGridLevel): 409 | """ 410 | Unlock a door A that requires to unlock a door B before 411 | """ 412 | 413 | def __init__(self, seed=None, room_size=6): 414 | super().__init__( 415 | num_rows=1, 416 | num_cols=3, 417 | room_size=room_size, 418 | max_steps=30*room_size**2, 419 | seed=seed 420 | ) 421 | 422 | def gen_mission(self): 423 | colors = self._rand_subset(COLOR_NAMES, 2) 424 | 425 | # Add a door of color A connecting left and middle room 426 | self.add_door(0, 0, door_idx=0, color=colors[0], locked=True) 427 | 428 | # Add a key of color A in the room on the right 429 | self.add_object(2, 0, kind="key", color=colors[0]) 430 | 431 | # Add a door of color B connecting middle and right room 432 | self.add_door(1, 0, door_idx=0, color=colors[1], locked=True) 433 | 434 | # Add a key of color B in the middle room 435 | self.add_object(1, 0, kind="key", color=colors[1]) 436 | 437 | obj, _ = self.add_object(0, 0, kind="ball") 438 | 439 | self.place_agent(1, 0) 440 | 441 | self.instrs = [Instr(action="pickup", object=Object(obj.type))] 442 | 443 | 444 | class Level_PickupDist(RoomGridLevel): 445 | """ 446 | Pick up an object 447 | The object to pick up is given by its type only, or 448 | by its color, or by its type and color. 449 | (in the current room, with distractors) 450 | """ 451 | 452 | def __init__(self, seed=None): 453 | super().__init__( 454 | num_rows = 1, 455 | num_cols = 1, 456 | room_size=7, 457 | lang_variation=2, 458 | seed=seed 459 | ) 460 | 461 | def gen_mission(self): 462 | # Add 5 random objects in the room 463 | objs = self.add_distractors(5) 464 | self.place_agent(0, 0) 465 | type, color = self._rand_elem(objs) 466 | 467 | select_by = self._rand_elem(["type", "color", "both"]) 468 | if select_by == "color": 469 | type = None 470 | elif select_by == "type": 471 | color = None 472 | 473 | self.instrs = [Instr(action="pickup", object=Object(type, color))] 474 | 475 | 476 | class Level_PickupDistDebug(Level_PickupDist): 477 | """ 478 | Same as PickupDist but the level stops when any object is picked 479 | """ 480 | 481 | def reset(self, **kwargs): 482 | obs = super().reset(**kwargs) 483 | 484 | # Recreate the verifier 485 | self.verifier = InstrSeqVerifier(self, self.instrs) 486 | # Recreate the pickup verifier 487 | self.pickup_verifier = PickupVerifier(self, Object()) 488 | 489 | return obs 490 | 491 | def step(self, action): 492 | obs, reward, done, info = super().step(action) 493 | 494 | # If we've successfully completed the mission 495 | if self.verifier.step() is True: 496 | done = True 497 | reward = self._reward() 498 | # If we've picked up the wrong object 499 | elif self.pickup_verifier.step() is True: 500 | done = True 501 | 502 | return obs, reward, done, info 503 | 504 | 505 | class Level_PickupAbove(RoomGridLevel): 506 | """ 507 | Pick up an object (in the room above) 508 | This task requires to use the compass to be solved effectively. 509 | """ 510 | 511 | def __init__(self, seed=None, room_size=6): 512 | super().__init__( 513 | room_size=room_size, 514 | max_steps=8*room_size**2, 515 | lang_variation=2, 516 | seed=seed 517 | ) 518 | 519 | def gen_mission(self): 520 | # Add a random object to the top-middle room 521 | obj, pos = self.add_object(1, 0) 522 | # Make sure the two rooms are directly connected 523 | self.add_door(1, 1, 3, locked=False) 524 | self.place_agent(1, 1) 525 | self.connect_all() 526 | 527 | self.instrs = [Instr(action="pickup", object=Object(obj.type, obj.color))] 528 | 529 | 530 | class Level_OpenTwoDoors(RoomGridLevel): 531 | """ 532 | Open door X, then open door Y 533 | The two doors are facing opposite directions, so that the agent 534 | Can't see whether the door behind him is open. 535 | This task requires memory (recurrent policy) to be solved effectively. 536 | """ 537 | 538 | def __init__(self, room_size=6, first_color=None, second_color=None, seed=None): 539 | self.first_color = first_color 540 | self.second_color = second_color 541 | 542 | super().__init__( 543 | room_size=room_size, 544 | max_steps=20*room_size**2, 545 | lang_variation=2, 546 | seed=seed 547 | ) 548 | 549 | def gen_mission(self): 550 | colors = self._rand_subset(COLOR_NAMES, 2) 551 | 552 | first_color = self.first_color 553 | if first_color is None: 554 | first_color = colors[0] 555 | second_color = self.second_color 556 | if second_color is None: 557 | second_color = colors[1] 558 | 559 | door1, _ = self.add_door(1, 1, 2, color=first_color, locked=False) 560 | door2, _ = self.add_door(1, 1, 0, color=second_color, locked=False) 561 | 562 | self.place_agent(1, 1) 563 | 564 | self.instrs = [ 565 | Instr(action="open", object=Object(door1.type, door1.color)), 566 | Instr(action="open", object=Object(door2.type, door2.color)) 567 | ] 568 | 569 | 570 | class Level_OpenTwoDoorsDebug(Level_OpenTwoDoors): 571 | """ 572 | Same as OpenTwoDoors but the level stops when the second door is opened 573 | """ 574 | 575 | def reset(self, **kwargs): 576 | obs = super().reset(**kwargs) 577 | 578 | # Recreate the verifier 579 | self.verifier = InstrSeqVerifier(self, self.instrs) 580 | # Recreate the open second verifier 581 | second_color = self.instrs[1].object.color 582 | self.open_second_verifier = OpenVerifier(self, Object("door", second_color)) 583 | 584 | return obs 585 | 586 | def step(self, action): 587 | obs, reward, done, info = super().step(action) 588 | 589 | # If we've successfully completed the mission 590 | if self.verifier.step() is True: 591 | done = True 592 | reward = self._reward() 593 | # If we've opened the wrong door 594 | elif self.open_second_verifier.step() is True: 595 | done = True 596 | 597 | return obs, reward, done, info 598 | 599 | 600 | class Level_OpenRedBlueDoors(Level_OpenTwoDoors): 601 | """ 602 | Open red door, then open blue door 603 | The two doors are facing opposite directions, so that the agent 604 | Can't see whether the door behind him is open. 605 | This task requires memory (recurrent policy) to be solved effectively. 606 | """ 607 | 608 | def __init__(self, seed=None): 609 | super().__init__( 610 | first_color="red", 611 | second_color="blue", 612 | seed=seed 613 | ) 614 | 615 | 616 | class Level_OpenRedBlueDoorsDebug(Level_OpenTwoDoorsDebug): 617 | """ 618 | Same as OpenRedBlueDoors but the level stops when the blue door is opened 619 | """ 620 | 621 | def __init__(self, seed=None): 622 | super().__init__( 623 | first_color="red", 624 | second_color="blue", 625 | seed=seed 626 | ) 627 | 628 | 629 | class Level_FindObjS5(RoomGridLevel): 630 | """ 631 | Pick up an object (in a random room) 632 | Rooms have a size of 5 633 | This level requires potentially exhaustive exploration 634 | """ 635 | 636 | def __init__(self, room_size=5, seed=None): 637 | super().__init__( 638 | room_size=room_size, 639 | max_steps=20*room_size**2, 640 | seed=seed 641 | ) 642 | 643 | def gen_mission(self): 644 | # Add a random object to a random room 645 | i = self._rand_int(0, self.num_rows) 646 | j = self._rand_int(0, self.num_cols) 647 | obj, _ = self.add_object(i, j) 648 | self.place_agent(1, 1) 649 | self.connect_all() 650 | 651 | self.instrs = [Instr(action="pickup", object=Object(obj.type))] 652 | 653 | 654 | class Level_FindObjS6(Level_FindObjS5): 655 | """ 656 | Same as the FindObjS5 level, but rooms have a size of 6 657 | """ 658 | 659 | def __init__(self, seed=None): 660 | super().__init__( 661 | room_size=6, 662 | seed=seed 663 | ) 664 | 665 | 666 | class Level_FindObjS7(Level_FindObjS5): 667 | """ 668 | Same as the FindObjS5 level, but rooms have a size of 7 669 | """ 670 | 671 | def __init__(self, seed=None): 672 | super().__init__( 673 | room_size=7, 674 | seed=seed 675 | ) 676 | 677 | 678 | class Level_FourObjsS5(RoomGridLevel): 679 | """ 680 | Four identical objects in four different rooms. The task is 681 | to pick up the correct one. 682 | The object to pick up is given by its location. 683 | Rooms have a size of 5. 684 | """ 685 | 686 | def __init__(self, room_size=5, seed=None): 687 | super().__init__( 688 | room_size=room_size, 689 | max_steps=20*room_size**2, 690 | lang_variation=2, 691 | seed=seed 692 | ) 693 | 694 | def gen_mission(self): 695 | obj, _ = self.add_object(1, 0) 696 | self.add_object(1, 2, obj.type, obj.color) 697 | self.add_object(0, 1, obj.type, obj.color) 698 | self.add_object(2, 1, obj.type, obj.color) 699 | 700 | # Make sure the start room is directly connected to the 701 | # four adjacent rooms 702 | for i in range(0, 4): 703 | _, _ = self.add_door(1, 1, i, locked=False) 704 | 705 | self.place_agent(1, 1) 706 | 707 | # Choose a random object to pick up 708 | loc = self._rand_elem(LOC_NAMES) 709 | rand_obj = Object(obj.type, obj.color, loc) 710 | self.instrs = [Instr(action="pickup", object=rand_obj)] 711 | 712 | 713 | class Level_FourObjsS6(Level_FourObjsS5): 714 | """ 715 | Same as the FindObjS5 level, but rooms have a size of 6 716 | """ 717 | 718 | def __init__(self, seed=None): 719 | super().__init__( 720 | room_size=6, 721 | seed=seed 722 | ) 723 | 724 | 725 | class Level_FourObjsS7(Level_FourObjsS5): 726 | """ 727 | Same as the FindObjS5 level, but rooms have a size of 7 728 | """ 729 | 730 | def __init__(self, seed=None): 731 | super().__init__( 732 | room_size=7, 733 | seed=seed 734 | ) 735 | 736 | 737 | class KeyCorridor(RoomGridLevel): 738 | """ 739 | A ball is behind a locked door, the key is placed in a 740 | random room. 741 | """ 742 | 743 | def __init__( 744 | self, 745 | num_rows=3, 746 | obj_type="ball", 747 | room_size=6, 748 | seed=None 749 | ): 750 | self.obj_type = obj_type 751 | 752 | super().__init__( 753 | room_size=room_size, 754 | num_rows=num_rows, 755 | max_steps=30*room_size**2, 756 | lang_variation=3, 757 | seed=seed, 758 | ) 759 | 760 | def gen_mission(self): 761 | # Connect the middle column rooms into a hallway 762 | for j in range(1, self.num_rows): 763 | self.remove_wall(1, j, 3) 764 | 765 | # Add a locked door on the bottom right 766 | # Add an object behind the locked door 767 | room_idx = self._rand_int(0, self.num_rows) 768 | door, _ = self.add_door(2, room_idx, 2, locked=True) 769 | obj, _ = self.add_object(2, room_idx, kind=self.obj_type) 770 | 771 | # Add a key in a random room on the left side 772 | self.add_object(0, self._rand_int(0, self.num_rows), 'key', door.color) 773 | 774 | # Place the agent in the middle 775 | self.place_agent(1, self.num_rows // 2) 776 | 777 | # Make sure all rooms are accessible 778 | self.connect_all() 779 | 780 | self.instrs = [Instr(action="pickup", object=Object(obj.type))] 781 | 782 | 783 | class Level_KeyCorridorS3R1(KeyCorridor): 784 | def __init__(self, seed=None): 785 | super().__init__( 786 | room_size=3, 787 | num_rows=1, 788 | seed=seed 789 | ) 790 | 791 | class Level_KeyCorridorS3R2(KeyCorridor): 792 | def __init__(self, seed=None): 793 | super().__init__( 794 | room_size=3, 795 | num_rows=2, 796 | seed=seed 797 | ) 798 | 799 | class Level_KeyCorridorS3R3(KeyCorridor): 800 | def __init__(self, seed=None): 801 | super().__init__( 802 | room_size=3, 803 | num_rows=3, 804 | seed=seed 805 | ) 806 | 807 | class Level_KeyCorridorS4R3(KeyCorridor): 808 | def __init__(self, seed=None): 809 | super().__init__( 810 | room_size=4, 811 | num_rows=3, 812 | seed=seed 813 | ) 814 | 815 | class Level_KeyCorridorS5R3(KeyCorridor): 816 | def __init__(self, seed=None): 817 | super().__init__( 818 | room_size=5, 819 | num_rows=3, 820 | seed=seed 821 | ) 822 | 823 | class Level_KeyCorridorS6R3(KeyCorridor): 824 | def __init__(self, seed=None): 825 | super().__init__( 826 | room_size=6, 827 | num_rows=3, 828 | seed=seed 829 | ) 830 | 831 | class Level_1RoomS8(RoomGridLevel): 832 | """ 833 | Pick up the ball 834 | Rooms have a size of 8 835 | """ 836 | 837 | def __init__(self, room_size=8, seed=None): 838 | super().__init__( 839 | room_size=room_size, 840 | num_rows=1, 841 | num_cols=1, 842 | seed=seed 843 | ) 844 | 845 | def gen_mission(self): 846 | obj, _ = self.add_object(0, 0, kind="ball") 847 | self.place_agent() 848 | self.instrs = [Instr(action="pickup", object=Object(obj.type))] 849 | 850 | 851 | class Level_1RoomS12(Level_1RoomS8): 852 | """ 853 | Pick up the ball 854 | Rooms have a size of 12 855 | """ 856 | 857 | def __init__(self, seed=None): 858 | super().__init__( 859 | room_size=12, 860 | seed=seed 861 | ) 862 | 863 | 864 | class Level_1RoomS16(Level_1RoomS8): 865 | """ 866 | Pick up the ball 867 | Rooms have a size of 16 868 | """ 869 | 870 | def __init__(self, seed=None): 871 | super().__init__( 872 | room_size=16, 873 | seed=seed 874 | ) 875 | 876 | 877 | class Level_1RoomS20(Level_1RoomS8): 878 | """ 879 | Pick up the ball 880 | Rooms have a size of 20 881 | """ 882 | 883 | def __init__(self, seed=None): 884 | super().__init__( 885 | room_size=20, 886 | seed=seed 887 | ) 888 | 889 | 890 | # Dictionary of levels, indexed by name, lexically sorted 891 | level_dict = OrderedDict() 892 | 893 | # Iterate through global names 894 | for global_name in sorted(list(globals().keys())): 895 | if not global_name.startswith('Level_'): 896 | continue 897 | 898 | module_name = __name__ 899 | level_name = global_name.split('Level_')[-1] 900 | level_class = globals()[global_name] 901 | 902 | # Register the levels with OpenAI Gym 903 | gym_id = 'BabyAI-%s-v0' % (level_name) 904 | entry_point = '%s:%s' % (module_name, global_name) 905 | gym.envs.registration.register( 906 | id=gym_id, 907 | entry_point=entry_point, 908 | ) 909 | 910 | # Add the level to the dictionary 911 | level_dict[level_name] = level_class 912 | 913 | # Store the name and gym id on the level class 914 | level_class.level_name = level_name 915 | level_class.gym_id = gym_id 916 | 917 | 918 | def test(): 919 | for idx, level_name in enumerate(level_dict.keys()): 920 | print('Level %s (%d/%d)' % (level_name, idx+1, len(level_dict))) 921 | 922 | level = level_dict[level_name] 923 | 924 | # Run the mission for a few episodes 925 | rng = random.Random(0) 926 | num_episodes = 0 927 | for i in range(0, 20): 928 | mission = level(seed=i) 929 | assert isinstance(mission.surface, str) 930 | assert len(mission.surface) > 0 931 | 932 | obs = mission.reset() 933 | assert obs['mission'] == mission.surface 934 | 935 | while True: 936 | action = rng.randint(0, mission.action_space.n - 1) 937 | obs, reward, done, info = mission.step(action) 938 | if done: 939 | obs = mission.reset() 940 | break 941 | 942 | num_episodes += 1 943 | 944 | # The same seed should always yield the same mission 945 | m0 = level(seed=0) 946 | m1 = level(seed=0) 947 | grid1 = m0.unwrapped.grid 948 | grid2 = m1.unwrapped.grid 949 | assert grid1 == grid2 950 | assert m0.surface == m1.surface 951 | -------------------------------------------------------------------------------- /babyai/levels/roomgrid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import gym_minigrid 3 | from gym_minigrid.minigrid import * 4 | 5 | 6 | def reject_next_to(env, pos): 7 | """ 8 | Function to filter out object positions that are right next to 9 | the agent's starting point 10 | """ 11 | 12 | sx, sy = env.start_pos 13 | x, y = pos 14 | d = abs(sx - x) + abs(sy - y) 15 | return d < 2 16 | 17 | 18 | class Room: 19 | def __init__( 20 | self, 21 | top, 22 | size 23 | ): 24 | # Top-left corner and size (tuples) 25 | self.top = top 26 | self.size = size 27 | 28 | # List of door objects and door positions 29 | # Order of the doors is right, down, left, up 30 | self.doors = [None] * 4 31 | self.door_pos = [None] * 4 32 | 33 | # List of rooms adjacent to this one 34 | # Order of the neighbors is right, down, left, up 35 | self.neighbors = [None] * 4 36 | 37 | # Indicates if this room is behind a locked door 38 | self.locked = False 39 | 40 | # List of objects contained 41 | self.objs = [] 42 | 43 | def rand_pos(self, env): 44 | topX, topY = self.top 45 | sizeX, sizeY = self.size 46 | return env._randPos( 47 | topX + 1, topX + sizeX - 1, 48 | topY + 1, topY + sizeY - 1 49 | ) 50 | 51 | 52 | class RoomGrid(MiniGridEnv): 53 | """ 54 | Environment with multiple rooms and random objects. 55 | This is meant to serve as a base class for other environments. 56 | """ 57 | 58 | def __init__( 59 | self, 60 | room_size=7, 61 | num_rows=3, 62 | num_cols=3, 63 | max_steps=100, 64 | seed=0 65 | ): 66 | assert room_size > 0 67 | assert room_size >= 3 68 | assert num_rows > 0 69 | assert num_cols > 0 70 | self.room_size = room_size 71 | self.num_rows = num_rows 72 | self.num_cols = num_cols 73 | 74 | height = (room_size - 1) * num_rows + 1 75 | width = (room_size - 1) * num_cols + 1 76 | grid_size = max(width, height) 77 | 78 | # By default, this environment has no mission 79 | self.mission = '' 80 | 81 | super().__init__( 82 | grid_size=grid_size, 83 | max_steps=max_steps, 84 | see_through_walls=False, 85 | seed=seed 86 | ) 87 | 88 | def room_from_pos(self, x, y): 89 | """Get the room a given position maps to""" 90 | 91 | assert x >= 0 92 | assert y >= 0 93 | 94 | i = x // self.room_size 95 | j = y // self.room_size 96 | 97 | assert i < self.num_cols 98 | assert j < self.num_rows 99 | 100 | return self.room_grid[j][i] 101 | 102 | def get_room(self, i, j): 103 | assert i < self.num_cols 104 | assert j < self.num_rows 105 | return self.room_grid[j][i] 106 | 107 | def _gen_grid(self, width, height): 108 | # Create the grid 109 | self.grid = Grid(width, height) 110 | 111 | self.room_grid = [] 112 | 113 | # For each row of rooms 114 | for j in range(0, self.num_rows): 115 | row = [] 116 | 117 | # For each column of rooms 118 | for i in range(0, self.num_cols): 119 | room = Room( 120 | (i * (self.room_size-1), j * (self.room_size-1)), 121 | (self.room_size, self.room_size) 122 | ) 123 | row.append(room) 124 | 125 | # Generate the walls for this room 126 | self.grid.wall_rect(*room.top, *room.size) 127 | 128 | self.room_grid.append(row) 129 | 130 | # For each row of rooms 131 | for j in range(0, self.num_rows): 132 | # For each column of rooms 133 | for i in range(0, self.num_cols): 134 | room = self.room_grid[j][i] 135 | 136 | x_l, y_l = (room.top[0] + 1, room.top[1] + 1) 137 | x_m, y_m = (room.top[0] + room.size[0] - 1, room.top[1] + room.size[1] - 1) 138 | 139 | # Door positions, order is right, down, left, up 140 | if i < self.num_cols - 1: 141 | room.neighbors[0] = self.room_grid[j][i+1] 142 | room.door_pos[0] = (x_m, self._rand_int(y_l, y_m)) 143 | if j < self.num_rows - 1: 144 | room.neighbors[1] = self.room_grid[j+1][i] 145 | room.door_pos[1] = (self._rand_int(x_l, x_m), y_m) 146 | if i > 0: 147 | room.neighbors[2] = self.room_grid[j][i-1] 148 | room.door_pos[2] = room.neighbors[2].door_pos[0] 149 | if j > 0: 150 | room.neighbors[3] = self.room_grid[j-1][i] 151 | room.door_pos[3] = room.neighbors[3].door_pos[1] 152 | 153 | # The agent starts in the middle, facing right 154 | self.start_pos = ( 155 | (self.num_cols // 2) * (self.room_size-1) + (self.room_size // 2), 156 | (self.num_rows // 2) * (self.room_size-1) + (self.room_size // 2) 157 | ) 158 | self.start_dir = 0 159 | 160 | def place_in_room(self, i, j, obj): 161 | """ 162 | Add an existing object to room (i, j) 163 | """ 164 | 165 | room = self.get_room(i, j) 166 | 167 | pos = self.place_obj( 168 | obj, 169 | room.top, 170 | room.size, 171 | reject_fn=reject_next_to 172 | ) 173 | 174 | room.objs.append(obj) 175 | 176 | return obj, pos 177 | 178 | def add_object(self, i, j, kind=None, color=None): 179 | """ 180 | Add a new object to room (i, j) 181 | """ 182 | 183 | if kind == None: 184 | kind = self._rand_elem(['key', 'ball', 'box']) 185 | 186 | if color == None: 187 | color = self._rand_color() 188 | 189 | # TODO: we probably want to add an Object.make helper function 190 | assert kind in ['key', 'ball', 'box'] 191 | if kind == 'key': 192 | obj = Key(color) 193 | elif kind == 'ball': 194 | obj = Ball(color) 195 | elif kind == 'box': 196 | obj = Box(color) 197 | 198 | return self.place_in_room(i, j, obj) 199 | 200 | def add_door(self, i, j, door_idx=None, color=None, locked=None): 201 | """ 202 | Add a door to a room, connecting it to a neighbor 203 | """ 204 | 205 | room = self.get_room(i, j) 206 | 207 | if door_idx == None: 208 | # Need to make sure that there is a neighbor along this wall 209 | # and that there is not already a door 210 | while True: 211 | door_idx = self._rand_int(0, 4) 212 | if room.neighbors[door_idx] and room.doors[door_idx] is None: 213 | break 214 | 215 | if color == None: 216 | color = self._rand_color() 217 | 218 | if locked is None: 219 | locked = self._rand_bool() 220 | 221 | assert room.doors[door_idx] is None, "door already exists" 222 | 223 | if locked: 224 | door = LockedDoor(color) 225 | room.locked = True 226 | else: 227 | door = Door(color) 228 | 229 | pos = room.door_pos[door_idx] 230 | self.grid.set(*pos, door) 231 | 232 | neighbor = room.neighbors[door_idx] 233 | room.doors[door_idx] = door 234 | neighbor.doors[(door_idx+2) % 4] = door 235 | 236 | return door, pos 237 | 238 | def remove_wall(self, i, j, wall_idx): 239 | """ 240 | Remove a wall between two rooms 241 | """ 242 | 243 | room = self.get_room(i, j) 244 | 245 | assert wall_idx >= 0 and wall_idx < 4 246 | assert room.doors[wall_idx] is None, "door exists on this wall" 247 | assert room.neighbors[wall_idx], "invalid wall" 248 | 249 | neighbor = room.neighbors[wall_idx] 250 | 251 | tx, ty = room.top 252 | w, h = room.size 253 | 254 | # Ordering of walls is right, down, left, up 255 | if wall_idx == 0: 256 | for i in range(1, h - 1): 257 | self.grid.set(tx + w - 1, ty + i, None) 258 | elif wall_idx == 1: 259 | for i in range(1, w - 1): 260 | self.grid.set(tx + i, ty + h - 1, None) 261 | elif wall_idx == 2: 262 | for i in range(1, h - 1): 263 | self.grid.set(tx, ty + i, None) 264 | elif wall_idx == 3: 265 | for i in range(1, w - 1): 266 | self.grid.set(tx + i, ty, None) 267 | else: 268 | assert False, "invalid wall index" 269 | 270 | # Mark the rooms as connected 271 | room.doors[wall_idx] = True 272 | neighbor.doors[(wall_idx+2) % 4] = True 273 | 274 | def place_agent(self, i=None, j=None, rand_dir=True): 275 | """ 276 | Place the agent in a room 277 | """ 278 | 279 | if i == None: 280 | i = self._rand_int(0, self.num_cols) 281 | if j == None: 282 | j = self._rand_int(0, self.num_rows) 283 | 284 | room = self.room_grid[j][i] 285 | 286 | # Find a position that is not right in front of an object 287 | while True: 288 | super().place_agent(room.top, room.size, rand_dir) 289 | pos = self.start_pos 290 | dir = DIR_TO_VEC[self.start_dir] 291 | front_pos = pos + dir 292 | front_cell = self.grid.get(*front_pos) 293 | if front_cell is None or front_cell.type is 'wall': 294 | break 295 | 296 | def connect_all(self): 297 | """ 298 | Make sure that all rooms are reachable by the agent from its 299 | starting position 300 | """ 301 | 302 | start_room = self.room_from_pos(*self.start_pos) 303 | 304 | added_doors = [] 305 | 306 | def find_reach(): 307 | reach = set() 308 | stack = [start_room] 309 | while len(stack) > 0: 310 | room = stack.pop() 311 | if room in reach: 312 | continue 313 | reach.add(room) 314 | for i in range(0, 4): 315 | if room.doors[i]: 316 | stack.append(room.neighbors[i]) 317 | return reach 318 | 319 | while True: 320 | # If all rooms are reachable, stop 321 | reach = find_reach() 322 | if len(reach) == self.num_rows * self.num_cols: 323 | break 324 | 325 | # Pick a random room and door position 326 | i = self._rand_int(0, self.num_cols) 327 | j = self._rand_int(0, self.num_rows) 328 | k = self._rand_int(0, 4) 329 | room = self.get_room(i, j) 330 | 331 | # If there is already a door there, skip 332 | if not room.door_pos[k] or room.doors[k]: 333 | continue 334 | 335 | if room.locked or room.neighbors[k].locked: 336 | continue 337 | 338 | color = self._rand_elem(COLOR_NAMES) 339 | door, _ = self.add_door(i, j, k, color, False) 340 | added_doors.append(door) 341 | 342 | return added_doors 343 | 344 | def add_distractors(self, num_distractors=10, room_i=None, room_j=None): 345 | """ 346 | Add random objects that can potentially distract/confuse the agent. 347 | """ 348 | 349 | # Collect a list of existing objects 350 | objs = [] 351 | for row in self.room_grid: 352 | for room in row: 353 | for obj in room.objs: 354 | objs.append((obj.type, obj.color)) 355 | 356 | num_added = 0 357 | while num_added < num_distractors: 358 | color = self._rand_elem(COLOR_NAMES) 359 | type = self._rand_elem(['key', 'ball', 'box']) 360 | obj = (type, color) 361 | 362 | if obj in objs: 363 | continue 364 | 365 | # Add the object to a random room if no room specified 366 | if room_i == None: 367 | i = self._rand_int(0, self.num_cols) 368 | else: 369 | i = room_i 370 | if room_j == None: 371 | j = self._rand_int(0, self.num_rows) 372 | else: 373 | j = room_j 374 | self.add_object(i, j, *obj) 375 | 376 | objs.append(obj) 377 | num_added += 1 378 | 379 | return objs 380 | -------------------------------------------------------------------------------- /babyai/levels/verifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from abc import ABC, abstractmethod 3 | from collections import namedtuple 4 | 5 | State = namedtuple("State", ["dir", "pos", "carry"]) 6 | 7 | 8 | def dot_product(v1, v2): 9 | """ 10 | Compute the dot product of the vectors v1 and v2. 11 | """ 12 | 13 | return sum([i*j for i, j in zip(v1, v2)]) 14 | 15 | 16 | class Verifier(ABC): 17 | def __init__(self, env): 18 | self.env = env 19 | self.startDirVec = env.dir_vec 20 | 21 | @abstractmethod 22 | def step(self): 23 | """ 24 | Update verifier's internal state and returns true 25 | iff the agent did what he was expected to. 26 | """ 27 | 28 | return 29 | 30 | def _obj_desc_to_poss(self, obj_desc): 31 | """ 32 | Get the positions of all the objects that match the description. 33 | """ 34 | 35 | poss = [] 36 | 37 | for i in range(self.env.grid.width): 38 | for j in range(self.env.grid.height): 39 | cell = self.env.grid.get(i, j) 40 | if cell == None: 41 | continue 42 | 43 | if cell.type == "locked_door": 44 | type = "door" 45 | state = "locked" 46 | else: 47 | type = cell.type 48 | state = None 49 | 50 | # Check if object's type matches description 51 | if obj_desc.type != None and type != obj_desc.type: 52 | continue 53 | 54 | # Check if object's state matches description 55 | if obj_desc.state != None and state != obj_desc.state: 56 | continue 57 | 58 | # Check if object's color matches description 59 | if obj_desc.color != None and cell.color != obj_desc.color: 60 | continue 61 | 62 | # Check if object's position matches description 63 | if obj_desc.loc in ["left", "right", "front", "behind"]: 64 | # Direction from the agent to the object 65 | v = (i-self.env.start_pos[0], j-self.env.start_pos[1]) 66 | 67 | # (d1, d2) is an oriented orthonormal basis 68 | d1 = self.startDirVec 69 | d2 = (-d1[1], d1[0]) 70 | 71 | # Check if object's position matches with location 72 | pos_matches = { 73 | "left": dot_product(v, d2) < 0, 74 | "right": dot_product(v, d2) > 0, 75 | "front": dot_product(v, d1) > 0, 76 | "behind": dot_product(v, d1) < 0 77 | } 78 | 79 | if not(pos_matches[obj_desc.loc]): 80 | continue 81 | 82 | poss.append((i, j)) 83 | 84 | return poss 85 | 86 | def _get_in_front_of_pos(self): 87 | """ 88 | Get the position in front of the agent. 89 | The agent's state is the 2-tuple (agent_dir, agent_pos). 90 | """ 91 | pos = self.env.agent_pos 92 | d = self.env.dir_vec 93 | pos = (pos[0] + d[0], pos[1] + d[1]) 94 | 95 | return pos 96 | 97 | 98 | class InstrSeqVerifier(Verifier): 99 | def __init__(self, env, instr): 100 | super().__init__(env) 101 | 102 | self.instr = instr 103 | self.instr_index = 0 104 | 105 | self.obj_to_drop = None 106 | self.intermediary_state = None 107 | 108 | self._load_next_verifier() 109 | 110 | def step(self): 111 | if self.verifier != None and self.verifier.step(): 112 | self._close_verifier() 113 | self._load_next_verifier() 114 | return self.verifier == None 115 | 116 | def _load_next_verifier(self): 117 | if self.instr_index >= len(self.instr): 118 | return 119 | 120 | instr = self.instr[self.instr_index] 121 | self.instr_index += 1 122 | 123 | if instr.action == "open": 124 | self.verifier = OpenVerifier(self.env, instr.object) 125 | elif instr.action == "goto": 126 | self.verifier = GotoVerifier(self.env, instr.object) 127 | elif instr.action == "pickup": 128 | self.verifier = PickupVerifier(self.env, instr.object) 129 | elif instr.action == "drop": 130 | self.verifier = DropVerifier(self.env, self.obj_to_drop) 131 | 132 | self.verifier.state = self.intermediary_state 133 | 134 | def _close_verifier(self): 135 | if isinstance(self.verifier, PickupVerifier): 136 | self.obj_to_drop = self.verifier.state.carry 137 | 138 | self.intermediary_state = self.verifier.state 139 | 140 | self.verifier = None 141 | 142 | 143 | class InstrVerifier(Verifier): 144 | def __init__(self, env): 145 | super().__init__(env) 146 | 147 | self.previous_state = None 148 | self.state = None 149 | 150 | def step(self): 151 | """ 152 | Update verifier's internal state and returns true 153 | iff the agent did what he was expected to. 154 | """ 155 | 156 | self.previous_state = self.state 157 | self.state = State( 158 | dir=self.env.agent_dir, 159 | pos=self.env.agent_pos, 160 | carry=self.env.carrying 161 | ) 162 | 163 | return self._done() 164 | 165 | @abstractmethod 166 | def _done(self): 167 | """ 168 | Check if the agent did what he was expected to. 169 | """ 170 | 171 | return 172 | 173 | 174 | class GotoVerifier(InstrVerifier): 175 | def __init__(self, env, obj): 176 | super().__init__(env) 177 | 178 | self.obj_poss = self._obj_desc_to_poss(obj) 179 | self.obj_cells = [self.env.grid.get(*pos) for pos in self.obj_poss] 180 | 181 | def _done(self): 182 | on_cell = self.env.grid.get(*self.state.pos) 183 | ifo_pos = self._get_in_front_of_pos() 184 | ifo_cell = self.env.grid.get(*ifo_pos) 185 | 186 | check_on_goal = on_cell != None and on_cell.type == "goal" 187 | check_goto_goal = check_on_goal and on_cell in self.obj_cells 188 | 189 | check_not_ifo_goal = ifo_cell == None or ifo_cell.type != "goal" 190 | check_goto_not_goal = check_not_ifo_goal and ifo_cell in self.obj_cells 191 | 192 | return check_goto_goal or check_goto_not_goal 193 | 194 | 195 | class PickupVerifier(InstrVerifier): 196 | def __init__(self, env, obj): 197 | super().__init__(env) 198 | 199 | self.obj_poss = self._obj_desc_to_poss(obj) 200 | self.obj_cells = [self.env.grid.get(*pos) for pos in self.obj_poss] 201 | 202 | def _done(self): 203 | check_wasnt_carrying = self.previous_state == None or self.previous_state.carry == None 204 | check_carrying = self.state.carry in self.obj_cells 205 | 206 | return check_wasnt_carrying and check_carrying 207 | 208 | 209 | class OpenVerifier(InstrVerifier): 210 | def __init__(self, env, obj): 211 | super().__init__(env) 212 | 213 | self.obj_poss = self._obj_desc_to_poss(obj) 214 | self.obj_cells = [self.env.grid.get(*pos) for pos in self.obj_poss] 215 | 216 | def _done(self): 217 | ifo_pos = self._get_in_front_of_pos() 218 | ifo_cell = self.env.grid.get(*ifo_pos) 219 | 220 | check_opened = ifo_cell in self.obj_cells and ifo_cell.is_open 221 | 222 | return check_opened 223 | 224 | 225 | class DropVerifier(InstrVerifier): 226 | def __init__(self, env, obj_to_drop): 227 | super().__init__(env) 228 | 229 | self.obj_to_drop = obj_to_drop 230 | 231 | def _done(self): 232 | check_was_carrying = self.previous_state != None and self.previous_state.carry == self.obj_to_drop 233 | check_isnt_carrying = self.state.carry == None 234 | 235 | return check_was_carrying and check_isnt_carrying 236 | -------------------------------------------------------------------------------- /babyai/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions.categorical import Categorical 6 | import torch_rl 7 | 8 | 9 | # Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py 10 | def initialize_parameters(m): 11 | classname = m.__class__.__name__ 12 | if classname.find('Linear') != -1: 13 | m.weight.data.normal_(0, 1) 14 | m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True)) 15 | if m.bias is not None: 16 | 17 | m.bias.data.fill_(0) 18 | 19 | 20 | class ACModel(nn.Module, torch_rl.RecurrentACModel): 21 | def __init__(self, obs_space, action_space, use_instr=None, use_memory=False, arch="cnn1"): 22 | super().__init__() 23 | 24 | # Decide which components are enabled 25 | self.use_instr = use_instr 26 | self.use_memory = use_memory 27 | 28 | self.obs_space = obs_space 29 | 30 | # Define image embedding 31 | self.image_embedding_size = 64 32 | if arch == "cnn1": 33 | self.image_conv = nn.Sequential( 34 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(2, 2)), 35 | nn.ReLU(), 36 | nn.MaxPool2d(kernel_size=(2, 2), stride=2), 37 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(2, 2)), 38 | nn.ReLU(), 39 | nn.Conv2d(in_channels=32, out_channels=self.image_embedding_size, kernel_size=(2, 2)), 40 | nn.ReLU() 41 | ) 42 | elif arch == "cnn2": 43 | self.image_conv = nn.Sequential( 44 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3)), 45 | nn.ReLU(), 46 | nn.MaxPool2d(kernel_size=(2, 2), stride=2, ceil_mode=True), 47 | nn.Conv2d(in_channels=16, out_channels=self.image_embedding_size, kernel_size=(3, 3)), 48 | nn.ReLU() 49 | ) 50 | else: 51 | raise ValueError("Incorrect architecture name: {}".format(arch)) 52 | 53 | # Define memory 54 | if self.use_memory: 55 | self.memory_rnn = nn.LSTMCell(self.image_embedding_size, self.semi_memory_size) 56 | 57 | # Define instruction embedding 58 | if self.use_instr == 'gru' or self.use_instr == 'conv': 59 | self.word_embedding_size = 32 60 | self.word_embedding = nn.Embedding(obs_space["instr"], self.word_embedding_size) 61 | if self.use_instr == 'gru': 62 | self.instr_embedding_size = 128 63 | self.instr_rnn = nn.GRU(self.word_embedding_size, self.instr_embedding_size, batch_first=True) 64 | else: 65 | kernel_dim = 64 66 | kernel_sizes = [3,4] 67 | self.instr_convs = nn.ModuleList([nn.Conv2d(1, kernel_dim, (K, self.word_embedding_size)) for K in kernel_sizes]) 68 | self.instr_embedding_size = kernel_dim * len(kernel_sizes) 69 | 70 | elif self.use_instr == 'bow': 71 | self.instr_embedding_size = 128 72 | hidden_units = [obs_space["instr"], 64, self.instr_embedding_size] 73 | layers = [] 74 | for n_in, n_out in zip(hidden_units, hidden_units[1:]): 75 | layers.append(nn.Linear(n_in, n_out)) 76 | layers.append(nn.ReLU()) 77 | self.instr_bow = nn.Sequential(*layers) 78 | 79 | # Resize image embedding 80 | self.embedding_size = self.semi_memory_size 81 | if self.use_instr is not None: 82 | self.embedding_size += self.instr_embedding_size 83 | 84 | # Define actor's model 85 | self.actor = nn.Sequential( 86 | nn.Linear(self.embedding_size, 64), 87 | nn.Tanh(), 88 | nn.Linear(64, action_space.n) 89 | ) 90 | 91 | # Define critic's model 92 | self.critic = nn.Sequential( 93 | nn.Linear(self.embedding_size, 64), 94 | nn.Tanh(), 95 | nn.Linear(64, 1) 96 | ) 97 | 98 | # Initialize parameters correctly 99 | self.apply(initialize_parameters) 100 | 101 | @property 102 | def memory_size(self): 103 | return 2 * self.semi_memory_size 104 | 105 | @property 106 | def semi_memory_size(self): 107 | return self.image_embedding_size 108 | 109 | def forward(self, obs, memory): 110 | if self.use_instr is not None: 111 | embed_instr = self._get_embed_instr(obs.instr) 112 | 113 | x = torch.transpose(torch.transpose(obs.image, 1, 3), 2, 3) 114 | x = self.image_conv(x) 115 | x = x.reshape(x.shape[0], -1) 116 | 117 | if self.use_memory: 118 | hidden = (memory[:, :self.semi_memory_size], memory[:, self.semi_memory_size:]) 119 | hidden = self.memory_rnn(x, hidden) 120 | embedding = hidden[0] 121 | memory = torch.cat(hidden, dim=1) 122 | else: 123 | embedding = x 124 | 125 | if self.use_instr is not None: 126 | embedding = torch.cat((embedding, embed_instr), dim=1) 127 | 128 | x = self.actor(embedding) 129 | dist = Categorical(logits=F.log_softmax(x, dim=1)) 130 | 131 | x = self.critic(embedding) 132 | value = x.squeeze(1) 133 | 134 | return dist, value, memory 135 | 136 | def _get_embed_instr(self, instr): 137 | if self.use_instr == 'gru': 138 | self.instr_rnn.flatten_parameters() 139 | _, hidden = self.instr_rnn(self.word_embedding(instr)) 140 | return hidden[-1] 141 | elif self.use_instr == 'conv': 142 | inputs = self.word_embedding(instr).unsqueeze(1) # (B,1,T,D) 143 | inputs = [F.relu(conv(inputs)).squeeze(3) for conv in self.instr_convs] 144 | inputs = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in inputs] 145 | 146 | return torch.cat(inputs, 1) 147 | elif self.use_instr == 'bow': 148 | # self.instr_bow.flatten_parameters() 149 | device = torch.device("cuda" if instr.is_cuda else "cpu") 150 | input_dim = self.obs_space["instr"] 151 | input = torch.zeros((instr.size(0), input_dim), device=device) 152 | idx = torch.arange(instr.size(0), dtype=torch.int64) 153 | input[idx.unsqueeze(1), instr] = 1. 154 | return self.instr_bow(input) 155 | else: 156 | ValueError("Undefined instruction architecture: {}".format(self.use_instr)) 157 | -------------------------------------------------------------------------------- /babyai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | import random 4 | import numpy 5 | import torch 6 | from babyai.utils.agent import load_agent 7 | from babyai.utils.demos import load_demos, save_demos, synthesize_demos 8 | from babyai.utils.format import ObssPreprocessor, reshape_reward 9 | from babyai.utils.log import get_log_dir, synthesize, get_logger 10 | from babyai.utils.model import get_model_dir, load_model, save_model 11 | 12 | 13 | def storage_dir(): 14 | # defines the storage directory to be in the root (Same level as babyai folder) 15 | if "BABYAI_STORAGE" in os.environ: 16 | return os.environ["BABYAI_STORAGE"] 17 | current_directory = os.path.dirname(os.path.realpath(__file__)) 18 | parent_directory = os.path.dirname(current_directory) 19 | root_directory = os.path.dirname(parent_directory) 20 | return os.path.join(root_directory, 'storage') 21 | 22 | 23 | def create_folders_if_necessary(path): 24 | dirname = os.path.dirname(path) 25 | if not(os.path.isdir(dirname)): 26 | os.makedirs(dirname) 27 | 28 | 29 | def seed(seed): 30 | random.seed(seed) 31 | numpy.random.seed(seed) 32 | torch.manual_seed(seed) 33 | if torch.cuda.is_available(): 34 | torch.cuda.manual_seed_all(seed) 35 | -------------------------------------------------------------------------------- /babyai/utils/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from abc import ABC, abstractmethod 3 | import torch 4 | 5 | from .. import utils 6 | 7 | 8 | class Agent(ABC): 9 | @abstractmethod 10 | def get_action(self, obs): 11 | pass 12 | 13 | @abstractmethod 14 | def analyze_feedback(self, reward, done): 15 | pass 16 | 17 | 18 | class ModelAgent(Agent): 19 | def __init__(self, model_name, observation_space, deterministic): 20 | self.obss_preprocessor = utils.ObssPreprocessor(model_name, observation_space) 21 | self.model = utils.load_model(model_name) 22 | self.deterministic = deterministic 23 | 24 | if self.model.recurrent: 25 | self._initialize_memory() 26 | 27 | def _initialize_memory(self): 28 | self.memory = torch.zeros(1, self.model.memory_size) 29 | 30 | def get_action(self, obs): 31 | preprocessed_obs = self.obss_preprocessor([obs]) 32 | 33 | with torch.no_grad(): 34 | if self.model.recurrent: 35 | dist, _, self.memory = self.model(preprocessed_obs, self.memory) 36 | else: 37 | dist, _ = self.model(preprocessed_obs) 38 | 39 | if self.deterministic: 40 | action = dist.probs.max(1, keepdim=True)[1] 41 | else: 42 | action = dist.sample() 43 | 44 | return action.item() 45 | 46 | def analyze_feedback(self, reward, done): 47 | if done and self.model.recurrent: 48 | self._initialize_memory() 49 | 50 | 51 | class DemoAgent(Agent): 52 | def __init__(self, env_name, origin): 53 | self.demos = utils.load_demos(env_name, origin) 54 | self.demo_id = 0 55 | self.step_id = 0 56 | 57 | @staticmethod 58 | def check_obss_equality(obs1, obs2): 59 | if not(obs1.keys() == obs2.keys()): 60 | return False 61 | for key in obs1.keys(): 62 | if type(obs1[key]) in (str, int): 63 | if not(obs1[key] == obs2[key]): 64 | return False 65 | else: 66 | if not (obs1[key] == obs2[key]).all(): 67 | return False 68 | return True 69 | 70 | def get_action(self, obs): 71 | if self.demo_id >= len(self.demos): 72 | raise ValueError("No demonstration remaining") 73 | 74 | expected_obs = self.demos[self.demo_id][self.step_id][0] 75 | assert DemoAgent.check_obss_equality(obs, expected_obs), "The observations do not match" 76 | 77 | return self.demos[self.demo_id][self.step_id][1] 78 | 79 | def analyze_feedback(self, reward, done): 80 | self.step_id += 1 81 | 82 | if done: 83 | self.demo_id += 1 84 | self.step_id = 0 85 | 86 | 87 | def load_agent(args, env): 88 | if args.model is not None: 89 | return ModelAgent(args.model, env.observation_space, args.deterministic) 90 | elif args.demos_origin is not None: 91 | return DemoAgent(args.env, args.demos_origin) -------------------------------------------------------------------------------- /babyai/utils/demos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | import pickle 4 | 5 | from .. import utils 6 | 7 | 8 | def get_demos_path(env_name, origin): 9 | return os.path.join(utils.storage_dir(), 'demos', env_name+"_"+origin+".pkl") 10 | 11 | 12 | def load_demos(env_name, origin): 13 | path = get_demos_path(env_name, origin) 14 | if os.path.exists(path): 15 | return pickle.load(open(path, "rb")) 16 | return [] 17 | 18 | 19 | def save_demos(demos, env_name, origin): 20 | path = get_demos_path(env_name, origin) 21 | utils.create_folders_if_necessary(path) 22 | pickle.dump(demos, open(path, "wb")) 23 | 24 | 25 | def synthesize_demos(demos): 26 | print('{} demonstrations saved'.format(len(demos))) 27 | num_frames_per_episode = [len(demo) for demo in demos] 28 | if len(demos) > 0: 29 | print('Demo num frames: {}'.format(num_frames_per_episode)) 30 | -------------------------------------------------------------------------------- /babyai/utils/format.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | import json 4 | import numpy 5 | import re 6 | import torch 7 | import torch_rl 8 | 9 | from .. import utils 10 | 11 | 12 | def get_vocab_path(model_name): 13 | return os.path.join(utils.get_model_dir(model_name), "vocab.json") 14 | 15 | 16 | class Vocabulary: 17 | def __init__(self, model_name): 18 | self.path = get_vocab_path(model_name) 19 | self.max_size = 100 20 | self.vocab = {} 21 | if os.path.exists(self.path): 22 | self.vocab = json.load(open(self.path)) 23 | 24 | def __getitem__(self, token): 25 | if not(token in self.vocab.keys()): 26 | if len(self.vocab) >= self.max_size: 27 | raise ValueError("Maximum vocabulary capacity reached") 28 | self.vocab[token] = len(self.vocab) + 1 29 | return self.vocab[token] 30 | 31 | def save(self): 32 | utils.create_folders_if_necessary(self.path) 33 | json.dump(self.vocab, open(self.path, "w")) 34 | 35 | 36 | class ObssPreprocessor: 37 | def __init__(self, model_name, obs_space): 38 | self.vocab = Vocabulary(model_name) 39 | self.obs_space = { 40 | "image": 147, 41 | "instr": self.vocab.max_size 42 | } 43 | 44 | def __call__(self, obss, device=None): 45 | obs_ = torch_rl.DictList() 46 | 47 | if "image" in self.obs_space.keys(): 48 | images = numpy.array([obs["image"] for obs in obss]) 49 | images = torch.tensor(images, device=device, dtype=torch.float) 50 | 51 | obs_.image = images 52 | 53 | if "instr" in self.obs_space.keys(): 54 | raw_instrs = [] 55 | max_instr_len = 0 56 | 57 | for obs in obss: 58 | tokens = re.findall("([a-z]+)", obs["mission"].lower()) 59 | instr = numpy.array([self.vocab[token] for token in tokens]) 60 | raw_instrs.append(instr) 61 | max_instr_len = max(len(instr), max_instr_len) 62 | 63 | instrs = numpy.zeros((len(obss), max_instr_len)) 64 | 65 | for i, instr in enumerate(raw_instrs): 66 | instrs[i, :len(instr)] = instr 67 | 68 | instrs = torch.tensor(instrs, device=device, dtype=torch.long) 69 | 70 | obs_.instr = instrs 71 | 72 | return obs_ 73 | 74 | 75 | def reshape_reward(obs, action, reward, done): 76 | return 5 * reward 77 | -------------------------------------------------------------------------------- /babyai/utils/log.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | import sys 4 | import numpy 5 | import logging 6 | from collections import OrderedDict 7 | 8 | from .. import utils 9 | 10 | 11 | def get_log_dir(log_name): 12 | return os.path.join(utils.storage_dir(), "logs", log_name) 13 | 14 | 15 | def get_log_path(log_name): 16 | return os.path.join(get_log_dir(log_name), "log.log") 17 | 18 | 19 | def synthesize(array): 20 | stats = OrderedDict() 21 | stats['mean'] = numpy.mean(array) 22 | stats['std'] = numpy.std(array) 23 | stats['min'] = numpy.min(array) 24 | stats['max'] = numpy.max(array) 25 | return stats 26 | '''return { 27 | "mean": numpy.mean(array), 28 | "std": numpy.std(array), 29 | "min": numpy.amin(array), 30 | "max": numpy.amax(array) 31 | }''' 32 | 33 | 34 | def get_logger(log_name): 35 | path = get_log_path(log_name) 36 | utils.create_folders_if_necessary(path) 37 | 38 | logging.basicConfig( 39 | level=logging.INFO, 40 | format="%(message)s", 41 | handlers=[ 42 | logging.FileHandler(filename=path), 43 | logging.StreamHandler(sys.stdout) 44 | ] 45 | ) 46 | 47 | return logging.getLogger() 48 | -------------------------------------------------------------------------------- /babyai/utils/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | import torch 4 | 5 | from .. import utils 6 | 7 | 8 | def get_model_dir(model_name): 9 | return os.path.join(utils.storage_dir(), "models", model_name) 10 | 11 | 12 | def get_model_path(model_name): 13 | return os.path.join(get_model_dir(model_name), "model.pt") 14 | 15 | 16 | def load_model(model_name, raise_not_found=True): 17 | path = get_model_path(model_name) 18 | try: 19 | model = torch.load(path) 20 | model.eval() 21 | return model 22 | except FileNotFoundError: 23 | if raise_not_found: 24 | raise FileNotFoundError("No model found at {}".format(path)) 25 | 26 | 27 | def save_model(model, model_name): 28 | path = get_model_path(model_name) 29 | utils.create_folders_if_necessary(path) 30 | torch.save(model, path) 31 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: bbAI 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - atk=2.25.90=hf2eb9ee_1001 9 | - backcall=0.1.0=py_0 10 | - blas=2.4=openblas 11 | - bzip2=1.0.6=h470a237_2 12 | - ca-certificates=2019.1.23=0 13 | - cairo=1.14.12=h8948797_3 14 | - certifi=2019.3.9=py36_0 15 | - cffi=1.11.5=py36he75722e_1 16 | - click=7.0=py36_0 17 | - cuda90=1.0=h6433d27_0 18 | - cudatoolkit=9.0=h13b8566_0 19 | - dbus=1.13.6=h746ee38_0 20 | - decorator=4.4.0=py_0 21 | - expat=2.2.6=he6710b0_0 22 | - ffmpeg=4.0.2=ha0c5888_2 23 | - fontconfig=2.13.1=he4413a7_1000 24 | - freetype=2.9.1=h8a8886c_1 25 | - gdk-pixbuf=2.36.12=h4f1c04b_1001 26 | - gettext=0.19.8.1=hc5be6a0_1002 27 | - giflib=5.1.4=h470a237_1 28 | - glib=2.56.2=hd408876_0 29 | - gmp=6.1.2=hfc679d8_0 30 | - gnutls=3.5.19=h2a4e5f8_1 31 | - gobject-introspection=1.56.1=py36hbc4ca2d_2 32 | - graphite2=1.3.12=hfc679d8_1 33 | - gst-plugins-base=1.14.0=hbbd80ab_1 34 | - gstreamer=1.14.0=hb453b48_1 35 | - gtk2=2.24.31=h5baeb44_1000 36 | - harfbuzz=1.9.0=he243708_1001 37 | - hdf5=1.10.2=hc401514_3 38 | - icu=58.2=h9c2bf20_1 39 | - intel-openmp=2019.0=118 40 | - ipython=7.4.0=py36h24bf2e0_0 41 | - ipython_genutils=0.2.0=py_1 42 | - jasper=1.900.1=hff1ad4c_5 43 | - jedi=0.13.3=py36_0 44 | - jpeg=9c=h470a237_1 45 | - libblas=3.8.0=4_openblas 46 | - libcblas=3.8.0=4_openblas 47 | - libedit=3.1.20170329=h6b74fdf_2 48 | - libffi=3.2.1=hd88cf55_4 49 | - libgcc=7.2.0=h69d50b8_2 50 | - libgcc-ng=8.2.0=hdf63c60_1 51 | - libgfortran=3.0.0=1 52 | - libgfortran-ng=7.3.0=hdf63c60_0 53 | - libiconv=1.15=h470a237_3 54 | - liblapack=3.8.0=4_openblas 55 | - liblapacke=3.8.0=4_openblas 56 | - libopenblas=0.2.20=h9ac9557_7 57 | - libpng=1.6.35=hbc83047_0 58 | - libprotobuf=3.7.0=h8b12597_2 59 | - libstdcxx-ng=8.2.0=hdf63c60_1 60 | - libtiff=4.0.9=he85c1e1_2 61 | - libuuid=2.32.1=h14c3975_1000 62 | - libwebp=0.5.2=7 63 | - libxcb=1.13=h1bed415_1 64 | - libxml2=2.9.8=h26e45fe_1 65 | - mkl=2019.0=118 66 | - ncurses=6.1=hf484d3e_0 67 | - nettle=3.3=0 68 | - ninja=1.8.2=py36h6bb024c_1 69 | - numpy=1.16.2=py36h8b7e671_1 70 | - olefile=0.46=py36_0 71 | - openblas=0.3.5=h9ac9557_1001 72 | - opencv=3.4.1=py36h6fd60c2_1 73 | - openh264=1.8.0=hd28b015_0 74 | - openssl=1.1.1=h7b6447c_0 75 | - pango=1.40.14=hf0c64fd_1003 76 | - parso=0.3.4=py_0 77 | - pcre=8.42=h439df22_0 78 | - pexpect=4.6.0=py36_1000 79 | - pickleshare=0.7.5=py36_1000 80 | - pillow=5.3.0=py36h34e0f95_0 81 | - pip=18.1=py36_0 82 | - pixman=0.34.0=h470a237_3 83 | - prompt_toolkit=2.0.9=py_0 84 | - protobuf=3.7.0=py36he1b5a44_1 85 | - ptyprocess=0.6.0=py36_1000 86 | - pycparser=2.19=py36_0 87 | - pygments=2.3.1=py_0 88 | - pyqt=5.9.2=py36h05f1152_2 89 | - python=3.6.8=h0371630_0 90 | - pytorch=0.4.1=py36_py35_py27__9.0.176_7.1.2_2 91 | - qt=5.9.7=h5867ecd_1 92 | - readline=7.0=h7b6447c_5 93 | - setuptools=40.6.2=py36_0 94 | - sip=4.19.8=py36hf484d3e_0 95 | - six=1.11.0=py36_1 96 | - sqlite=3.26.0=h67949de_1001 97 | - tensorboardx=1.6=py_0 98 | - tk=8.6.8=hbc83047_0 99 | - torchvision=0.2.1=py_2 100 | - traitlets=4.3.2=py36_1000 101 | - wcwidth=0.1.7=py_1 102 | - wheel=0.32.3=py36_0 103 | - x264=1!152.20180717=h470a237_1 104 | - xorg-kbproto=1.0.7=h14c3975_1002 105 | - xorg-libice=1.0.9=h516909a_1004 106 | - xorg-libsm=1.2.3=h84519dc_1000 107 | - xorg-libx11=1.6.7=h14c3975_1000 108 | - xorg-libxext=1.3.4=h516909a_0 109 | - xorg-libxrender=0.9.10=h516909a_1002 110 | - xorg-libxt=1.1.5=h14c3975_1002 111 | - xorg-renderproto=0.11.1=h14c3975_1002 112 | - xorg-xextproto=7.3.0=h14c3975_1002 113 | - xorg-xproto=7.0.31=h14c3975_1007 114 | - xz=5.2.4=h14c3975_4 115 | - zlib=1.2.11=ha838bed_2 116 | - pip: 117 | - chardet==3.0.4 118 | - cycler==0.10.0 119 | - easyprocess==0.2.3 120 | - future==0.17.1 121 | - gym==0.10.9 122 | - idna==2.7 123 | - ipdb==0.12 124 | - kiwisolver==1.0.1 125 | - matplotlib==3.0.3 126 | - opencv-python==4.0.0.21 127 | - pyglet==1.3.2 128 | - pyparsing==2.3.1 129 | - pyqt5==5.11.3 130 | - pyqt5-sip==4.19.13 131 | - python-dateutil==2.8.0 132 | - pyvirtualdisplay==0.2.1 133 | - requests==2.20.1 134 | - scipy==1.1.0 135 | - urllib3==1.24.1 136 | prefix: /private/home/nke001/anaconda3/envs/bbAI 137 | 138 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /scripts/enjoy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | 4 | import argparse 5 | import gym 6 | import time 7 | 8 | import babyai.utils as utils 9 | from pyvirtualdisplay import Display 10 | 11 | display_ = Display(visible=0, size=(550, 500)) 12 | display_.start() 13 | 14 | 15 | 16 | # Parse arguments 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--env", required=True, 20 | help="name of the environment to be run (REQUIRED)") 21 | parser.add_argument("--model", default=None, 22 | help="name of the trained model (REQUIRED or --demos-origin REQUIRED)") 23 | parser.add_argument("--demos-origin", default=None, 24 | help="origin of the demonstrations: human | agent (REQUIRED or --model REQUIRED)") 25 | parser.add_argument("--seed", type=int, default=None, 26 | help="random seed (default: 0 if model agent, 1 if demo agent)") 27 | parser.add_argument("--shift", type=int, default=0, 28 | help="number of times the environment is reset at the beginning (default: 0)") 29 | parser.add_argument("--deterministic", action="store_true", default=False, 30 | help="action with highest probability is selected for model agent") 31 | parser.add_argument("--pause", type=float, default=0.1, 32 | help="the pause between two consequent actions of an agent") 33 | 34 | args = parser.parse_args() 35 | 36 | assert args.model is not None or args.demos_origin is not None, "--model or --demos-origin must be specified." 37 | if args.seed is None: 38 | args.seed = 0 if args.model is not None else 1 39 | 40 | # Set seed for all randomness sources 41 | 42 | utils.seed(args.seed) 43 | 44 | # Generate environment 45 | 46 | env = gym.make(args.env) 47 | env.seed(args.seed) 48 | for _ in range(args.shift): 49 | env.reset() 50 | 51 | # Define agent 52 | 53 | agent = utils.load_agent(args, env) 54 | 55 | # Run the agent 56 | 57 | done = True 58 | import cv2 59 | import numpy as np 60 | episode = 0 61 | step = 0 62 | while True: 63 | time.sleep(args.pause) 64 | image = env.render("rgb_array") 65 | image = cv2.resize(image, dsize=(512, 512), interpolation=cv2.INTER_CUBIC) 66 | #image = np.transpose(image, (2, 0, 1)) 67 | file_name = 'rendered_image/episodes_'+str(episode) + '_step_' +str(step) + '.png' 68 | cv2.imwrite(file_name, image[:,:,::-1]) 69 | step += 1 70 | if done: 71 | obs = env.reset() 72 | print("Mission: {}".format(obs["mission"])) 73 | episode += 1 74 | step = 0 75 | action = agent.get_action(obs) 76 | obs, reward, done, _ = env.step(action) 77 | agent.analyze_feedback(reward, done) 78 | 79 | if done: 80 | print("Reward:", reward) 81 | 82 | #if image.window is None: 83 | # break 84 | -------------------------------------------------------------------------------- /scripts/enjoy_zf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | 4 | import argparse 5 | import gym 6 | import time 7 | 8 | import babyai.utils as utils 9 | from pyvirtualdisplay import Display 10 | import scipy.optimize 11 | import random 12 | import scipy.misc 13 | import torch 14 | from scripts.rl_zforcing import ZForcing 15 | import os 16 | 17 | import matplotlib 18 | matplotlib.use('Agg') 19 | import matplotlib.pyplot as plt 20 | import matplotlib.cm as cm 21 | import numpy as np 22 | 23 | 24 | display_ = Display(visible=0, size=(550, 500)) 25 | display_.start() 26 | 27 | 28 | 29 | # Parse arguments 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--env", required=True, 33 | help="name of the environment to be run (REQUIRED)") 34 | parser.add_argument("--demos-origin", default=None, 35 | help="origin of the demonstrations: human | agent (REQUIRED or --model REQUIRED)") 36 | parser.add_argument("--seed", type=int, default=None, 37 | help="random seed (default: 0 if model agent, 1 if demo agent)") 38 | parser.add_argument("--model", default=None, 39 | help="name of the trained model (REQUIRED or --demos-origin REQUIRED)") 40 | parser.add_argument("--shift", type=int, default=0, 41 | help="number of times the environment is reset at the beginning (default: 0)") 42 | parser.add_argument("--deterministic", action="store_true", default=False, 43 | help="action with highest probability is selected for model agent") 44 | parser.add_argument("--pause", type=float, default=0.1, 45 | help="the pause between two consequent actions of an agent") 46 | 47 | args = parser.parse_args() 48 | 49 | def load_param(model, model_file_name): 50 | model.load_state_dict(torch.load(model_file_name)) 51 | return model 52 | 53 | if args.seed is None: 54 | args.seed = 1 55 | 56 | # Set seed for all randomness sources 57 | 58 | utils.seed(args.seed) 59 | num_actions = 7 60 | 61 | # Generate environment 62 | 63 | env = gym.make(args.env) 64 | env.seed(args.seed) 65 | for _ in range(args.shift): 66 | env.reset() 67 | 68 | # Define agent 69 | random.seed(20) 70 | 71 | agent = utils.load_agent(args, env) 72 | 73 | zf_model = 'BabyAI-UnlockPickup-v0_model/zforce_2opt_room_10_lr0.0001_bwd_w_1.0_aux_w_1e-06_kld_w_0.0_491.pkl' 74 | #zf_model = 'BabyAI-UnlockPickup-v0_model/zforce_2opt_room_10_lr0.0001_bwd_w_0.0_aux_w_0.0_kld_w_0.0_976.pkl' 75 | #model_name = 'zforce_2opt_room_10_lr0.0001_bwd_w_0.0_aux_w_0.0_kld_w_0.0_976' + str(random.randint(1,5000)) 76 | model_name = 'zforce_2opt_room_10_lr0.0001_bwd_w_1.0_aux_w_1e-06_kld_w_0.0_491_' + str(random.randint(1,5111)) 77 | image_dir = 'zf_rendered_image' 78 | 79 | import ipdb; ipdb.set_trace() 80 | 81 | model_image_dir = os.path.join(image_dir, model_name) 82 | os.mkdir(model_image_dir) 83 | 84 | zf = ZForcing(emb_dim=512, rnn_dim=512, z_dim=256, 85 | mlp_dim=256, out_dim=num_actions , z_force=True, cond_ln=True, return_loss=True) 86 | zf = load_param(zf, zf_model) 87 | # Run the agent 88 | 89 | done = True 90 | import cv2 91 | import numpy as np 92 | episode = 0 93 | step = 0 94 | logs = {"num_frames_per_episode": [], "return_per_episode": []} 95 | returnn = 0 96 | 97 | zf.float().cuda() 98 | hidden = zf.init_hidden(1) 99 | obs = env.reset() 100 | num_frames = 0 101 | step = 0 102 | model_image_dir_episode = os.path.join(model_image_dir, 'episode_0') 103 | os.mkdir(model_image_dir_episode) 104 | 105 | num_episode = 200 106 | 107 | aux_loss = [] 108 | 109 | 110 | def plot_loss(aux_loss, image_dir): 111 | plt.plot(aux_loss) 112 | plt_file = os.path.join(image_dir, 'aux_loss.pdf') 113 | plt.savefig(plt_file) 114 | start_time = time.time() 115 | while True: 116 | time.sleep(args.pause) 117 | image = env.render("rgb_array") 118 | image = cv2.resize(image, dsize=(512, 512), interpolation=cv2.INTER_CUBIC) 119 | epi_file_name = 'episodes_' + str(episode) +'_step_' +str(step) + '.png' 120 | file_name = os.path.join(model_image_dir_episode, epi_file_name) 121 | #file_name = 'zf_rendered_image/episodes_'+str(episode) + '_step_' +str(step) + '.png' 122 | #image = np.transpose(image, (2, 0, 1)) 123 | cv2.imwrite(file_name, image[:,:,::-1]) 124 | 125 | obs_image = np.expand_dims(obs['image'], 0) 126 | mask = torch.ones(obs_image.shape).unsqueeze(0) 127 | obs_image = torch.from_numpy(obs_image).unsqueeze(0).permute(0,1,4,2,3) 128 | action, hidden, aux_nll = zf.generate_onestep(obs_image.float().cuda(), mask.cuda(), hidden) 129 | aux_loss.append(aux_nll) 130 | if done: 131 | obs = env.reset() 132 | plot_loss(aux_loss, model_image_dir_episode) 133 | print("Mission: {}".format(obs["mission"])) 134 | episode += 1 135 | obs = env.reset() 136 | done = False 137 | num_frames = 0 138 | returnn = 0 139 | step = 0 140 | hidden = zf.init_hidden(1) 141 | model_image_dir_episode = os.path.join(model_image_dir, 'episode_'+str(episode)) 142 | os.mkdir(model_image_dir_episode) 143 | aux_loss = [] 144 | action = agent.get_action(obs) 145 | obs, reward, done, _ = env.step(action) 146 | agent.analyze_feedback(reward, done) 147 | num_frames += 1 148 | step += 1 149 | returnn += reward 150 | if done: 151 | print("Reward:", reward) 152 | logs["num_frames_per_episode"].append(num_frames) 153 | logs["return_per_episode"].append(returnn) 154 | 155 | if episode > num_episode: 156 | break 157 | #if image.window is None: 158 | # break 159 | import datetime 160 | end_time = time.time() 161 | num_frames = sum(logs["num_frames_per_episode"]) 162 | fps = num_frames/(end_time - start_time) 163 | ellapsed_time = int(end_time - start_time) 164 | duration = datetime.timedelta(seconds=ellapsed_time) 165 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 166 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 167 | 168 | log_line = ("F {} | FPS {:.0f} | D {} | R:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {}".format(num_frames, fps, duration, *return_per_episode.values(),*num_frames_per_episode.values())) 169 | print("F {} | FPS {:.0f} | D {} | R:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {}" 170 | .format(num_frames, fps, duration, 171 | *return_per_episode.values(), 172 | *num_frames_per_episode.values())) 173 | #print(log_line) 174 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | 4 | import argparse 5 | import gym 6 | import time 7 | import datetime 8 | 9 | import babyai.utils as utils 10 | 11 | # Parse arguments 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--env", required=True, 15 | help="name of the environment to be run (REQUIRED)") 16 | parser.add_argument("--model", default=None, 17 | help="name of the trained model (REQUIRED or --demos-origin REQUIRED)") 18 | parser.add_argument("--demos-origin", default=None, 19 | help="origin of the demonstrations: human | agent (REQUIRED or --model REQUIRED)") 20 | parser.add_argument("--episodes", type=int, default=1000, 21 | help="number of episodes of evaluation (default: 1000)") 22 | parser.add_argument("--seed", type=int, default=None, 23 | help="random seed (default: 0 if model agent, 1 if demo agent)") 24 | parser.add_argument("--deterministic", action="store_true", default=False, 25 | help="action with highest probability is selected for model agent") 26 | 27 | 28 | def evaluate(agent, env, episodes): 29 | # Initialize logs 30 | logs = {"num_frames_per_episode": [], "return_per_episode": []} 31 | 32 | for _ in range(episodes): 33 | obs = env.reset() 34 | done = False 35 | 36 | num_frames = 0 37 | returnn = 0 38 | 39 | while not(done): 40 | action = agent.get_action(obs) 41 | obs, reward, done, _ = env.step(action) 42 | agent.analyze_feedback(reward, done) 43 | 44 | num_frames += 1 45 | returnn += reward 46 | logs["num_frames_per_episode"].append(num_frames) 47 | logs["return_per_episode"].append(returnn) 48 | 49 | return logs 50 | 51 | 52 | if __name__ == "__main__": 53 | args = parser.parse_args() 54 | 55 | assert args.model is not None or args.demos_origin is not None, "--model or --demos-origin must be specified." 56 | if args.seed is None: 57 | args.seed = 0 if args.model is not None else 1 58 | 59 | # Set seed for all randomness sources 60 | 61 | utils.seed(args.seed) 62 | 63 | # Generate environment 64 | 65 | env = gym.make(args.env) 66 | env.seed(args.seed) 67 | 68 | # Define agent 69 | 70 | agent = utils.load_agent(args, env) 71 | 72 | if args.model is None and args.episodes > len(agent.demos): 73 | # Set the number of episodes to be the number of demos 74 | 75 | args.episodes = len(agent.demos) 76 | 77 | # Run the agent 78 | 79 | start_time = time.time() 80 | 81 | logs = evaluate(agent, env, args.episodes) 82 | 83 | end_time = time.time() 84 | 85 | # Print logs 86 | 87 | num_frames = sum(logs["num_frames_per_episode"]) 88 | fps = num_frames/(end_time - start_time) 89 | ellapsed_time = int(end_time - start_time) 90 | duration = datetime.timedelta(seconds=ellapsed_time) 91 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 92 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 93 | 94 | print("F {} | FPS {:.0f} | D {} | R:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {}" 95 | .format(num_frames, fps, duration, 96 | *return_per_episode.values(), 97 | *num_frames_per_episode.values())) 98 | 99 | indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k]) 100 | n = 10 101 | print("{} worst episodes:".format(n)) 102 | for i in indexes[:n]: 103 | print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i])) 104 | -------------------------------------------------------------------------------- /scripts/evaluate_all_demos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | from subprocess import call 4 | import sys 5 | 6 | import babyai.utils as utils 7 | 8 | folder = os.path.join(utils.storage_dir(), "demos") 9 | for filename in sorted(os.listdir(folder)): 10 | if filename.endswith(".pkl"): 11 | env = filename.split("_")[0] 12 | demos_maker = filename.split("_")[-1].split('.')[0] 13 | print("> Env: {} - {}".format(env, demos_maker)) 14 | command = ["python evaluate.py --env {} --demos-origin {}".format(env, demos_maker)] + sys.argv[1:] 15 | call(" ".join(command), shell=True) -------------------------------------------------------------------------------- /scripts/evaluate_all_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | from subprocess import call 4 | import sys 5 | import argparse 6 | 7 | import babyai.utils as utils 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--episodes", type=int, default=1000, 11 | help="number of episodes of evaluation (default: 1000)") 12 | args = parser.parse_args() 13 | 14 | folder = os.path.join(utils.storage_dir(), "models") 15 | for model in sorted(os.listdir(folder)): 16 | if model.startswith('.'): 17 | continue 18 | env = model.split("_")[0] 19 | print("> Env: {}".format(env)) 20 | command = ["python evaluate.py --env {} --model {} --episodes {}".format(env, model, args.episodes)] + sys.argv[1:] 21 | call(" ".join(command), shell=True) -------------------------------------------------------------------------------- /scripts/gen_samples.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | 4 | import argparse 5 | import gym 6 | import time 7 | import datetime 8 | import pickle 9 | import babyai.utils as utils 10 | 11 | # Parse arguments 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--env", required=True, 15 | help="name of the environment to be run (REQUIRED)") 16 | parser.add_argument("--model", default=None, 17 | help="name of the trained model (REQUIRED or --demos-origin REQUIRED)") 18 | parser.add_argument("--demos-origin", default=None, 19 | help="origin of the demonstrations: human | agent (REQUIRED or --model REQUIRED)") 20 | parser.add_argument("--episodes", type=int, default=1000, 21 | help="number of episodes of evaluation (default: 1000)") 22 | parser.add_argument("--seed", type=int, default=None, 23 | help="random seed (default: 0 if model agent, 1 if demo agent)") 24 | parser.add_argument("--room", type=int, default=8, 25 | help="size of the room") 26 | parser.add_argument("--deterministic", action="store_true", default=False, 27 | help="action with highest probability is selected for model agent") 28 | 29 | def write_samples(all_samples_obs, all_samples_actions, filename): 30 | # write to pickle file 31 | import ipdb; ipdb.set_trace() 32 | all_data = list(zip(all_samples_obs, all_samples_actions)) 33 | output = open(filename, "wb") 34 | pickle.dump(all_data, output) 35 | output.close() 36 | return True 37 | 38 | def evaluate(agent, env, episodes): 39 | # Initialize logs 40 | logs = {"num_frames_per_episode": [], "return_per_episode": []} 41 | expert_actions = [] 42 | expert_obs = [] 43 | all_samples_obs = [] 44 | all_samples_actions = [] 45 | 46 | for ep in range(episodes): 47 | obs = env.reset() 48 | done = False 49 | expert_obs = [] 50 | expert_actions = [] 51 | expert_obs.append(obs['image']) 52 | num_frames = 0 53 | returnn = 0 54 | 55 | while not(done): 56 | action = agent.get_action(obs) 57 | expert_actions.append(action) 58 | obs, reward, done, _ = env.step(action) 59 | 60 | expert_obs.append(obs['image']) 61 | agent.analyze_feedback(reward, done) 62 | 63 | num_frames += 1 64 | returnn += reward 65 | 66 | import ipdb; ipdb.set_trace() 67 | all_samples_obs.append(expert_obs) 68 | all_samples_actions.append(expert_actions) 69 | if (ep + 1) % 1000 == 0: 70 | print ("iteration: " + str(ep)) 71 | logs["num_frames_per_episode"].append(num_frames) 72 | logs["return_per_episode"].append(returnn) 73 | 74 | filename = 'data/' + args.env + 'start_flag_room_'+ str(args.room)+ '_' + str(episodes) + '_samples.dat' 75 | write_samples(all_samples_obs, all_samples_actions, filename) 76 | return logs 77 | 78 | 79 | if __name__ == "__main__": 80 | args = parser.parse_args() 81 | 82 | assert args.model is not None or args.demos_origin is not None, "--model or --demos-origin must be specified." 83 | if args.seed is None: 84 | args.seed = 0 if args.model is not None else 1 85 | 86 | # Set seed for all randomness sources 87 | 88 | utils.seed(args.seed) 89 | 90 | # Generate environment 91 | 92 | env = gym.make(args.env) 93 | env.seed(args.seed) 94 | 95 | # Define agent 96 | 97 | agent = utils.load_agent(args, env) 98 | 99 | if args.model is None and args.episodes > len(agent.demos): 100 | # Set the number of episodes to be the number of demos 101 | 102 | args.episodes = len(agent.demos) 103 | 104 | # Run the agent 105 | 106 | start_time = time.time() 107 | 108 | logs = evaluate(agent, env, args.episodes) 109 | 110 | end_time = time.time() 111 | 112 | # Print logs 113 | 114 | num_frames = sum(logs["num_frames_per_episode"]) 115 | fps = num_frames/(end_time - start_time) 116 | ellapsed_time = int(end_time - start_time) 117 | duration = datetime.timedelta(seconds=ellapsed_time) 118 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 119 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 120 | 121 | print("F {} | FPS {:.0f} | D {} | R:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {}" 122 | .format(num_frames, fps, duration, 123 | *return_per_episode.values(), 124 | *num_frames_per_episode.values())) 125 | 126 | indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k]) 127 | n = 10 128 | print("{} worst episodes:".format(n)) 129 | for i in indexes[:n]: 130 | print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i])) 131 | -------------------------------------------------------------------------------- /scripts/gen_samples_bidir.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | 4 | import argparse 5 | import gym 6 | import time 7 | import datetime 8 | import pickle 9 | import babyai.utils as utils 10 | 11 | # Parse arguments 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--env", required=True, 15 | help="name of the environment to be run (REQUIRED)") 16 | parser.add_argument("--model", default=None, 17 | help="name of the trained model (REQUIRED or --demos-origin REQUIRED)") 18 | parser.add_argument("--demos-origin", default=None, 19 | help="origin of the demonstrations: human | agent (REQUIRED or --model REQUIRED)") 20 | parser.add_argument("--episodes", type=int, default=1000, 21 | help="number of episodes of evaluation (default: 1000)") 22 | parser.add_argument("--seed", type=int, default=None, 23 | help="random seed (default: 0 if model agent, 1 if demo agent)") 24 | parser.add_argument("--room", type=int, default=8, 25 | help="size of the room") 26 | parser.add_argument("--deterministic", action="store_true", default=False, 27 | help="action with highest probability is selected for model agent") 28 | 29 | def write_samples(all_samples_obs, all_samples_actions, filename): 30 | # write to pickle file 31 | import ipdb; ipdb.set_trace() 32 | all_data = list(zip(all_samples_obs, all_samples_actions)) 33 | output = open(filename, "wb") 34 | pickle.dump(all_data, output) 35 | output.close() 36 | return True 37 | 38 | def evaluate(agent, env, episodes): 39 | # Initialize logs 40 | logs = {"num_frames_per_episode": [], "return_per_episode": []} 41 | expert_actions = [] 42 | expert_obs = [] 43 | expert_obs_bwd = [] 44 | all_samples_obs = [] 45 | all_samples_actions = [] 46 | 47 | for ep in range(episodes): 48 | obs = env.reset() 49 | done = False 50 | expert_obs = [] 51 | expert_actions = [] 52 | 53 | expert_obs.append(obs['image']) 54 | num_frames = 0 55 | returnn = 0 56 | 57 | while not(done): 58 | action = agent.get_action(obs) 59 | expert_actions.append(action) 60 | obs, reward, done, _ = env.step(action) 61 | 62 | expert_obs.append(obs['image']) 63 | agent.analyze_feedback(reward, done) 64 | 65 | num_frames += 1 66 | returnn += reward 67 | all_samples_obs.append(expert_obs) 68 | all_samples_actions.append(expert_actions) 69 | if (ep + 1) % 1000 == 0: 70 | print ("iteration: " + str(ep)) 71 | logs["num_frames_per_episode"].append(num_frames) 72 | logs["return_per_episode"].append(returnn) 73 | 74 | filename = args.env + '_room_'+ str(args.room)+ '_' + str(episodes) + '_samples.dat' 75 | write_samples(all_samples_obs, all_samples_actions, filename) 76 | return logs 77 | 78 | 79 | if __name__ == "__main__": 80 | args = parser.parse_args() 81 | 82 | assert args.model is not None or args.demos_origin is not None, "--model or --demos-origin must be specified." 83 | if args.seed is None: 84 | args.seed = 0 if args.model is not None else 1 85 | 86 | # Set seed for all randomness sources 87 | 88 | utils.seed(args.seed) 89 | 90 | # Generate environment 91 | 92 | env = gym.make(args.env) 93 | env.seed(args.seed) 94 | 95 | # Define agent 96 | 97 | agent = utils.load_agent(args, env) 98 | 99 | if args.model is None and args.episodes > len(agent.demos): 100 | # Set the number of episodes to be the number of demos 101 | 102 | args.episodes = len(agent.demos) 103 | 104 | # Run the agent 105 | 106 | start_time = time.time() 107 | 108 | logs = evaluate(agent, env, args.episodes) 109 | 110 | end_time = time.time() 111 | 112 | # Print logs 113 | 114 | num_frames = sum(logs["num_frames_per_episode"]) 115 | fps = num_frames/(end_time - start_time) 116 | ellapsed_time = int(end_time - start_time) 117 | duration = datetime.timedelta(seconds=ellapsed_time) 118 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 119 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 120 | 121 | print("F {} | FPS {:.0f} | D {} | R:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {}" 122 | .format(num_frames, fps, duration, 123 | *return_per_episode.values(), 124 | *num_frames_per_episode.values())) 125 | 126 | indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k]) 127 | n = 10 128 | print("{} worst episodes:".format(n)) 129 | for i in indexes[:n]: 130 | print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i])) 131 | -------------------------------------------------------------------------------- /scripts/make_agent_demos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | 4 | import argparse 5 | import gym 6 | 7 | import babyai.utils as utils 8 | 9 | # Parse arguments 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--env", required=True, 13 | help="name of the environment to be run (REQUIRED)") 14 | parser.add_argument("--model", required=True, 15 | help="name of the trained model (REQUIRED)") 16 | parser.add_argument("--episodes", type=int, default=1000, 17 | help="number of episodes to generate demonstrations for (default: 1000)") 18 | parser.add_argument("--seed", type=int, default=1, 19 | help="random seed (default: 1)") 20 | parser.add_argument("--deterministic", action="store_true", default=False, 21 | help="action with highest probability is selected") 22 | parser.add_argument("--save-interval", type=int, default=0, 23 | help="interval between demonstrations saving (default: 0, 0 means only at the end)") 24 | args = parser.parse_args() 25 | 26 | # Set seed for all randomness sources 27 | 28 | utils.seed(args.seed) 29 | 30 | # Generate environment 31 | 32 | env = gym.make(args.env) 33 | env.seed(args.seed) 34 | 35 | # Define agent 36 | 37 | agent = utils.load_agent(args, env) 38 | 39 | # Load demonstrations 40 | 41 | demos = utils.load_demos(args.env, "agent") 42 | utils.synthesize_demos(demos) 43 | 44 | for i in range(1, args.episodes+1): 45 | # Run the expert for one episode 46 | 47 | done = False 48 | obs = env.reset() 49 | demo = [] 50 | 51 | while not(done): 52 | action = agent.get_action(obs) 53 | new_obs, reward, done, _ = env.step(action) 54 | agent.analyze_feedback(reward, done) 55 | 56 | demo.append((obs, action, reward, done)) 57 | obs = new_obs 58 | 59 | demos.append(demo) 60 | 61 | # Save demonstrations 62 | 63 | if args.save_interval > 0 and i < args.episodes and i % args.save_interval == 0: 64 | utils.save_demos(demos, args.env, "agent") 65 | utils.synthesize_demos(demos) 66 | 67 | # Save demonstrations 68 | 69 | utils.save_demos(demos, args.env, "agent") 70 | utils.synthesize_demos(demos) -------------------------------------------------------------------------------- /scripts/make_human_demos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | 4 | import sys 5 | import copy 6 | import random 7 | import argparse 8 | import gym 9 | from PyQt5.QtCore import Qt 10 | from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QInputDialog 11 | from PyQt5.QtWidgets import QLabel, QTextEdit, QFrame 12 | from PyQt5.QtWidgets import QPushButton, QHBoxLayout, QVBoxLayout 13 | 14 | import babyai.utils as utils 15 | 16 | # Parse arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--env", required=True, 19 | help="name of the environment to be loaded (REQUIRED)") 20 | parser.add_argument("--seed", type=int, default=1, 21 | help="random seed (default: 1)") 22 | parser.add_argument("--shift", type=int, default=None, 23 | help="number of times the environment is reset at the beginning (default: NUM_DEMOS)") 24 | parser.add_argument("--full-view", action="store_true", default=False, 25 | help="show the full environment view") 26 | args = parser.parse_args() 27 | 28 | class ImgWidget(QLabel): 29 | """ 30 | Widget to intercept clicks on the full image view 31 | """ 32 | def __init__(self, window): 33 | super().__init__() 34 | self.window = window 35 | 36 | class AIGameWindow(QMainWindow): 37 | """Application window for the baby AI game""" 38 | 39 | def __init__(self, env): 40 | super().__init__() 41 | self.initUI() 42 | 43 | # By default, manual stepping only 44 | self.fpsLimit = 0 45 | 46 | self.env = env 47 | self.lastObs = None 48 | 49 | # Demonstrations 50 | self.demos = utils.load_demos(args.env, "human") 51 | utils.synthesize_demos(self.demos) 52 | self.current_demo = [] 53 | 54 | self.shift = len(self.demos) if args.shift is None else args.shift 55 | 56 | self.shiftEnv() 57 | 58 | # Pointing and naming data 59 | self.pointingData = [] 60 | 61 | def initUI(self): 62 | """Create and connect the UI elements""" 63 | 64 | self.resize(512, 512) 65 | self.setWindowTitle('Baby AI Game') 66 | 67 | # Full render view (large view) 68 | self.imgLabel = ImgWidget(self) 69 | self.imgLabel.setFrameStyle(QFrame.Panel | QFrame.Sunken) 70 | leftBox = QVBoxLayout() 71 | leftBox.addStretch(1) 72 | leftBox.addWidget(self.imgLabel) 73 | leftBox.addStretch(1) 74 | 75 | # Area on the right of the large view 76 | rightBox = self.createRightArea() 77 | 78 | # Arrange widgets horizontally 79 | hbox = QHBoxLayout() 80 | hbox.addLayout(leftBox) 81 | hbox.addLayout(rightBox) 82 | 83 | # Create a main widget for the window 84 | mainWidget = QWidget(self) 85 | self.setCentralWidget(mainWidget) 86 | mainWidget.setLayout(hbox) 87 | 88 | # Show the application window 89 | self.show() 90 | self.setFocus() 91 | 92 | def createRightArea(self): 93 | # Agent render view (partially observable) 94 | self.obsImgLabel = QLabel() 95 | self.obsImgLabel.setFrameStyle(QFrame.Panel | QFrame.Sunken) 96 | miniViewBox = QHBoxLayout() 97 | miniViewBox.addStretch(1) 98 | miniViewBox.addWidget(self.obsImgLabel) 99 | miniViewBox.addStretch(1) 100 | 101 | self.missionBox = QTextEdit() 102 | self.missionBox.setMinimumSize(500, 100) 103 | 104 | buttonBox = self.createButtons() 105 | 106 | self.stepsLabel = QLabel() 107 | self.stepsLabel.setFrameStyle(QFrame.Panel | QFrame.Sunken) 108 | self.stepsLabel.setAlignment(Qt.AlignCenter) 109 | self.stepsLabel.setMinimumSize(60, 10) 110 | restartBtn = QPushButton("Restart") 111 | restartBtn.clicked.connect(self.shiftEnv) 112 | stepsBox = QHBoxLayout() 113 | stepsBox.addStretch(1) 114 | stepsBox.addWidget(QLabel("Steps remaining")) 115 | stepsBox.addWidget(self.stepsLabel) 116 | stepsBox.addWidget(restartBtn) 117 | stepsBox.addStretch(1) 118 | stepsBox.addStretch(1) 119 | 120 | hline2 = QFrame() 121 | hline2.setFrameShape(QFrame.HLine) 122 | hline2.setFrameShadow(QFrame.Sunken) 123 | 124 | # Stack everything up in a vetical layout 125 | vbox = QVBoxLayout() 126 | vbox.addLayout(miniViewBox) 127 | vbox.addLayout(stepsBox) 128 | vbox.addWidget(hline2) 129 | vbox.addWidget(QLabel("")) 130 | vbox.addWidget(self.missionBox) 131 | vbox.addLayout(buttonBox) 132 | 133 | return vbox 134 | 135 | def createButtons(self): 136 | """Create the row of UI buttons""" 137 | 138 | # Assemble the buttons into a horizontal layout 139 | hbox = QHBoxLayout() 140 | hbox.addStretch(1) 141 | hbox.addStretch(1) 142 | 143 | return hbox 144 | 145 | def keyPressEvent(self, e): 146 | # Manual agent control 147 | actions = self.env.unwrapped.actions 148 | 149 | if e.key() == Qt.Key_Left: 150 | self.stepEnv(actions.left) 151 | elif e.key() == Qt.Key_Right: 152 | self.stepEnv(actions.right) 153 | elif e.key() == Qt.Key_Up: 154 | self.stepEnv(actions.forward) 155 | 156 | elif e.key() == Qt.Key_PageUp: 157 | self.stepEnv(actions.pickup) 158 | elif e.key() == Qt.Key_PageDown: 159 | self.stepEnv(actions.drop) 160 | elif e.key() == Qt.Key_Space: 161 | self.stepEnv(actions.toggle) 162 | 163 | elif e.key() == Qt.Key_Backspace: 164 | self.shiftEnv() 165 | elif e.key() == Qt.Key_Escape: 166 | self.close() 167 | 168 | def mousePressEvent(self, event): 169 | """ 170 | Clear the focus of the text boxes and buttons if somewhere 171 | else on the window is clicked 172 | """ 173 | 174 | # Set the focus on the full render image 175 | self.imgLabel.setFocus() 176 | 177 | QMainWindow.mousePressEvent(self, event) 178 | 179 | def shiftEnv(self): 180 | assert self.shift <= len(self.demos) 181 | 182 | self.env.seed(args.seed) 183 | self.resetEnv() 184 | for _ in range(self.shift): 185 | self.resetEnv() 186 | 187 | def resetEnv(self): 188 | self.current_demo = [] 189 | 190 | obs = self.env.reset() 191 | self.lastObs = obs 192 | self.showEnv(obs) 193 | 194 | self.missionBox.setText(obs["mission"]) 195 | 196 | def showEnv(self, obs): 197 | unwrapped = self.env.unwrapped 198 | 199 | # Render and display the environment 200 | if args.full_view: 201 | pixmap = self.env.render(mode='pixmap') 202 | self.imgLabel.setPixmap(pixmap) 203 | 204 | # Render and display the agent's view 205 | image = obs['image'] 206 | obsPixmap = unwrapped.get_obs_render(image) 207 | self.obsImgLabel.setPixmap(obsPixmap) 208 | 209 | # Set the steps remaining 210 | stepsRem = unwrapped.steps_remaining 211 | self.stepsLabel.setText(str(stepsRem)) 212 | 213 | def stepEnv(self, action=None): 214 | # If no manual action was specified by the user 215 | if action is None: 216 | action = random.randint(0, self.env.action_space.n - 1) 217 | action = int(action) 218 | 219 | obs, reward, done, info = self.env.step(action) 220 | 221 | self.current_demo.append((self.lastObs, action, reward, done)) 222 | 223 | self.showEnv(obs) 224 | self.lastObs = obs 225 | 226 | if done: 227 | if reward > 0: # i.e. we did not lose 228 | if self.shift < len(self.demos): 229 | self.demos[self.shift] = self.current_demo 230 | else: 231 | self.demos.append(self.current_demo) 232 | utils.save_demos(self.demos, args.env, "human") 233 | self.missionBox.append('Demonstrations are saved.') 234 | utils.synthesize_demos(self.demos) 235 | 236 | self.shift += 1 237 | self.resetEnv() 238 | else: 239 | self.shiftEnv() 240 | 241 | def main(argv): 242 | # Generate environment 243 | env = gym.make(args.env) 244 | 245 | # Create the application window 246 | app = QApplication(sys.argv) 247 | window = AIGameWindow(env) 248 | 249 | # Run the application 250 | sys.exit(app.exec_()) 251 | 252 | if __name__ == '__main__': 253 | main(sys.argv) 254 | -------------------------------------------------------------------------------- /scripts/rl_zforcing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | from torch.nn import Parameter 6 | from torch.autograd import Variable 7 | import torchvision.datasets as dsets 8 | import time 9 | import click 10 | import numpy 11 | import numpy as np 12 | import os 13 | import random 14 | from itertools import chain 15 | import torch.nn.functional as F 16 | from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend 17 | import math 18 | 19 | 20 | class View(nn.Module): 21 | def __init__(self, *args): 22 | super(View, self).__init__() 23 | self._shape = args 24 | def forward(self, x): 25 | return x.view(*self._shape) 26 | 27 | def log_prob_gaussian(x, mu, log_vars, mean=False): 28 | lp = - 0.5 * math.log(2 * math.pi) \ 29 | - log_vars / 2 - (x - mu) ** 2 / (2 * torch.exp(log_vars)) 30 | if mean: 31 | return torch.mean(lp, -1) 32 | return torch.sum(lp, -1) 33 | 34 | def log_prob_bernoulli(x, mu): 35 | lp = x * torch.log(mu + 1e-5) + (1. - y) * torch.log(1. - mu + 1e-5) 36 | return lp 37 | 38 | 39 | def gaussian_kld(mu_left, logvar_left, mu_right, logvar_right): 40 | """ 41 | Compute KL divergence between a bunch of univariate Gaussian distributions 42 | with the given means and log-variances. 43 | We do KL(N(mu_left, logvar_left) || N(mu_right, logvar_right)). 44 | """ 45 | gauss_klds = 0.5 * (logvar_right - logvar_left + 46 | (torch.exp(logvar_left) / torch.exp(logvar_right)) + 47 | ((mu_left - mu_right)**2.0 / torch.exp(logvar_right)) - 1.0) 48 | assert len(gauss_klds.size()) == 2 49 | return torch.sum(gauss_klds, 1) 50 | 51 | 52 | class LayerNorm(nn.Module): 53 | def __init__(self, nb_features, eps=1e-5): 54 | super(LayerNorm, self).__init__() 55 | self.eps = eps 56 | self.gain = nn.Parameter(torch.ones(1, nb_features)) 57 | self.bias = nn.Parameter(torch.zeros(1, nb_features)) 58 | 59 | def forward(self, x, gain=None, bias=None): 60 | assert len(x.size()) == 2 61 | if gain is None: 62 | gain = self.gain 63 | if bias is None: 64 | bias = self.bias 65 | mean = torch.mean(x, dim=-1, keepdim=True) 66 | std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + self.eps) 67 | z = (x - mean.expand_as(x)) / std.expand_as(x) 68 | return z * gain.expand_as(z) + bias.expand_as(z) 69 | 70 | 71 | class LSTMCell(nn.Module): 72 | """A basic LSTM cell.""" 73 | 74 | def __init__(self, input_size, hidden_size, use_layernorm=False): 75 | """ 76 | Most parts are copied from torch.nn.LSTMCell. 77 | """ 78 | 79 | super(LSTMCell, self).__init__() 80 | self.input_size = input_size 81 | self.hidden_size = hidden_size 82 | self.use_layernorm = use_layernorm 83 | self.has_bias = not self.use_layernorm 84 | if self.use_layernorm: 85 | self.use_bias = False 86 | print("LSTMCell: use_layernorm=%s" % use_layernorm) 87 | self.weight_ih = nn.Parameter( 88 | torch.FloatTensor(input_size, 4 * hidden_size)) 89 | self.weight_hh = nn.Parameter( 90 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 91 | if self.use_layernorm: 92 | self.ln_ih = LayerNorm(4 * hidden_size) 93 | self.ln_hh = LayerNorm(4 * hidden_size) 94 | else: 95 | self.bias_ih = Parameter(torch.FloatTensor(4 * hidden_size)) 96 | self.bias_hh = Parameter(torch.FloatTensor(4 * hidden_size)) 97 | self.init_weights() 98 | 99 | def init_weights(self): 100 | """ 101 | Initialize parameters following the way proposed in the paper. 102 | """ 103 | stdv = 1.0 / np.sqrt(self.hidden_size) 104 | self.weight_ih.data.uniform_(-stdv, stdv) 105 | nn.init.orthogonal(self.weight_hh.data) 106 | if self.has_bias: 107 | self.bias_ih.data.fill_(0) 108 | self.bias_hh.data.fill_(0) 109 | 110 | def forward(self, input_, hx, 111 | gain_ih=None, gain_hh=None, 112 | bias_ih=None, bias_hh=None): 113 | """ 114 | Args: 115 | input_: A (batch, input_size) tensor containing input 116 | features. 117 | hx: A tuple (h_0, c_0), which contains the initial hidden 118 | and cell state, where the size of both states is 119 | (batch, hidden_size). 120 | Returns: 121 | h_1, c_1: Tensors containing the next hidden and cell state. 122 | """ 123 | assert input_.is_cuda 124 | h_0, c_0 = hx 125 | igates = torch.mm(input_, self.weight_ih) 126 | hgates = torch.mm(h_0, self.weight_hh) 127 | state = fusedBackend.LSTMFused() 128 | if self.use_layernorm: 129 | igates = self.ln_ih(igates, gain=gain_ih, bias=bias_ih) 130 | hgates = self.ln_hh(hgates, gain=gain_hh, bias=bias_hh) 131 | return state.apply(igates, hgates, c_0) 132 | else: 133 | return state.apply(igates, hgates, c_0, 134 | self.bias_ih, self.bias_hh) 135 | 136 | def __repr__(self): 137 | s = '{name}({input_size}, {hidden_size})' 138 | return s.format(name=self.__class__.__name__, **self.__dict__) 139 | 140 | 141 | class LReLU(nn.Module): 142 | def __init__(self, c=1./3): 143 | super(LReLU, self).__init__() 144 | self.c = c 145 | 146 | def forward(self, x): 147 | return torch.clamp(F.leaky_relu(x, self.c), -3., 3.) 148 | 149 | 150 | class ZForcing(nn.Module): 151 | def __init__(self, emb_dim, rnn_dim, 152 | z_dim, mlp_dim, out_dim, bwd_out_dim=None, out_type="softmax", bwd_out_type = 'softmax', 153 | cond_ln=False, nlayers=1, z_force=False, dropout=0., 154 | use_l2=False, drop_grad=False, return_loss=False): 155 | super(ZForcing, self).__init__() 156 | assert not drop_grad, "drop_grad is not supported!" 157 | self.emb_dim = emb_dim 158 | self.out_dim = out_dim 159 | self.rnn_dim = rnn_dim 160 | self.nlayers = nlayers 161 | self.z_dim = z_dim 162 | self.return_loss = return_loss 163 | self.dropout = dropout 164 | self.out_type = out_type 165 | self.bwd_out_type = bwd_out_type 166 | self.mlp_dim = mlp_dim 167 | self.cond_ln = cond_ln 168 | self.z_force = z_force 169 | self.use_l2 = use_l2 170 | self.drop_grad = drop_grad 171 | if bwd_out_dim is None: 172 | self.bwd_out_dim = self.out_dim 173 | else: 174 | self.bwd_out_dim = bwd_out_dim 175 | 176 | self.fwd_emb_mod = nn.Sequential( 177 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1), 178 | nn.LeakyReLU(), 179 | nn.MaxPool2d(kernel_size=3, stride=2), 180 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 181 | nn.LeakyReLU(), 182 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 183 | nn.LeakyReLU(), 184 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), 185 | nn.LeakyReLU(), 186 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 187 | nn.LeakyReLU(), 188 | #nn.AvgPool2d(4), 189 | View(-1, 256), 190 | nn.Linear(256, emb_dim), 191 | nn.Dropout(dropout)) 192 | 193 | self.bwd_emb_mod = nn.Sequential( 194 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1), 195 | nn.LeakyReLU(), 196 | nn.MaxPool2d(kernel_size=3, stride=2), 197 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 198 | nn.LeakyReLU(), 199 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 200 | nn.LeakyReLU(), 201 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), 202 | nn.LeakyReLU(), 203 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 204 | nn.LeakyReLU(), 205 | #nn.AvgPool2d(4), 206 | View(-1, 256), 207 | nn.Linear(256, emb_dim), 208 | nn.Dropout(dropout)) 209 | self.bwd_mod = nn.LSTM(emb_dim, rnn_dim, nlayers) 210 | nn.init.orthogonal(self.bwd_mod.weight_hh_l0.data) 211 | self.fwd_mod = LSTMCell( 212 | emb_dim if cond_ln else emb_dim + mlp_dim, 213 | rnn_dim, use_layernorm=cond_ln) 214 | self.pri_mod = nn.Sequential( 215 | nn.Linear(rnn_dim, mlp_dim), 216 | LReLU(), 217 | nn.Linear(mlp_dim, z_dim * 2)) 218 | self.inf_mod = nn.Sequential( 219 | nn.Linear(rnn_dim * 2, mlp_dim), 220 | LReLU(), 221 | nn.Linear(mlp_dim, z_dim * 2)) 222 | if cond_ln: 223 | self.gen_mod = nn.Sequential( 224 | nn.Linear(z_dim, mlp_dim), 225 | LReLU(), 226 | nn.Linear(mlp_dim, 8 * rnn_dim)) 227 | else: 228 | self.gen_mod = nn.Linear(z_dim, mlp_dim) 229 | self.aux_mod = nn.Sequential( 230 | nn.Linear(z_dim + rnn_dim, mlp_dim), 231 | LReLU(), 232 | nn.Linear(mlp_dim, 2 * rnn_dim)) 233 | self.fwd_out_mod = nn.Linear(rnn_dim, out_dim) 234 | self.bwd_out_mod = nn.Linear(rnn_dim, self.bwd_out_dim) 235 | 236 | def save(self, filename): 237 | state = { 238 | 'emb_dim': self.emb_dim, 239 | 'rnn_dim': self.rnn_dim, 240 | 'nlayers': self.nlayers, 241 | 'mlp_dim': self.mlp_dim, 242 | 'out_dim': self.out_dim, 243 | 'bwd_out_dim': self.bwd_out_dim, 244 | 'out_type': self.out_type, 245 | 'cond_ln': self.cond_ln, 246 | 'z_force': self.z_force, 247 | 'use_l2': self.use_l2, 248 | 'z_dim': self.z_dim, 249 | 'dropout': self.dropout, 250 | 'drop_grad': self.drop_grad, 251 | 'state_dict': self.state_dict() 252 | } 253 | torch.save(state, filename) 254 | 255 | @classmethod 256 | def load(cls, filename): 257 | state = torch.load(filename) 258 | model = ZForcing( 259 | state['inp_dim'], state['emb_dim'], state['rnn_dim'], 260 | state['z_dim'], state['mlp_dim'], state['out_dim'], 261 | nlayers=state['nlayers'], cond_ln=state['cond_ln'], 262 | out_type=state['out_type'], z_force=state['z_force'], 263 | use_l2=state.get('use_l2', False), drop_grad=state.get('drop_grad', False)) 264 | model.load_state_dict(state['state_dict']) 265 | return model 266 | 267 | def reparametrize(self, mu, logvar, eps=None): 268 | std = logvar.mul(0.5).exp_() 269 | if eps is None: 270 | eps = Variable(std.data.new(std.size()).normal_()) 271 | return eps.mul(std).add_(mu) 272 | 273 | def init_hidden(self, bsz): 274 | weight = next(self.parameters()).data 275 | return (Variable(weight.new(self.nlayers, bsz, self.rnn_dim).zero_()), 276 | Variable(weight.new(self.nlayers, bsz, self.rnn_dim).zero_())) 277 | 278 | def fwd_pass(self, x_fwd, hidden, bwd_states=None, z_step=None): 279 | x_fwd_reshape = x_fwd.view(-1, *x_fwd.shape[2:]) 280 | x_emb = self.fwd_emb_mod(x_fwd_reshape) 281 | x_fwd = x_emb.view(*x_fwd.shape[:2], self.emb_dim) 282 | nsteps = x_fwd.size(0) 283 | states = [(hidden[0][0], hidden[1][0])] 284 | klds, zs, log_pz, log_qz, aux_cs = [], [], [], [], [] 285 | eps = Variable(next(self.parameters()).data.new( 286 | nsteps, x_fwd.size(1), self.z_dim).normal_()) 287 | big = Variable(next(self.parameters()).data.new(x_fwd.size(1)).zero_()) + 0.5 288 | big = torch.bernoulli(big).unsqueeze(1) 289 | 290 | assert (z_step is None) or (nsteps == 1) 291 | for step in range(nsteps): 292 | states_step = states[step] 293 | x_step = x_fwd[step] 294 | h_step, c_step = states_step[0], states_step[1] 295 | r_step = eps[step] 296 | 297 | pri_params = self.pri_mod(h_step) 298 | pri_params = torch.clamp(pri_params, -8., 8.) 299 | pri_mu, pri_logvar = torch.chunk(pri_params, 2, 1) 300 | 301 | # inference phase 302 | if bwd_states is not None: 303 | b_step = bwd_states[step] 304 | inf_params = self.inf_mod(torch.cat((h_step, b_step), 1)) 305 | inf_params = torch.clamp(inf_params, -8., 8.) 306 | inf_mu, inf_logvar = torch.chunk(inf_params, 2, 1) 307 | kld = gaussian_kld(inf_mu, inf_logvar, pri_mu, pri_logvar) 308 | z_step = self.reparametrize(inf_mu, inf_logvar, eps=r_step) 309 | if self.z_force: 310 | h_step_ = h_step * 0. 311 | else: 312 | h_step_ = h_step 313 | aux_params = self.aux_mod(torch.cat((h_step_, z_step), 1)) 314 | aux_params = torch.clamp(aux_params, -8., 8.) 315 | aux_mu, aux_logvar = torch.chunk(aux_params, 2, 1) 316 | # disconnect gradient here 317 | b_step_ = b_step.detach() 318 | if self.use_l2: 319 | aux_step = torch.sum((b_step_ - F.tanh(aux_mu)) ** 2.0, 1) 320 | else: 321 | aux_step = -log_prob_gaussian( 322 | b_step_, F.tanh(aux_mu), aux_logvar, mean=False) 323 | # generation phase 324 | else: 325 | # sample from the prior 326 | if z_step is None: 327 | z_step = self.reparametrize(pri_mu, pri_logvar, eps=r_step) 328 | aux_step = torch.sum(pri_mu * 0., -1) 329 | inf_mu, inf_logvar = pri_mu, pri_logvar 330 | kld = aux_step 331 | 332 | i_step = self.gen_mod(z_step) 333 | if self.cond_ln: 334 | i_step = torch.clamp(i_step, -3, 3) 335 | gain_hh, bias_hh = torch.chunk(i_step, 2, 1) 336 | gain_hh = 1. + gain_hh 337 | h_new, c_new = self.fwd_mod(x_step, (h_step, c_step), 338 | gain_hh=gain_hh, bias_hh=bias_hh) 339 | else: 340 | h_new, c_new = self.fwd_mod(torch.cat((i_step, x_step), 1), 341 | (h_step, c_step)) 342 | states.append((h_new, c_new)) 343 | klds.append(kld) 344 | zs.append(z_step) 345 | aux_cs.append(aux_step) 346 | log_pz.append(log_prob_gaussian(z_step, pri_mu, pri_logvar)) 347 | log_qz.append(log_prob_gaussian(z_step, inf_mu, inf_logvar)) 348 | 349 | klds = torch.stack(klds, 0) 350 | aux_cs = torch.stack(aux_cs, 0) 351 | log_pz = torch.stack(log_pz, 0) 352 | log_qz = torch.stack(log_qz, 0) 353 | zs = torch.stack(zs, 0) 354 | 355 | outputs = [s[0] for s in states[1:]] 356 | outputs = torch.stack(outputs, 0) 357 | outputs = self.fwd_out_mod(outputs) 358 | return outputs, states[1:], klds, aux_cs, zs, log_pz, log_qz 359 | 360 | def infer(self, x, hidden): 361 | '''Infer latent variables for a given batch of sentences ``x''. 362 | ''' 363 | x_ = x[:-1] 364 | y_ = x[1:] 365 | bwd_states, bwd_outputs = self.bwd_pass(x_, y_, hidden) 366 | fwd_outputs, fwd_states, klds, aux_nll, zs, log_pz, log_qz = self.fwd_pass( 367 | x_, hidden, bwd_states=bwd_states) 368 | return zs 369 | 370 | def bwd_pass(self, x, hidden): 371 | idx = np.arange(x.size(0))[::-1].tolist() 372 | idx = torch.LongTensor(idx) 373 | idx = Variable(idx).cuda() 374 | 375 | # invert the targets and revert back 376 | x_bwd = x.index_select(0, idx) 377 | # x_bwd = torch.cat([x_bwd, x[:1]], 0) 378 | x_bwd_reshape = x_bwd.view(-1, *x_bwd.shape[2:]) 379 | print(x_bwd_reshape.shape) 380 | x_emb = self.bwd_emb_mod(x_bwd_reshape) 381 | x_bwd = x_emb.view(*x_bwd.shape[:2], self.emb_dim) 382 | states, _ = self.bwd_mod(x_bwd, hidden) 383 | outputs = self.bwd_out_mod(states) 384 | states = states.index_select(0, idx) 385 | outputs = outputs.index_select(0, idx) 386 | return states, outputs 387 | 388 | def generate_onestep(self, x_fwd, x_mask, hidden): 389 | nsteps, nbatch = x_fwd.size(0), x_fwd.size(1) 390 | #bwd_states, bwd_outputs = self.bwd_pass(x_bwd, hidden) 391 | fwd_outputs, fwd_states, klds, aux_nll, zs, log_pz, log_qz = self.fwd_pass( 392 | x_fwd, hidden) 393 | output_prob = F.softmax(fwd_outputs.squeeze(0)) 394 | sampled_output = torch.multinomial(output_prob, 1) 395 | hidden = (fwd_states[0][0].unsqueeze(0), fwd_states[0][1].unsqueeze(0)) 396 | if self.return_loss: 397 | return (sampled_output, hidden, aux_nll) 398 | else: 399 | return (sampled_output, hidden) 400 | 401 | '''kld = (klds * x_mask).sum(0) 402 | log_pz = (log_pz * x_mask).sum(0) 403 | log_qz = (log_qz * x_mask).sum(0) 404 | aux_nll = (aux_nll * x_mask).sum(0) 405 | if self.out_type == 'gaussian': 406 | out_mu, out_logvar = torch.chunk(fwd_outputs, 2, -1) 407 | fwd_nll = -log_prob_gaussian(y, out_mu, out_logvar) 408 | fwd_nll = (fwd_nll * x_mask).sum(0) 409 | out_mu, out_logvar = torch.chunk(bwd_outputs, 2, -1) 410 | bwd_nll = -log_prob_gaussian(y, out_mu, out_logvar) 411 | bwd_nll = (bwd_nll * x_mask).sum(0) 412 | elif self.out_type == 'softmax': 413 | fwd_out = fwd_outputs.view(nsteps * nbatch, self.out_dim) 414 | fwd_out = F.log_softmax(fwd_out) 415 | y = y.view(-1, 1) 416 | fwd_nll = torch.gather(fwd_out, 1, y).squeeze(1) 417 | fwd_nll = fwd_nll.view(nsteps, nbatch) 418 | fwd_nll = -(fwd_nll * x_mask).sum(0) 419 | bwd_out = bwd_outputs.view(nsteps * nbatch, self.out_dim) 420 | bwd_out = F.log_softmax(bwd_out) 421 | y = y.view(-1, 1) 422 | bwd_nll = torch.gather(bwd_out, 1, y).squeeze(1) 423 | bwd_nll = -bwd_nll.view(nsteps, nbatch) 424 | bwd_nll = (bwd_nll * x_mask).sum(0) 425 | 426 | if return_stats: 427 | return fwd_nll, bwd_nll, aux_nll, kld, log_pz, log_qz 428 | return fwd_nll.mean(), bwd_nll.mean(), aux_nll.mean(), kld.mean() ''' 429 | 430 | 431 | def forward(self, x_fwd, x_bwd, y, x_mask, hidden, y_bwd = None, return_stats=False): 432 | if y_bwd is None: 433 | y_bwd = y 434 | nsteps, nbatch = x_fwd.size(0), x_fwd.size(1) 435 | bwd_states, bwd_outputs = self.bwd_pass(x_bwd, hidden) 436 | fwd_outputs, fwd_states, klds, aux_nll, zs, log_pz, log_qz = self.fwd_pass( 437 | x_fwd, hidden, bwd_states=bwd_states) 438 | kld = (klds * x_mask).sum(0) 439 | log_pz = (log_pz * x_mask).sum(0) 440 | log_qz = (log_qz * x_mask).sum(0) 441 | aux_nll = (aux_nll * x_mask).sum(0) 442 | if self.out_type == 'gaussian': 443 | out_mu, out_logvar = torch.chunk(fwd_outputs, 2, -1) 444 | fwd_nll = -log_prob_gaussian(y, out_mu, out_logvar) 445 | fwd_nll = (fwd_nll * x_mask).sum(0) 446 | elif self.out_type == 'softmax': 447 | fwd_out = fwd_outputs.view(nsteps * nbatch, self.out_dim) 448 | fwd_out = F.log_softmax(fwd_out) 449 | y = y.view(-1, 1) 450 | fwd_nll = torch.gather(fwd_out, 1, y.long()).squeeze(1) 451 | fwd_nll = fwd_nll.view(nsteps, nbatch) 452 | fwd_nll = -(fwd_nll * x_mask).sum(0) 453 | if self.bwd_out_type == 'softmax': 454 | bwd_out = bwd_outputs.view(nsteps * nbatch, self.out_dim) 455 | bwd_out = F.log_softmax(bwd_out) 456 | y_bwd = y_bwd.view(-1, 1) 457 | bwd_nll = torch.gather(bwd_out, 1, y_bwd.long()).squeeze(1) 458 | bwd_nll = -bwd_nll.view(nsteps, nbatch) 459 | bwd_nll = (bwd_nll * x_mask).sum(0) 460 | elif self.bwd_out_type == 'gaussian': 461 | out_mu, out_logvar = torch.chunk(bwd_outputs, 2, -1) 462 | bwd_nll = -log_prob_gaussian(y_bwd, out_mu, out_logvar) 463 | bwd_nll = (bwd_nll * x_mask).sum(0) 464 | if return_kld: 465 | return fwd_nll.mean(), bwd_nll.mean(), aux_nll.mean(), klds 466 | if return_stats: 467 | return fwd_nll, bwd_nll, aux_nll, kld, log_pz, log_qz 468 | return fwd_nll.mean(), bwd_nll.mean(), aux_nll.mean(), kld.mean() 469 | -------------------------------------------------------------------------------- /scripts/rl_zforcing_dec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | from torch.nn import Parameter 6 | from torch.autograd import Variable 7 | import torchvision.datasets as dsets 8 | import time 9 | import click 10 | import numpy 11 | import numpy as np 12 | import os 13 | import random 14 | from itertools import chain 15 | import torch.nn.functional as F 16 | from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend 17 | import math 18 | 19 | 20 | class View(nn.Module): 21 | def __init__(self, *args): 22 | super(View, self).__init__() 23 | self._shape = args 24 | def forward(self, x): 25 | return x.view(*self._shape) 26 | 27 | def log_prob_gaussian(x, mu, log_vars, mean=False): 28 | lp = - 0.5 * math.log(2 * math.pi) \ 29 | - log_vars / 2 - (x - mu) ** 2 / (2 * torch.exp(log_vars)) 30 | if mean: 31 | return torch.mean(lp, -1) 32 | return torch.sum(lp, -1) 33 | 34 | def log_prob_bernoulli(x, mu): 35 | lp = x * torch.log(mu + 1e-5) + (1. - y) * torch.log(1. - mu + 1e-5) 36 | return lp 37 | 38 | 39 | def gaussian_kld(mu_left, logvar_left, mu_right, logvar_right): 40 | """ 41 | Compute KL divergence between a bunch of univariate Gaussian distributions 42 | with the given means and log-variances. 43 | We do KL(N(mu_left, logvar_left) || N(mu_right, logvar_right)). 44 | """ 45 | gauss_klds = 0.5 * (logvar_right - logvar_left + 46 | (torch.exp(logvar_left) / torch.exp(logvar_right)) + 47 | ((mu_left - mu_right)**2.0 / torch.exp(logvar_right)) - 1.0) 48 | assert len(gauss_klds.size()) == 2 49 | return torch.sum(gauss_klds, 1) 50 | 51 | 52 | class LayerNorm(nn.Module): 53 | def __init__(self, nb_features, eps=1e-5): 54 | super(LayerNorm, self).__init__() 55 | self.eps = eps 56 | self.gain = nn.Parameter(torch.ones(1, nb_features)) 57 | self.bias = nn.Parameter(torch.zeros(1, nb_features)) 58 | 59 | def forward(self, x, gain=None, bias=None): 60 | assert len(x.size()) == 2 61 | if gain is None: 62 | gain = self.gain 63 | if bias is None: 64 | bias = self.bias 65 | mean = torch.mean(x, dim=-1, keepdim=True) 66 | std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + self.eps) 67 | z = (x - mean.expand_as(x)) / std.expand_as(x) 68 | return z * gain.expand_as(z) + bias.expand_as(z) 69 | 70 | 71 | class LSTMCell(nn.Module): 72 | """A basic LSTM cell.""" 73 | 74 | def __init__(self, input_size, hidden_size, use_layernorm=False): 75 | """ 76 | Most parts are copied from torch.nn.LSTMCell. 77 | """ 78 | 79 | super(LSTMCell, self).__init__() 80 | self.input_size = input_size 81 | self.hidden_size = hidden_size 82 | self.use_layernorm = use_layernorm 83 | self.has_bias = not self.use_layernorm 84 | if self.use_layernorm: 85 | self.use_bias = False 86 | print("LSTMCell: use_layernorm=%s" % use_layernorm) 87 | self.weight_ih = nn.Parameter( 88 | torch.FloatTensor(input_size, 4 * hidden_size)) 89 | self.weight_hh = nn.Parameter( 90 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 91 | if self.use_layernorm: 92 | self.ln_ih = LayerNorm(4 * hidden_size) 93 | self.ln_hh = LayerNorm(4 * hidden_size) 94 | else: 95 | self.bias_ih = Parameter(torch.FloatTensor(4 * hidden_size)) 96 | self.bias_hh = Parameter(torch.FloatTensor(4 * hidden_size)) 97 | self.init_weights() 98 | 99 | def init_weights(self): 100 | """ 101 | Initialize parameters following the way proposed in the paper. 102 | """ 103 | stdv = 1.0 / np.sqrt(self.hidden_size) 104 | self.weight_ih.data.uniform_(-stdv, stdv) 105 | nn.init.orthogonal(self.weight_hh.data) 106 | if self.has_bias: 107 | self.bias_ih.data.fill_(0) 108 | self.bias_hh.data.fill_(0) 109 | 110 | def forward(self, input_, hx, 111 | gain_ih=None, gain_hh=None, 112 | bias_ih=None, bias_hh=None): 113 | """ 114 | Args: 115 | input_: A (batch, input_size) tensor containing input 116 | features. 117 | hx: A tuple (h_0, c_0), which contains the initial hidden 118 | and cell state, where the size of both states is 119 | (batch, hidden_size). 120 | Returns: 121 | h_1, c_1: Tensors containing the next hidden and cell state. 122 | """ 123 | assert input_.is_cuda 124 | h_0, c_0 = hx 125 | igates = torch.mm(input_, self.weight_ih) 126 | hgates = torch.mm(h_0, self.weight_hh) 127 | state = fusedBackend.LSTMFused() 128 | if self.use_layernorm: 129 | igates = self.ln_ih(igates, gain=gain_ih, bias=bias_ih) 130 | hgates = self.ln_hh(hgates, gain=gain_hh, bias=bias_hh) 131 | return state.apply(igates, hgates, c_0) 132 | else: 133 | return state.apply(igates, hgates, c_0, 134 | self.bias_ih, self.bias_hh) 135 | 136 | def __repr__(self): 137 | s = '{name}({input_size}, {hidden_size})' 138 | return s.format(name=self.__class__.__name__, **self.__dict__) 139 | 140 | 141 | class LReLU(nn.Module): 142 | def __init__(self, c=1./3): 143 | super(LReLU, self).__init__() 144 | self.c = c 145 | 146 | def forward(self, x): 147 | return torch.clamp(F.leaky_relu(x, self.c), -3., 3.) 148 | 149 | 150 | class ZForcing(nn.Module): 151 | def __init__(self, emb_dim, rnn_dim, 152 | z_dim, mlp_dim, out_dim, bwd_out_dim=None, out_type="softmax", bwd_out_type = 'softmax', 153 | cond_ln=False, nlayers=1, z_force=False, dropout=0., 154 | use_l2=False, drop_grad=False, return_loss=False): 155 | super(ZForcing, self).__init__() 156 | assert not drop_grad, "drop_grad is not supported!" 157 | self.emb_dim = emb_dim 158 | self.out_dim = out_dim 159 | self.rnn_dim = rnn_dim 160 | self.nlayers = nlayers 161 | self.z_dim = z_dim 162 | self.return_loss = return_loss 163 | self.dropout = dropout 164 | self.out_type = out_type 165 | self.bwd_out_type = bwd_out_type 166 | self.mlp_dim = mlp_dim 167 | self.cond_ln = cond_ln 168 | self.z_force = z_force 169 | self.use_l2 = use_l2 170 | self.drop_grad = drop_grad 171 | if bwd_out_dim is None: 172 | self.bwd_out_dim = self.out_dim 173 | else: 174 | self.bwd_out_dim = bwd_out_dim 175 | 176 | '''self.fwd_emb_mod = nn.Sequential( 177 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1), 178 | nn.LeakyReLU(), 179 | nn.MaxPool2d(kernel_size=3, stride=2), 180 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 181 | nn.LeakyReLU(), 182 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 183 | nn.LeakyReLU(), 184 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), 185 | nn.LeakyReLU(), 186 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 187 | nn.LeakyReLU(), 188 | #nn.AvgPool2d(4), 189 | View(-1, 256), 190 | nn.Linear(256, emb_dim), 191 | nn.Dropout(dropout)) 192 | ''' 193 | self.fwd_emb_mod = nn.Sequential( 194 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1), 195 | nn.LeakyReLU(), 196 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3,stride=2, padding=1), 197 | nn.LeakyReLU(), 198 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 199 | nn.LeakyReLU()) 200 | self.linear1 = nn.Linear(256, emb_dim) 201 | 202 | '''self.bwd_emb_mod = nn.Sequential( 203 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1), 204 | nn.LeakyReLU(), 205 | nn.MaxPool2d(kernel_size=3, stride=2), 206 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 207 | nn.LeakyReLU(), 208 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 209 | nn.LeakyReLU(), 210 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), 211 | nn.LeakyReLU(), 212 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 213 | nn.LeakyReLU(), 214 | #nn.AvgPool2d(4), 215 | View(-1, 256), 216 | nn.Linear(256, emb_dim), 217 | nn.Dropout(dropout))''' 218 | 219 | self.bwd_emb_mod = nn.Sequential( 220 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1), 221 | nn.LeakyReLU(), 222 | nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3,stride=2, padding=1), 223 | nn.LeakyReLU(), 224 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 225 | nn.LeakyReLU(), 226 | View(-1, 256), 227 | nn.Linear(256, emb_dim)) 228 | 229 | self.linear1 = nn.Linear(256, emb_dim) 230 | self.dec_linear = nn.Linear(rnn_dim, 256) 231 | self.l2_loss = nn.MSELoss() 232 | self.l1_loss = nn.L1Loss() 233 | self.ce_loss = nn.CrossEntropyLoss() 234 | self.bwd_dec_mod = nn.Sequential( 235 | nn.ConvTranspose2d(in_channels=64,out_channels=16,kernel_size=3,stride=1, padding=1), 236 | nn.LeakyReLU(), 237 | nn.ConvTranspose2d(in_channels=16,out_channels=8,kernel_size=3,stride=2, padding=1), 238 | nn.LeakyReLU(), 239 | nn.ConvTranspose2d(in_channels=8,out_channels=3,kernel_size=5,stride=2, padding=1), 240 | nn.Sigmoid() 241 | 242 | ) 243 | self.softmax = torch.nn.Softmax() 244 | self.bwd_mod = nn.LSTM(emb_dim, rnn_dim, nlayers) 245 | nn.init.orthogonal(self.bwd_mod.weight_hh_l0.data) 246 | self.fwd_mod = LSTMCell( 247 | emb_dim if cond_ln else emb_dim + mlp_dim, 248 | rnn_dim, use_layernorm=cond_ln) 249 | self.pri_mod = nn.Sequential( 250 | nn.Linear(rnn_dim, mlp_dim), 251 | LReLU(), 252 | nn.Linear(mlp_dim, z_dim * 2)) 253 | self.inf_mod = nn.Sequential( 254 | nn.Linear(rnn_dim * 2, mlp_dim), 255 | LReLU(), 256 | nn.Linear(mlp_dim, z_dim * 2)) 257 | if cond_ln: 258 | self.gen_mod = nn.Sequential( 259 | nn.Linear(z_dim, mlp_dim), 260 | LReLU(), 261 | nn.Linear(mlp_dim, 8 * rnn_dim)) 262 | else: 263 | self.gen_mod = nn.Linear(z_dim, mlp_dim) 264 | self.aux_mod = nn.Sequential( 265 | nn.Linear(z_dim + rnn_dim, mlp_dim), 266 | LReLU(), 267 | nn.Linear(mlp_dim, 2 * rnn_dim)) 268 | self.fwd_out_mod = nn.Linear(rnn_dim, out_dim) 269 | self.bwd_out_mod = nn.Linear(rnn_dim, self.bwd_out_dim) 270 | 271 | def save(self, filename): 272 | state = { 273 | 'emb_dim': self.emb_dim, 274 | 'rnn_dim': self.rnn_dim, 275 | 'nlayers': self.nlayers, 276 | 'mlp_dim': self.mlp_dim, 277 | 'out_dim': self.out_dim, 278 | 'bwd_out_dim': self.bwd_out_dim, 279 | 'out_type': self.out_type, 280 | 'cond_ln': self.cond_ln, 281 | 'z_force': self.z_force, 282 | 'use_l2': self.use_l2, 283 | 'z_dim': self.z_dim, 284 | 'dropout': self.dropout, 285 | 'drop_grad': self.drop_grad, 286 | 'state_dict': self.state_dict() 287 | } 288 | torch.save(state, filename) 289 | 290 | @classmethod 291 | def load(cls, filename): 292 | state = torch.load(filename) 293 | model = ZForcing( 294 | state['inp_dim'], state['emb_dim'], state['rnn_dim'], 295 | state['z_dim'], state['mlp_dim'], state['out_dim'], 296 | nlayers=state['nlayers'], cond_ln=state['cond_ln'], 297 | out_type=state['out_type'], z_force=state['z_force'], 298 | use_l2=state.get('use_l2', False), drop_grad=state.get('drop_grad', False)) 299 | model.load_state_dict(state['state_dict']) 300 | return model 301 | 302 | def reparametrize(self, mu, logvar, eps=None): 303 | std = logvar.mul(0.5).exp_() 304 | if eps is None: 305 | eps = Variable(std.data.new(std.size()).normal_()) 306 | return eps.mul(std).add_(mu) 307 | 308 | def init_hidden(self, bsz): 309 | weight = next(self.parameters()).data 310 | return (Variable(weight.new(self.nlayers, bsz, self.rnn_dim).zero_()), 311 | Variable(weight.new(self.nlayers, bsz, self.rnn_dim).zero_())) 312 | 313 | def fwd_pass(self, x_fwd, hidden, bwd_states=None, z_step=None): 314 | x_fwd_reshape = x_fwd.view(-1, *x_fwd.shape[2:]) 315 | x_emb = self.fwd_emb_mod(x_fwd_reshape).view(-1, 256) 316 | x_emb = self.linear1(x_emb) 317 | 318 | x_fwd = x_emb.view(*x_fwd.shape[:2], self.emb_dim) 319 | 320 | nsteps = x_fwd.size(0) 321 | states = [(hidden[0][0], hidden[1][0])] 322 | klds, zs, log_pz, log_qz, aux_cs = [], [], [], [], [] 323 | eps = Variable(next(self.parameters()).data.new( 324 | nsteps, x_fwd.size(1), self.z_dim).normal_()) 325 | big = Variable(next(self.parameters()).data.new(x_fwd.size(1)).zero_()) + 0.5 326 | big = torch.bernoulli(big).unsqueeze(1) 327 | 328 | assert (z_step is None) or (nsteps == 1) 329 | for step in range(nsteps): 330 | states_step = states[step] 331 | x_step = x_fwd[step] 332 | h_step, c_step = states_step[0], states_step[1] 333 | r_step = eps[step] 334 | 335 | pri_params = self.pri_mod(h_step) 336 | pri_params = torch.clamp(pri_params, -8., 8.) 337 | pri_mu, pri_logvar = torch.chunk(pri_params, 2, 1) 338 | 339 | # inference phase 340 | if bwd_states is not None: 341 | b_step = bwd_states[step] 342 | b_step = b_step.detach() 343 | inf_params = self.inf_mod(torch.cat((h_step, b_step), 1)) 344 | inf_params = torch.clamp(inf_params, -8., 8.) 345 | inf_mu, inf_logvar = torch.chunk(inf_params, 2, 1) 346 | kld = gaussian_kld(inf_mu, inf_logvar, pri_mu, pri_logvar) 347 | z_step = self.reparametrize(inf_mu, inf_logvar, eps=r_step) 348 | if self.z_force: 349 | h_step_ = h_step * 0. 350 | else: 351 | h_step_ = h_step 352 | aux_params = self.aux_mod(torch.cat((h_step_, z_step), 1)) 353 | aux_params = torch.clamp(aux_params, -8., 8.) 354 | aux_mu, aux_logvar = torch.chunk(aux_params, 2, 1) 355 | # disconnect gradient here 356 | b_step_ = b_step.detach() 357 | if self.use_l2: 358 | aux_step = torch.sum((b_step_ - F.tanh(aux_mu)) ** 2.0, 1) 359 | else: 360 | aux_step = -log_prob_gaussian( 361 | b_step_, F.tanh(aux_mu), aux_logvar, mean=False) 362 | # generation phase 363 | else: 364 | # sample from the prior 365 | if z_step is None: 366 | z_step = self.reparametrize(pri_mu, pri_logvar, eps=r_step) 367 | aux_step = torch.sum(pri_mu * 0., -1) 368 | inf_mu, inf_logvar = pri_mu, pri_logvar 369 | kld = aux_step 370 | 371 | i_step = self.gen_mod(z_step) 372 | if self.cond_ln: 373 | i_step = torch.clamp(i_step, -3, 3) 374 | gain_hh, bias_hh = torch.chunk(i_step, 2, 1) 375 | gain_hh = 1. + gain_hh 376 | h_new, c_new = self.fwd_mod(x_step, (h_step, c_step), 377 | gain_hh=gain_hh, bias_hh=bias_hh) 378 | else: 379 | h_new, c_new = self.fwd_mod(torch.cat((i_step, x_step), 1), 380 | (h_step, c_step)) 381 | states.append((h_new, c_new)) 382 | klds.append(kld) 383 | zs.append(z_step) 384 | aux_cs.append(aux_step) 385 | log_pz.append(log_prob_gaussian(z_step, pri_mu, pri_logvar)) 386 | log_qz.append(log_prob_gaussian(z_step, inf_mu, inf_logvar)) 387 | 388 | klds = torch.stack(klds, 0) 389 | aux_cs = torch.stack(aux_cs, 0) 390 | log_pz = torch.stack(log_pz, 0) 391 | log_qz = torch.stack(log_qz, 0) 392 | zs = torch.stack(zs, 0) 393 | 394 | outputs = [s[0] for s in states[1:]] 395 | outputs = torch.stack(outputs, 0) 396 | outputs = self.fwd_out_mod(outputs) 397 | return outputs, states[1:], klds, aux_cs, zs, log_pz, log_qz 398 | 399 | def infer(self, x, hidden): 400 | '''Infer latent variables for a given batch of sentences ``x''. 401 | ''' 402 | x_ = x[:-1] 403 | y_ = x[1:] 404 | bwd_states, bwd_outputs = self.bwd_pass(x_, y_, hidden) 405 | fwd_outputs, fwd_states, klds, aux_nll, zs, log_pz, log_qz = self.fwd_pass( 406 | x_, hidden, bwd_states=bwd_states) 407 | return zs 408 | 409 | def bwd_pass(self, x, y, hidden): 410 | idx = np.arange(x.size(0))[::-1].tolist() 411 | idx = torch.LongTensor(idx) 412 | idx = Variable(idx).cuda() 413 | # invert the targets and revert back 414 | x_bwd = x.index_select(0, idx) 415 | 416 | y_bwd = y.index_select(0, idx) 417 | # x_bwd = torch.cat([x_bwd, x[:1]], 0) 418 | x_bwd_reshape = x_bwd.view(-1, *x_bwd.shape[2:]) 419 | x_emb = self.bwd_emb_mod(x_bwd_reshape) 420 | x_bwd = x_emb.view(*x_bwd.shape[:2], self.emb_dim) 421 | states, _ = self.bwd_mod(x_bwd, hidden) 422 | dec_states = self.dec_linear(states) 423 | dec_states = dec_states.reshape(-1, 64, 2, 2) 424 | dec_outputs = self.bwd_dec_mod(dec_states) 425 | # reshape dec_outputs and compute cross entropy loss (already includes softmax) 426 | bwd_l2_loss = self.l2_loss(dec_outputs.view_as(y_bwd), y_bwd) 427 | outputs = self.bwd_out_mod(states) 428 | states = states.index_select(0, idx) 429 | outputs = outputs.index_select(0, idx) 430 | return states, outputs, bwd_l2_loss 431 | 432 | def generate_onestep(self, x_fwd, x_mask, hidden): 433 | nsteps, nbatch = x_fwd.size(0), x_fwd.size(1) 434 | #bwd_states, bwd_outputs = self.bwd_pass(x_bwd, hidden) 435 | fwd_outputs, fwd_states, klds, aux_nll, zs, log_pz, log_qz = self.fwd_pass( 436 | x_fwd, hidden) 437 | output_prob = F.softmax(fwd_outputs.squeeze(0)) 438 | sampled_output = torch.multinomial(output_prob, 1) 439 | hidden = (fwd_states[0][0].unsqueeze(0), fwd_states[0][1].unsqueeze(0)) 440 | if self.return_loss: 441 | return (sampled_output, hidden, aux_nll) 442 | else: 443 | return (sampled_output, hidden) 444 | 445 | '''kld = (klds * x_mask).sum(0) 446 | log_pz = (log_pz * x_mask).sum(0) 447 | log_qz = (log_qz * x_mask).sum(0) 448 | aux_nll = (aux_nll * x_mask).sum(0) 449 | if self.out_type == 'gaussian': 450 | out_mu, out_logvar = torch.chunk(fwd_outputs, 2, -1) 451 | fwd_nll = -log_prob_gaussian(y, out_mu, out_logvar) 452 | fwd_nll = (fwd_nll * x_mask).sum(0) 453 | out_mu, out_logvar = torch.chunk(bwd_outputs, 2, -1) 454 | bwd_nll = -log_prob_gaussian(y, out_mu, out_logvar) 455 | bwd_nll = (bwd_nll * x_mask).sum(0) 456 | elif self.out_type == 'softmax': 457 | fwd_out = fwd_outputs.view(nsteps * nbatch, self.out_dim) 458 | fwd_out = F.log_softmax(fwd_out) 459 | y = y.view(-1, 1) 460 | fwd_nll = torch.gather(fwd_out, 1, y).squeeze(1) 461 | fwd_nll = fwd_nll.view(nsteps, nbatch) 462 | fwd_nll = -(fwd_nll * x_mask).sum(0) 463 | bwd_out = bwd_outputs.view(nsteps * nbatch, self.out_dim) 464 | bwd_out = F.log_softmax(bwd_out) 465 | y = y.view(-1, 1) 466 | bwd_nll = torch.gather(bwd_out, 1, y).squeeze(1) 467 | bwd_nll = -bwd_nll.view(nsteps, nbatch) 468 | bwd_nll = (bwd_nll * x_mask).sum(0) 469 | 470 | if return_stats: 471 | return fwd_nll, bwd_nll, aux_nll, kld, log_pz, log_qz 472 | return fwd_nll.mean(), bwd_nll.mean(), aux_nll.mean(), kld.mean() ''' 473 | 474 | 475 | def forward(self, x_fwd, x_bwd, y, dec_bwd, x_mask, hidden, y_bwd = None, return_stats=False, return_per_step=False): 476 | if y_bwd is None: 477 | y_bwd = y 478 | nsteps, nbatch = x_fwd.size(0), x_fwd.size(1) 479 | bwd_states, bwd_outputs, bwd_l2_loss = self.bwd_pass(x_bwd, dec_bwd,hidden) 480 | fwd_outputs, fwd_states, klds, aux_nlls, zs, log_pz, log_qz = self.fwd_pass( 481 | x_fwd, hidden, bwd_states=bwd_states) 482 | kld = (klds * x_mask).sum(0) 483 | log_pz = (log_pz * x_mask).sum(0) 484 | log_qz = (log_qz * x_mask).sum(0) 485 | aux_nll = (aux_nlls * x_mask).sum(0) 486 | 487 | 488 | if self.out_type == 'gaussian': 489 | out_mu, out_logvar = torch.chunk(fwd_outputs, 2, -1) 490 | fwd_nll = -log_prob_gaussian(y, out_mu, out_logvar) 491 | fwd_nll = (fwd_nll * x_mask).sum(0) 492 | elif self.out_type == 'softmax': 493 | fwd_out = fwd_outputs.view(nsteps * nbatch, self.out_dim) 494 | fwd_out = F.log_softmax(fwd_out) 495 | y = y.view(-1, 1) 496 | fwd_nll = torch.gather(fwd_out, 1, y.long()).squeeze(1) 497 | fwd_nll = fwd_nll.view(nsteps, nbatch) 498 | fwd_nll = -(fwd_nll * x_mask).sum(0) 499 | if self.bwd_out_type == 'softmax': 500 | bwd_out = bwd_outputs.view(nsteps * nbatch, self.out_dim) 501 | bwd_out = F.log_softmax(bwd_out) 502 | y_bwd = y_bwd.view(-1, 1) 503 | bwd_nll = torch.gather(bwd_out, 1, y_bwd.long()).squeeze(1) 504 | bwd_nll = -bwd_nll.view(nsteps, nbatch) 505 | bwd_nll = (bwd_nll * x_mask).sum(0) 506 | elif self.bwd_out_type == 'gaussian': 507 | out_mu, out_logvar = torch.chunk(bwd_outputs, 2, -1) 508 | bwd_nll = -log_prob_gaussian(y_bwd, out_mu, out_logvar) 509 | bwd_nll = (bwd_nll * x_mask).sum(0) 510 | if return_per_step: 511 | return fwd_nll, bwd_nll, aux_nlls, klds, log_pz, bwd_l2_loss 512 | elif return_stats: 513 | return fwd_nll, bwd_nll, aux_nll, kld, log_pz, log_qz, bwd_l2_loss 514 | return fwd_nll.mean(), bwd_nll.mean(), aux_nll.mean(), kld.mean(), bwd_l2_loss 515 | -------------------------------------------------------------------------------- /scripts/train_curclm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | # -*- coding: utf-8 -*- 4 | 5 | import argparse 6 | import gym 7 | import time 8 | import datetime 9 | import torch 10 | import torch_rl 11 | 12 | import babyai.utils as utils 13 | from babyai.model import ACModel 14 | 15 | # Parse arguments 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--algo", required=True, 19 | help="algorithm to use: a2c | ppo (REQUIRED)") 20 | parser.add_argument("--env", required=True, 21 | help="name of the environment to train on (REQUIRED)") 22 | parser.add_argument("--model", default=None, 23 | help="name of the model (default: ENV_ALGO_TIME)") 24 | parser.add_argument("--seed", type=int, default=1, 25 | help="random seed (default: 1)") 26 | parser.add_argument("--procs", type=int, default=16, 27 | help="number of processes (default: 16)") 28 | parser.add_argument("--frames", type=int, default=10**7, 29 | help="number of frames of training (default: 10e7)") 30 | parser.add_argument("--log-interval", type=int, default=1, 31 | help="number of updates between two logs (default: 1)") 32 | parser.add_argument("--save-interval", type=int, default=10, 33 | help="number of updates between two saves (default: 0, 0 means no saving)") 34 | parser.add_argument("--tb", action="store_true", default=False, 35 | help="log into Tensorboard") 36 | parser.add_argument("--frames-per-proc", type=int, default=None, 37 | help="number of frames per process before update (default: 5 for A2C and 128 for PPO)") 38 | parser.add_argument("--discount", type=float, default=0.99, 39 | help="discount factor (default: 0.99)") 40 | parser.add_argument("--lr", type=float, default=7e-4, 41 | help="learning rate (default: 7e-4)") 42 | parser.add_argument("--gae-tau", type=float, default=0.95, 43 | help="tau coefficient in GAE formula (default: 0.95, 1 means no gae)") 44 | parser.add_argument("--entropy-coef", type=float, default=0.01, 45 | help="entropy term coefficient (default: 0.01)") 46 | parser.add_argument("--value-loss-coef", type=float, default=0.5, 47 | help="value loss term coefficient (default: 0.5)") 48 | parser.add_argument("--max-grad-norm", type=float, default=0.5, 49 | help="maximum norm of gradient (default: 0.5)") 50 | parser.add_argument("--recurrence", type=int, default=1, 51 | help="number of timesteps gradient is backpropagated (default: 1)") 52 | parser.add_argument("--optim-eps", type=float, default=1e-5, 53 | help="Adam and RMSprop optimizer epsilon (default: 1e-5)") 54 | parser.add_argument("--optim-alpha", type=float, default=0.99, 55 | help="RMSprop optimizer apha (default: 0.99)") 56 | parser.add_argument("--clip-eps", type=float, default=0.2, 57 | help="clipping epsilon for PPO (default: 0.2)") 58 | parser.add_argument("--epochs", type=int, default=4, 59 | help="number of epochs for PPO (default: 4)") 60 | parser.add_argument("--batch-size", type=int, default=256, 61 | help="batch size for PPO (default: 256)") 62 | parser.add_argument("--room-size", type=int, default=6, 63 | help="room size for env (default: 6)") 64 | parser.add_argument("--instr-model", default=None, 65 | help="model to encode instructions, None if not using instructions, possible values: gru, conv, bow") 66 | parser.add_argument("--no-mem", action="store_true", default=False, 67 | help="don't use memory in the model") 68 | parser.add_argument("--arch", default='cnn1', 69 | help="image embedding architecture") 70 | 71 | args = parser.parse_args() 72 | 73 | # Set seed for all randomness sources 74 | 75 | utils.seed(args.seed) 76 | 77 | # Generate environments 78 | 79 | envs = [] 80 | for i in range(args.procs): 81 | env = gym.make(args.env) 82 | env.set_roomsize(args.room_size) 83 | env.seed(args.seed + i) 84 | envs.append(env) 85 | 86 | # Define model name 87 | import ipdb; ipdb.set_trace() 88 | suffix = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") 89 | instr = args.isntr_model if args.instr_model else "noinstr" 90 | mem = "mem" if not args.no_mem else "nomem" 91 | default_model_name = "{}_{}_{}_{}_{}_room_{}_start_flag_new_seed{}_{}".format(args.env, 92 | args.algo, 93 | args.arch, 94 | instr, 95 | mem, 96 | args.room_size, 97 | args.seed, 98 | suffix) 99 | model_name = args.model or default_model_name 100 | 101 | # Define obss preprocessor 102 | 103 | obss_preprocessor = utils.ObssPreprocessor(model_name, envs[0].observation_space) 104 | 105 | # Define actor-critic model 106 | acmodel = utils.load_model(model_name, raise_not_found=False) 107 | if acmodel is None: 108 | acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space, 109 | args.instr_model, not args.no_mem, args.arch) 110 | if torch.cuda.is_available(): 111 | acmodel.cuda() 112 | 113 | # Define actor-critic algo 114 | 115 | if args.algo == "a2c": 116 | algo = torch_rl.A2CAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau, 117 | args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, 118 | args.optim_alpha, args.optim_eps, obss_preprocessor, utils.reshape_reward) 119 | elif args.algo == "ppo": 120 | algo = torch_rl.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau, 121 | args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, 122 | args.optim_eps, args.clip_eps, args.epochs, args.batch_size, obss_preprocessor, 123 | utils.reshape_reward) 124 | else: 125 | raise ValueError("Incorrect algorithm name: {}".format(args.algo)) 126 | 127 | # Define logger and Tensorboard writer 128 | 129 | logger = utils.get_logger(model_name) 130 | if args.tb: 131 | from tensorboardX import SummaryWriter 132 | writer = SummaryWriter(utils.get_log_dir(model_name)) 133 | 134 | # Log command, availability of CUDA and model 135 | 136 | logger.info(args) 137 | logger.info("CUDA available: {}".format(torch.cuda.is_available())) 138 | logger.info(acmodel) 139 | 140 | # Train model 141 | 142 | num_frames = 0 143 | total_start_time = time.time() 144 | i = 0 145 | 146 | while num_frames < args.frames: 147 | # Update parameters 148 | 149 | update_start_time = time.time() 150 | logs = algo.update_parameters() 151 | update_end_time = time.time() 152 | 153 | num_frames += logs["num_frames"] 154 | i += 1 155 | 156 | # Print logs 157 | 158 | if i % args.log_interval == 0: 159 | total_ellapsed_time = int(time.time() - total_start_time) 160 | fps = logs["num_frames"]/(update_end_time - update_start_time) 161 | duration = datetime.timedelta(seconds=total_ellapsed_time) 162 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 163 | rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"]) 164 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 165 | 166 | logger.info( 167 | "U {} | F {:06} | FPS {:04.0f} | D {} | rR:x̄σmM {: .2f} {: .2f} {: .2f} {: .2f} | F:x̄σmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {: .3f} | vL {:.3f}" 168 | .format(i, num_frames, fps, duration, 169 | *rreturn_per_episode.values(), 170 | *num_frames_per_episode.values(), 171 | logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"])) 172 | if args.tb: 173 | writer.add_scalar("frames", num_frames, i) 174 | writer.add_scalar("FPS", fps, i) 175 | writer.add_scalar("duration", total_ellapsed_time, i) 176 | for key, value in return_per_episode.items(): 177 | writer.add_scalar("return_" + key, value, i) 178 | for key, value in rreturn_per_episode.items(): 179 | writer.add_scalar("rreturn_" + key, value, i) 180 | for key, value in num_frames_per_episode.items(): 181 | writer.add_scalar("num_frames_" + key, value, i) 182 | writer.add_scalar("entropy", logs["entropy"], i) 183 | writer.add_scalar("value", logs["value"], i) 184 | writer.add_scalar("policy_loss", logs["policy_loss"], i) 185 | writer.add_scalar("value_loss", logs["value_loss"], i) 186 | 187 | # Save obss preprocessor vocabulary and model 188 | 189 | if args.save_interval > 0 and i % args.save_interval == 0: 190 | obss_preprocessor.vocab.save() 191 | 192 | if torch.cuda.is_available(): 193 | acmodel.cpu() 194 | utils.save_model(acmodel, model_name) 195 | logger.info("Model is saved.") 196 | if torch.cuda.is_available(): 197 | acmodel.cuda() 198 | -------------------------------------------------------------------------------- /scripts/train_rl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | # -*- coding: utf-8 -*- 4 | 5 | import argparse 6 | import gym 7 | import time 8 | import datetime 9 | import torch 10 | import torch_rl 11 | 12 | import babyai.utils as utils 13 | from babyai.model import ACModel 14 | 15 | # Parse arguments 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--algo", required=True, 19 | help="algorithm to use: a2c | ppo (REQUIRED)") 20 | parser.add_argument("--env", required=True, 21 | help="name of the environment to train on (REQUIRED)") 22 | parser.add_argument("--model", default=None, 23 | help="name of the model (default: ENV_ALGO_TIME)") 24 | parser.add_argument("--seed", type=int, default=1, 25 | help="random seed (default: 1)") 26 | parser.add_argument("--procs", type=int, default=16, 27 | help="number of processes (default: 16)") 28 | parser.add_argument("--frames", type=int, default=10**7, 29 | help="number of frames of training (default: 10e7)") 30 | parser.add_argument("--log-interval", type=int, default=1, 31 | help="number of updates between two logs (default: 1)") 32 | parser.add_argument("--save-interval", type=int, default=10, 33 | help="number of updates between two saves (default: 0, 0 means no saving)") 34 | parser.add_argument("--tb", action="store_true", default=False, 35 | help="log into Tensorboard") 36 | parser.add_argument("--frames-per-proc", type=int, default=None, 37 | help="number of frames per process before update (default: 5 for A2C and 128 for PPO)") 38 | parser.add_argument("--discount", type=float, default=0.99, 39 | help="discount factor (default: 0.99)") 40 | parser.add_argument("--lr", type=float, default=7e-4, 41 | help="learning rate (default: 7e-4)") 42 | parser.add_argument("--gae-tau", type=float, default=0.95, 43 | help="tau coefficient in GAE formula (default: 0.95, 1 means no gae)") 44 | parser.add_argument("--entropy-coef", type=float, default=0.01, 45 | help="entropy term coefficient (default: 0.01)") 46 | parser.add_argument("--value-loss-coef", type=float, default=0.5, 47 | help="value loss term coefficient (default: 0.5)") 48 | parser.add_argument("--max-grad-norm", type=float, default=0.5, 49 | help="maximum norm of gradient (default: 0.5)") 50 | parser.add_argument("--recurrence", type=int, default=1, 51 | help="number of timesteps gradient is backpropagated (default: 1)") 52 | parser.add_argument("--optim-eps", type=float, default=1e-5, 53 | help="Adam and RMSprop optimizer epsilon (default: 1e-5)") 54 | parser.add_argument("--optim-alpha", type=float, default=0.99, 55 | help="RMSprop optimizer apha (default: 0.99)") 56 | parser.add_argument("--clip-eps", type=float, default=0.2, 57 | help="clipping epsilon for PPO (default: 0.2)") 58 | parser.add_argument("--epochs", type=int, default=4, 59 | help="number of epochs for PPO (default: 4)") 60 | parser.add_argument("--batch-size", type=int, default=256, 61 | help="batch size for PPO (default: 256)") 62 | parser.add_argument("--instr-model", default=None, 63 | help="model to encode instructions, None if not using instructions, possible values: gru, conv, bow") 64 | parser.add_argument("--no-mem", action="store_true", default=False, 65 | help="don't use memory in the model") 66 | parser.add_argument("--arch", default='cnn1', 67 | help="image embedding architecture") 68 | parser.add_argument("--room-size", type=int, default=6, 69 | help="room size for the env") 70 | 71 | args = parser.parse_args() 72 | 73 | # Set seed for all randomness sources 74 | 75 | utils.seed(args.seed) 76 | 77 | # Generate environments 78 | 79 | envs = [] 80 | for i in range(args.procs): 81 | env = gym.make(args.env) 82 | env.seed(args.seed + i) 83 | envs.append(env) 84 | 85 | # Define model name 86 | 87 | suffix = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") 88 | instr = args.isntr_model if args.instr_model else "noinstr" 89 | mem = "mem" if not args.no_mem else "nomem" 90 | default_model_name = "{}_{}_{}_{}_{}_{}_start_marked_mark_seed{}_{}".format(args.env, 91 | args.algo, 92 | args.arch, 93 | instr, 94 | mem, 95 | args.room_size, 96 | args.seed, 97 | suffix) 98 | model_name = args.model or default_model_name 99 | 100 | # Define obss preprocessor 101 | 102 | obss_preprocessor = utils.ObssPreprocessor(model_name, envs[0].observation_space) 103 | 104 | # Define actor-critic model 105 | 106 | acmodel = utils.load_model(model_name, raise_not_found=False) 107 | if acmodel is None: 108 | acmodel = ACModel(obss_preprocessor.obs_space, envs[0].action_space, 109 | args.instr_model, not args.no_mem, args.arch) 110 | if torch.cuda.is_available(): 111 | acmodel.cuda() 112 | 113 | # Define actor-critic algo 114 | 115 | if args.algo == "a2c": 116 | algo = torch_rl.A2CAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau, 117 | args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, 118 | args.optim_alpha, args.optim_eps, obss_preprocessor, utils.reshape_reward) 119 | elif args.algo == "ppo": 120 | algo = torch_rl.PPOAlgo(envs, acmodel, args.frames_per_proc, args.discount, args.lr, args.gae_tau, 121 | args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, 122 | args.optim_eps, args.clip_eps, args.epochs, args.batch_size, obss_preprocessor, 123 | utils.reshape_reward) 124 | else: 125 | raise ValueError("Incorrect algorithm name: {}".format(args.algo)) 126 | 127 | # Define logger and Tensorboard writer 128 | 129 | logger = utils.get_logger(model_name) 130 | if args.tb: 131 | from tensorboardX import SummaryWriter 132 | writer = SummaryWriter(utils.get_log_dir(model_name)) 133 | 134 | # Log command, availability of CUDA and model 135 | 136 | logger.info(args) 137 | logger.info("CUDA available: {}".format(torch.cuda.is_available())) 138 | logger.info(acmodel) 139 | 140 | # Train model 141 | 142 | num_frames = 0 143 | total_start_time = time.time() 144 | i = 0 145 | while num_frames < args.frames: 146 | # Update parameters 147 | 148 | update_start_time = time.time() 149 | logs = algo.update_parameters() 150 | update_end_time = time.time() 151 | 152 | num_frames += logs["num_frames"] 153 | i += 1 154 | 155 | # Print logs 156 | 157 | if i % args.log_interval == 0: 158 | total_ellapsed_time = int(time.time() - total_start_time) 159 | fps = logs["num_frames"]/(update_end_time - update_start_time) 160 | duration = datetime.timedelta(seconds=total_ellapsed_time) 161 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 162 | rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"]) 163 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 164 | 165 | logger.info( 166 | "U {} | F {:06} | FPS {:04.0f} | D {} | rR:x̄σmM {: .2f} {: .2f} {: .2f} {: .2f} | F:x̄σmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {: .3f} | vL {:.3f}" 167 | .format(i, num_frames, fps, duration, 168 | *rreturn_per_episode.values(), 169 | *num_frames_per_episode.values(), 170 | logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"])) 171 | if args.tb: 172 | writer.add_scalar("frames", num_frames, i) 173 | writer.add_scalar("FPS", fps, i) 174 | writer.add_scalar("duration", total_ellapsed_time, i) 175 | for key, value in return_per_episode.items(): 176 | writer.add_scalar("return_" + key, value, i) 177 | for key, value in rreturn_per_episode.items(): 178 | writer.add_scalar("rreturn_" + key, value, i) 179 | for key, value in num_frames_per_episode.items(): 180 | writer.add_scalar("num_frames_" + key, value, i) 181 | writer.add_scalar("entropy", logs["entropy"], i) 182 | writer.add_scalar("value", logs["value"], i) 183 | writer.add_scalar("policy_loss", logs["policy_loss"], i) 184 | writer.add_scalar("value_loss", logs["value_loss"], i) 185 | 186 | # Save obss preprocessor vocabulary and model 187 | 188 | if args.save_interval > 0 and i % args.save_interval == 0: 189 | obss_preprocessor.vocab.save() 190 | 191 | if torch.cuda.is_available(): 192 | acmodel.cpu() 193 | utils.save_model(acmodel, model_name) 194 | logger.info("Model is saved.") 195 | if torch.cuda.is_available(): 196 | acmodel.cuda() 197 | -------------------------------------------------------------------------------- /scripts/zforcing_main_state_dec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | #!/usr/bin/env python3 3 | 4 | import argparse 5 | import cv2 6 | import gym 7 | import time 8 | import datetime 9 | import pickle 10 | import babyai.utils as utils 11 | from itertools import count 12 | import scipy.optimize 13 | from scripts.rl_zforcing_dec import ZForcing 14 | import random 15 | import scipy.misc 16 | import torch 17 | import numpy as np 18 | # Parse arguments 19 | import os 20 | import matplotlib 21 | matplotlib.use('Agg') 22 | import matplotlib.pyplot as plt 23 | import matplotlib.cm as cm 24 | import numpy as np 25 | import re 26 | from pyvirtualdisplay import Display 27 | 28 | display_ = Display(visible=0, size=(550, 500)) 29 | display_.start() 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--env", required=True, 33 | help="name of the environment to be run (REQUIRED)") 34 | parser.add_argument("--eval-episodes", type=int, default=1000, 35 | help="number of episodes of evaluation (default: 1000)") 36 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 37 | parser.add_argument("--eval-interval", type=int, default=100, 38 | help="how often to evaluate the student model") 39 | parser.add_argument("--seed", type=int, default=None, 40 | help="random seed (default: 0 if model agent, 1 if demo agent)") 41 | parser.add_argument("--room", type=int, default=15, 42 | help="room size") 43 | parser.add_argument("--deterministic", action="store_true", default=False, 44 | help="action with highest probability is selected for model agent") 45 | parser.add_argument('--aux-weight-start', type=float, default=0., 46 | help='start weight for auxiliary loss') 47 | parser.add_argument('--l2-weight', type=float, default=1., 48 | help='weight for l2 loss') 49 | parser.add_argument('--aux-weight-end', type=float, default=0., 50 | help='end weight for auxiliary loss') 51 | parser.add_argument('--bwd-weight', type=float, default=0., 52 | help='weight for bwd teacher forcing loss') 53 | parser.add_argument('--kld-weight-start', type=float, default=0., 54 | help='start weight for kl divergence between prior and posterior z loss') 55 | parser.add_argument('--kld-step', type=float, default=1e-6, 56 | help='step size to anneal kld_weight per iteration') 57 | parser.add_argument('--aux-step', type=float, default=1e-6, 58 | help='step size to anneal aux_weight per iteration') 59 | parser.add_argument("--datafile", default=None, 60 | help="name and location of the expert trajectory data file to load") 61 | 62 | def pad(array, length): 63 | return array + [np.zeros_like(array[-1])] * (length - len(array)) 64 | 65 | def front_pad(array, length): 66 | return [np.zeros_like(array[-1])] * (length - len(array)) + array 67 | 68 | def max_length(arrays): 69 | return max([len(array) for array in arrays]) 70 | 71 | def save_param(model, model_file_name): 72 | torch.save(model.state_dict(), model_file_name) 73 | 74 | def load_param(model, model_file_name): 75 | model.load_state_dict(torch.load(model_file_name)) 76 | return model 77 | 78 | def write_samples(all_samples_obs, all_samples_actions, filename): 79 | # write to pickle file 80 | all_data = list(zip(all_samples_obs, all_samples_actions)) 81 | output = open(filename, "wb") 82 | pickle.dump(all_data, output) 83 | output.close() 84 | return True 85 | 86 | def load_samples(filename): 87 | output = open(filename, "rb") 88 | all_data = pickle.load(output) 89 | return all_data 90 | 91 | 92 | def evaluate_student(agent, env, episodes): 93 | logs = {"num_frames_per_episode": [], "return_per_episode": []} 94 | reward_batch = [] 95 | for _ in range(episodes): 96 | obs = env.reset() 97 | done = False 98 | num_frames = 0 99 | returnn = 0 100 | hidden = zf.init_hidden(1) 101 | 102 | while not (done): 103 | #action = agent(obs['image']) 104 | image = np.expand_dims(obs['image'], 0) 105 | mask = torch.ones(image.shape).unsqueeze(0) 106 | image = torch.from_numpy(image).unsqueeze(0).permute(0,1,4,2,3) 107 | action, hidden = zf.generate_onestep(image.float().cuda(), mask.cuda(), hidden) 108 | obs, reward, done, _ = env.step(action) 109 | num_frames += 1 110 | returnn += reward 111 | logs["num_frames_per_episode"].append(num_frames) 112 | logs["return_per_episode"].append(returnn) 113 | reward_batch.append(returnn) 114 | #log_line = 'test reward is '+ str(np.asarray(reward_batch).mean()) +'\n' 115 | #log_line = 'test reward std is ' + str(np.asarray(reward_batch).std()) + '\n' 116 | #print (log_line) 117 | #log_line = 'test reward is '+ str(np.asarray(reward_batch).mean()) 118 | #with open(log_file, 'a') as f: 119 | # f.write(log_line) 120 | return logs 121 | 122 | def analysis_zf(agent, env, episode, iteration, episodes): 123 | logs = {"num_frames_per_episode": [], "return_per_episode": []} 124 | print ('analyzing model') 125 | reward_batch = [] 126 | 127 | curr_dir = os.path.join(model_dir, 'episode_'+str(episode) + '_iter_' + str(iteration)) 128 | os.mkdir(curr_dir) 129 | 130 | for iter_ in range(episodes): 131 | iter_dir = os.path.join(curr_dir, 'iter_'+str(iter_)) 132 | os.mkdir(iter_dir) 133 | obs = env.reset() 134 | done = False 135 | num_frames = 0 136 | returnn = 0 137 | hidden = zf.init_hidden(1) 138 | test_images = [] 139 | test_actions = [] 140 | episode_images = [] 141 | episode_actions = [] 142 | images = [] 143 | step = 0 144 | while not (done): 145 | #action = agent(obs['image']) 146 | image = env.render("rgb_array") 147 | image = cv2.resize(image, dsize=(512, 512), interpolation=cv2.INTER_CUBIC) 148 | file_name = os.path.join(iter_dir, 'iter_' + str(iter_) +'_step_' +str(step) + '.png') 149 | cv2.imwrite(file_name, image[:,:,::-1]) 150 | 151 | 152 | obs_image = np.expand_dims(obs['image'], 0) 153 | 154 | episode_images.append(obs_image) 155 | mask = torch.ones(obs_image.shape).unsqueeze(0) 156 | 157 | zf_image = torch.from_numpy(obs_image).unsqueeze(0).permute(0,1,4,2,3) 158 | action, hidden = zf.generate_onestep(zf_image.float().cuda(), mask.cuda(), hidden) 159 | episode_actions.append(action.item()) 160 | obs, reward, done, _ = env.step(action) 161 | num_frames += 1 162 | step += 1 163 | returnn += reward 164 | # after gathering all observation images, run them through the ZForcing model and print predication cose 165 | obs_image = np.expand_dims(obs['image'], 0) 166 | episode_images.append(obs_image) 167 | 168 | test_images.append(episode_images) 169 | test_actions.append(episode_actions) 170 | 171 | images_max_len = max_length(test_images) 172 | actions_max_len = max_length(test_actions) 173 | images_mask = [[1] * (len(array) - 1) + [0] * (images_max_len - len(array)) 174 | for array in test_images] 175 | fwd_images = [pad(array[:-1], images_max_len - 1) for array in test_images] 176 | bwd_images = [front_pad(array[1:], images_max_len - 1) for array in test_images] 177 | bwd_images_target = [front_pad(array[:-1], images_max_len - 1) for array in test_images] 178 | training_actions = [pad(array, actions_max_len) for array in test_actions] 179 | 180 | fwd_images = np.array(list(zip(*fwd_images)), dtype=np.float32) 181 | bwd_images = np.array(list(zip(*bwd_images)), dtype=np.float32) 182 | bwd_images_target = np.array(list(zip(*bwd_images_target)), dtype=np.float32) 183 | images_mask = np.array(list(zip(*images_mask)), dtype=np.float32) 184 | test_actions = np.array(list(zip(*test_actions)), dtype=np.float32) 185 | x_fwd = torch.from_numpy(fwd_images.squeeze(1)).permute(0,1,4,2,3).cuda() 186 | x_bwd = torch.from_numpy(bwd_images.squeeze(1)).permute(0,1,4,2,3).cuda() 187 | y_bwd = torch.from_numpy(bwd_images_target.squeeze(1)).permute(0,1,4,2,3).cuda() 188 | #y_bwd = torch.from_numpy(fwd_images.squeeze(1)).permute(0,1,4,2,3).cuda() 189 | y = torch.from_numpy(test_actions).cuda() 190 | x_mask = torch.from_numpy(images_mask).cuda() 191 | 192 | fwd_nll, bwd_nll, aux_nlls, klds, log_pz, bwd_l2_loss = zf(x_fwd, x_bwd, y, y_bwd, x_mask, hidden, return_per_step=True) 193 | aux_nlls = aux_nlls.data.cpu().numpy().reshape(-1) 194 | plt.plot(aux_nlls, label='auxillary cost changes') 195 | plt.legend(loc='upper right') 196 | filename = os.path.join(iter_dir, 'aux_cost.pdf') 197 | plt.savefig(filename) 198 | plt.close() 199 | logs["num_frames_per_episode"].append(num_frames) 200 | logs["return_per_episode"].append(returnn) 201 | test_images = [] 202 | test_actions = [] 203 | reward_batch.append(returnn) 204 | return logs 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | if __name__ == "__main__": 213 | args = parser.parse_args() 214 | lr = args.lr 215 | 216 | if args.seed is None: 217 | args.seed = 0 # if args.model is not None else 1 218 | 219 | model_name = 'zforce_2opt_room_' + str(args.room) + '_lr'+ str(args.lr) + '_bwd_w_' + str(args.bwd_weight) +'_l2_w_' + str(args.l2_weight) + '_aux_w_' + str(args.aux_weight_start) + '_kld_w_' + str(args.kld_weight_start) + '_' + str(random.randint(1,1000)) 220 | 221 | model_dir = os.path.join(args.env+'-model', model_name) 222 | 223 | os.mkdir(model_dir) 224 | zf_name = model_name + '.pkl' 225 | zf_file = os.path.join(model_dir, zf_name) 226 | 227 | log_name = model_name +'.log' 228 | log_file = os.path.join(model_dir, log_name) 229 | # Set seed for all randomness sources 230 | 231 | utils.seed(args.seed) 232 | 233 | # Generate environment 234 | 235 | env = gym.make(args.env) 236 | env.seed(args.seed) 237 | 238 | 239 | # Run the agent 240 | 241 | start_time = time.time() 242 | 243 | # load expert data samples 244 | 245 | end_time = time.time() 246 | 247 | # Print logs 248 | 249 | '''num_frames = sum(logs["num_frames_per_episode"]) 250 | fps = num_frames/(end_time - start_time) 251 | ellapsed_time = int(end_time - start_time) 252 | duration = datetime.timedelta(seconds=ellapsed_time) 253 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 254 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 255 | 256 | print("F {} | FPS {:.0f} | D {} | R:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {}" 257 | .format(num_frames, fps, duration, 258 | *return_per_episode.values(), 259 | *num_frames_per_episode.values())) 260 | 261 | indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k]) 262 | n = 10 263 | print("{} worst episodes:".format(n)) 264 | for i in indexes[:n]: 265 | print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i])) 266 | ''' 267 | # Train a student policy 268 | num_actions = 7 269 | zf = ZForcing(emb_dim=512, rnn_dim=512, z_dim=256, 270 | mlp_dim=256, out_dim=num_actions , z_force=False, cond_ln=False, use_l2=True) 271 | data_file = args.datafile #'data/BabyAI-UnlockPickup-v0start_flag_room_10_10000_samples.dat' 272 | all_data = load_samples(data_file) 273 | all_samples_obs, all_samples_actions = [list(t) for t in zip(*all_data)] 274 | 275 | fwd_param = [] 276 | bwd_param = [] 277 | 278 | hist_return_mean = 0.0 279 | 280 | for param_tuple in zf.named_parameters(): 281 | name = param_tuple[0] 282 | param = param_tuple[1] 283 | if 'bwd' in name: 284 | bwd_param.append(param) 285 | else: 286 | fwd_param.append(param) 287 | 288 | zf_fwd_param = (n for n in fwd_param) 289 | zf_bwd_param = (n for n in bwd_param) 290 | fwd_opt = torch.optim.Adam(zf_fwd_param, lr = lr, eps=1e-5) 291 | bwd_opt = torch.optim.Adam(zf_bwd_param, lr = lr, eps=1e-5) 292 | 293 | kld_weight = args.kld_weight_start 294 | aux_weight = args.aux_weight_start 295 | bwd_weight = args.bwd_weight 296 | zf.float() 297 | zf.cuda() 298 | 299 | num_samples = len(all_samples_obs) 300 | 301 | batch_size = 32 302 | 303 | num_episodes = 50 304 | 305 | for episode in range(num_episodes): 306 | for i in range(int(num_samples/ batch_size)): 307 | training_images = all_samples_obs[i * batch_size : (i + 1) * batch_size] 308 | training_actions = all_samples_actions[i * batch_size : (i + 1) * batch_size] 309 | images_max_len = max_length(training_images) 310 | actions_max_len = max_length(training_actions) 311 | images_mask = [[1] * (len(array) - 1) + [0] * (images_max_len - len(array)) 312 | for array in training_images] 313 | 314 | fwd_images = [pad(array[:-1], images_max_len - 1) for array in training_images] 315 | 316 | 317 | bwd_images = [front_pad(array[1:], images_max_len - 1) for array in training_images] 318 | bwd_images_target = [front_pad(array[:-1], images_max_len - 1) for array in training_images] 319 | training_actions = [pad(array, actions_max_len) for array in training_actions] 320 | 321 | fwd_images = np.array(list(zip(*fwd_images)), dtype=np.float32) 322 | bwd_images = np.array(list(zip(*bwd_images)), dtype=np.float32) 323 | bwd_images_target = np.array(list(zip(*bwd_images_target)), dtype=np.float32) 324 | images_mask = np.array(list(zip(*images_mask)), dtype=np.float32) 325 | training_actions = np.array(list(zip(*training_actions)), dtype=np.float32) 326 | x_fwd = torch.from_numpy(fwd_images).permute(0,1,4,2,3).cuda() 327 | x_bwd = torch.from_numpy(bwd_images).permute(0,1,4,2,3).cuda() 328 | y_bwd = torch.from_numpy(bwd_images_target).permute(0,1,4,2,3).cuda() 329 | y = torch.from_numpy(training_actions).cuda() 330 | x_mask = torch.from_numpy(images_mask).cuda() 331 | 332 | zf.float().cuda() 333 | hidden = zf.init_hidden(batch_size) 334 | 335 | fwd_opt.zero_grad() 336 | bwd_opt.zero_grad() 337 | 338 | fwd_nll, bwd_nll, aux_nll, kld, bwd_l2_loss = zf(x_fwd, x_bwd, y, y_bwd, x_mask, hidden) 339 | #bwd_nll = (aux_weight > 0.) * (bwd_weight * bwd_nll) 340 | bwd_nll = bwd_weight * bwd_nll 341 | aux_nll = aux_weight * aux_nll 342 | all_loss = fwd_nll + bwd_nll + aux_nll + kld_weight * kld + args.l2_weight * bwd_l2_loss 343 | fwd_loss = fwd_nll + aux_nll + kld_weight * kld 344 | bwd_loss = args.l2_weight * bwd_l2_loss + 0.0 * bwd_nll 345 | 346 | kld_weight += args.kld_step 347 | kld_weight = min(kld_weight, 1.) 348 | if args.aux_weight_start < args.aux_weight_end: 349 | aux_weight += args.aux_step 350 | aux_weight = min(aux_weight, args.aux_weight_end) 351 | else: 352 | aux_weight -= args.aux_step 353 | aux_weight = max(aux_weight, args.aux_weight_end) 354 | log_line ='Episode: %d, Iteration: %d, All loss is %.3f , forward loss is %.3f, backward loss is %.3f, l2 loss is %.3f, aux loss is %.3f, kld is %.3f' % ( 355 | episode, i, 356 | all_loss.item(), 357 | fwd_nll.item(), 358 | bwd_nll.item(), 359 | bwd_l2_loss.item(), 360 | aux_nll.item(), 361 | kld.item() 362 | ) + '\n' 363 | #print(log_line) 364 | with open(log_file, 'a') as f: 365 | f.write(log_line) 366 | 367 | fwd_loss.backward() 368 | bwd_loss.backward() 369 | 370 | torch.nn.utils.clip_grad_norm_(zf.parameters(), 100.) 371 | #opt.step() 372 | fwd_opt.step() 373 | bwd_opt.step() 374 | 375 | if (i + 1) % (args.eval_interval) == 0: 376 | logs = evaluate_student(zf, env, args.eval_episodes) 377 | # Print logs 378 | num_frames = sum(logs["num_frames_per_episode"]) 379 | fps = num_frames/(end_time - start_time) 380 | ellapsed_time = int(end_time - start_time) 381 | duration = datetime.timedelta(seconds=ellapsed_time) 382 | return_per_episode = utils.synthesize(logs["return_per_episode"]) 383 | num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) 384 | 385 | log_line = ("F {} | FPS {:.0f} | D {} | R:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {}".format(num_frames, fps, duration, *return_per_episode.values(),*num_frames_per_episode.values())) 386 | print("F {} | FPS {:.0f} | D {} | R:x̄σmM {:.2f} {:.2f} {:.2f} {:.2f} | F:x̄σmM {:.1f} {:.1f} {} {}" 387 | .format(num_frames, fps, duration, 388 | *return_per_episode.values(), 389 | *num_frames_per_episode.values())) 390 | with open(log_file, 'a') as f: 391 | f.write(log_line) 392 | if return_per_episode['mean'] > hist_return_mean: 393 | save_param(zf, zf_file) 394 | hist_return_mean = return_per_episode['mean'] 395 | if (i + 1 ) % 200 == 0: 396 | analysis_zf(zf, env, episode, i, 10) 397 | 398 | 399 | 400 | 401 | 402 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from setuptools import setup 3 | 4 | setup( 5 | name='babyai', 6 | version='0.0.2', 7 | license='BSD 3-clause', 8 | keywords='memory, environment, agent, rl, openaigym, openai-gym, gym', 9 | packages=['babyai', 'babyai.levels', 'babyai.agents', 'babyai.utils'], 10 | install_requires=[ 11 | 'gym>=0.9.6', 12 | 'numpy>=1.10.0' 13 | ], 14 | dependency_links=[ 15 | ] 16 | ) 17 | --------------------------------------------------------------------------------