├── .gitignore ├── LICENSE.txt ├── README.md ├── config ├── geo_nmn.yml ├── geo_quant_nmn.yml ├── log.yml └── vqa_nmn.yml ├── extra ├── geo │ ├── fl.sps │ ├── ga.sps │ ├── mi.sps │ ├── nc.sps │ ├── ok.sps │ ├── pa.sps │ ├── sc.sps │ ├── tn.sps │ ├── va.sps │ └── wv.sps └── vqa │ ├── parse.old.py │ ├── parse.py │ ├── test-dev2015.sps2 │ ├── test2015.sps2 │ ├── train2014.sps2 │ └── val2014.sps2 ├── layers ├── __init__.py └── reinforce.py ├── main.py ├── misc ├── __init__.py ├── datum.py ├── indices.py ├── parse.py ├── util.py └── visualizer.py ├── models ├── __init__.py ├── att.py └── nmn.py ├── opt ├── __init__.py └── adadelta.py ├── run.sh └── tasks ├── __init__.py ├── geo.py └── vqa.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.lprof 4 | *.log 5 | logs 6 | vis 7 | data 8 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural module networks 2 | 3 | **UPDATE 22 Jun 2017: Code for our end-to-end module network framework is 4 | available at https://github.com/ronghanghu/n2nmn. The n2nmn code works better 5 | and is easier to set up. Use it!** 6 | 7 | This library provides code for training and evaluating _neural module networks_ 8 | (NMNs). An NMN is a neural network that is assembled dynamically by composing 9 | shallow network fragments called _modules_ into a deeper structure. These 10 | modules are jointly trained to be freely composable. For a general overview to 11 | the framework, refer to: 12 | 13 | > [Neural module networks](http://arxiv.org/abs/1511.02799). 14 | > Jacob Andreas, Marcus Rohrbach, Trevor Darrell and Dan Klein. 15 | > CVPR 2016. 16 | 17 | 18 | > [Learning to compose neural networks for question 19 | > answering](http://arxiv.org/abs/1601.01705). 20 | > Jacob Andreas, Marcus Rohrbach, Trevor Darrell and Dan Klein. 21 | > NAACL 2016. 22 | 23 | At present the code supports predicting network layouts from natural-language 24 | strings, with end-to-end training of modules. Various extensions should be 25 | straightforward to implement—alternative layout predictors, supervised 26 | training of specific modules, etc. 27 | 28 | Please cite the CVPR paper for the general NMN framework, and the NAACL paper 29 | for dynamic structure selection. Feel free to email me at 30 | [jda@cs.berkeley.edu](mailto:jda@cs.berkeley.edu) if you have questions. This 31 | code is released under the Apache 2 license, provided in `LICENSE.txt`. 32 | 33 | ## Installing dependencies 34 | 35 | You will need to build **my fork** of the excellent 36 | [ApolloCaffe](http://apollocaffe.com/) library. This fork may be found at 37 | [jacobandreas/apollocaffe](https://github.com/jacobandreas/apollocaffe), and 38 | provides support for a few Caffe layers that haven't made it into the main 39 | Apollo repository. Ordinary Caffe users: note that you will have to install the 40 | `runcython` Python module in addition to the usual Caffe dependencies. 41 | 42 | One this is done, update `APOLLO_ROOT` at the top of `run.sh` to point to your 43 | ApolloCaffe installation. 44 | 45 | You will also need to install the following packages: 46 | 47 | colorlogs, sexpdata 48 | 49 | ## Downloading data 50 | 51 | All experiment data should be placed in the `data` directory. 52 | 53 | #### VQA 54 | 55 | In `data`, create a subdirectory named `vqa`. Follow the [VQA setup 56 | instructions](https://github.com/VT-vision-lab/VQA/blob/master/README.md) to 57 | install the data into this directory. (It should have children `Annotations`, 58 | `Images`, etc.) 59 | 60 | We have modified the structure of the VQA `Images` directory slightly. `Images` 61 | should have two subdirectories, `raw` and `conv`. `raw` contains the original 62 | VQA images, while `conv` contains the result of preprocessing these images with 63 | a [16-layer VGGNet](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) as 64 | described in the paper. Every file in the `conv` directory should be of the form 65 | `COCO_{SETNAME}_{IMAGEID}.jpg.npz`, and contain a 512x14x14 image map in zipped 66 | numpy format. Here's a [gist](https://gist.github.com/jacobandreas/897987ac03f8d4b9ea4b9e44affa00e7) 67 | with the code I use for doing the extraction. 68 | 69 | #### GeoQA 70 | 71 | Download the GeoQA dataset from the [LSP 72 | website](http://rtw.ml.cmu.edu/tacl2013_lsp/), and unpack it into `data/geo`. 73 | 74 | ## Parsing questions 75 | 76 | Every dataset fold should contain a file of parsed questions, one per line, 77 | formatted as S-expressions. If multiple parses are provided, they should be 78 | semicolon-delimited. As an example, for the question "is the train modern" we 79 | might have: 80 | 81 | (is modern);(is train);(is (and modern train)) 82 | 83 | For VQA, these files should be named `Questions/{train2014,val2014,...}.sps2`. 84 | For GeoQA, they should be named `environments/{fl,ga,...}/training.sps`. Parses 85 | used in our papers are provided in `extra` and should be installed in the 86 | appropriate location. The VQA parser script is also located under `extra/vqa`; 87 | instructions for running are provided in the body of the script. 88 | 89 | ## Running experiments 90 | 91 | You will first need to create directories `vis` and `logs` (which respectively 92 | store run logs and visualization code) 93 | 94 | Different experiments can be run by providing an appropriate configuration file 95 | on the command line (see the last line of `run.sh`). Examples for VQA and GeoQA 96 | are provided in the `config` directory. 97 | 98 | Looking for SHAPES? I haven't finished integrating it with the rest of the 99 | codebase, but check out the `shapes` branch of this repository for data and 100 | code. 101 | 102 | ## TODO 103 | 104 | - Configurable data location 105 | - Model checkpointing 106 | -------------------------------------------------------------------------------- /config/geo_nmn.yml: -------------------------------------------------------------------------------- 1 | task: 2 | name: geo 3 | quant: false 4 | fold: 1 5 | k_best_parses: 8 6 | 7 | model: 8 | name: nmn 9 | lstm_hidden: 256 10 | layout_hidden: 100 11 | 12 | combine_question: false 13 | att_normalization: local 14 | 15 | opt: 16 | batch_size: 10 17 | iters: 10000 18 | 19 | rho: 0.95 20 | eps: 0.000001 21 | lr: 2 22 | clip: 10.0 23 | 24 | dropout: false 25 | multiclass: true 26 | 27 | log_preds: false 28 | -------------------------------------------------------------------------------- /config/geo_quant_nmn.yml: -------------------------------------------------------------------------------- 1 | task: 2 | name: geo 3 | quant: true 4 | fold: 1 5 | k_best_parses: 2 6 | 7 | model: 8 | name: nmn 9 | lstm_hidden: 256 10 | layout_hidden: 100 11 | 12 | combine_question: false 13 | att_normalization: local 14 | 15 | opt: 16 | batch_size: 10 17 | iters: 10000 18 | 19 | rho: 0.95 20 | eps: 0.000001 21 | lr: 2 22 | clip: 10.0 23 | 24 | dropout: false 25 | multiclass: true 26 | 27 | log_preds: false 28 | -------------------------------------------------------------------------------- /config/log.yml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | 3 | root: 4 | handlers: [consoleHandler, fileHandler] 5 | level: DEBUG 6 | 7 | handlers: 8 | consoleHandler: 9 | class: logging.StreamHandler 10 | formatter: colorFormatter 11 | stream: ext://sys.stdout 12 | 13 | fileHandler: 14 | class: logging.FileHandler 15 | formatter: simpleFormatter 16 | 17 | formatters: 18 | colorFormatter: 19 | '()': 'colorlog.ColoredFormatter' 20 | format: "%(asctime)s %(log_color)s%(levelname)-8s%(reset)s %(fg_cyan)s[%(name)s]%(reset)s %(message)s" 21 | 22 | simpleFormatter: 23 | format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 24 | -------------------------------------------------------------------------------- /config/vqa_nmn.yml: -------------------------------------------------------------------------------- 1 | task: 2 | name: vqa 3 | #debug: 1013 4 | chooser: cvpr # [null, cvpr, naacl] 5 | answers: 2000 6 | 7 | model: 8 | name: nmn 9 | 10 | lstm_hidden: 1000 11 | layout_hidden: 100 12 | att_hidden: 100 13 | pred_hidden: 500 14 | 15 | combine_question: true 16 | att_normalization: local 17 | 18 | opt: 19 | batch_size: 100 20 | iters: 30 21 | dropout: true 22 | multiclass: false 23 | 24 | rho: 0.95 25 | eps: 0.000001 26 | lr: 1 27 | clip: 10.0 28 | -------------------------------------------------------------------------------- /extra/geo/fl.sps: -------------------------------------------------------------------------------- 1 | tallahassee;(exists tallahassee);(and tallahassee city);(exists (and tallahassee city)) 2 | miami;(exists miami);(and miami city);(exists (and miami city)) 3 | florida;(exists florida);(and florida state);(exists (and florida state)) 4 | miami;(exists miami);(and miami state);(exists (and miami state)) 5 | tallahassee;(exists tallahassee);(and tallahassee state);(exists (and tallahassee state)) 6 | key_largo;(exists key_largo);(and key_largo island);(exists (and key_largo island)) 7 | miami;(exists miami);(and miami island);(exists (and miami island)) 8 | tallahassee;(exists tallahassee);(and tallahassee park);(exists (and tallahassee park)) 9 | tallahassee;(exists tallahassee);(and tallahassee city);(exists (and tallahassee city));(and tallahassee city (in florida));(exists (and tallahassee city (in florida))) 10 | miami;(exists miami);(and miami city);(exists (and miami city));(and miami city (in florida));(exists (and miami city (in florida))) 11 | island;(exists island);(and island (in florida));(exists (and island (in florida))) 12 | park;(exists park);(and park (in florida));(exists (and park (in florida))) 13 | park;(exists park);(and park (to_south_west_of miami));(exists (and park (to_south_west_of miami))) 14 | park;(exists park);(and park (to_north_west_of key_largo));(exists (and park (to_north_west_of key_largo))) 15 | city;(exists city);(and city (to_north_of key_largo));(exists (and city (to_north_of key_largo))) 16 | city;(exists city);(and city (to_south_east_of tallahassee));(exists (and city (to_south_east_of tallahassee))) 17 | city;(exists city);(and city (to_north_west_of daytona_beach));(exists (and city (to_north_west_of daytona_beach))) 18 | park;(exists park);(and park (to_south_east_of tallahassee));(exists (and park (to_south_east_of tallahassee)));(and park (to_south_east_of tallahassee) (north_west_of key_largo));(exists (and park (to_south_east_of tallahassee) (north_west_of key_largo))) 19 | city;(exists city);(and city (to_north_of key_largo));(exists (and city (to_north_of key_largo)));(and city (to_north_of key_largo) (east_of tallahassee));(exists (and city (to_north_of key_largo) (east_of tallahassee))) 20 | city;(exists city);(and city (in_between daytona_beach_key_largo));(exists (and city (in_between daytona_beach_key_largo))) 21 | park;(exists park);(and park (in florida));(exists (and park (in florida))) 22 | beach;(exists beach);(and beach city);(exists (and beach city));(and beach city (in florida));(exists (and beach city (in florida))) 23 | city;(exists city);(and city (in florida));(exists (and city (in florida))) 24 | capital;(exists capital);(and capital (of florida));(exists (and capital (of florida))) 25 | florida;(exists florida);(and florida city);(exists (and florida city));(and florida city (on peninsula));(exists (and florida city (on peninsula))) 26 | florida;(exists florida);(and florida city);(exists (and florida city));(and florida city (on peninsula));(exists (and florida city (on peninsula))) 27 | state;(exists state);(and state tallahassee);(exists (and state tallahassee));(and state tallahassee capital);(exists (and state tallahassee capital)) 28 | state;(exists state);(and state (in__inv miami));(exists (and state (in__inv miami))) 29 | park;(exists park);(and park (east_of miami));(exists (and park (east_of miami))) 30 | beach;(exists beach);(and beach (in florida));(exists (and beach (in florida))) 31 | beach;(exists beach);(and beach (on_east_of florida));(exists (and beach (on_east_of florida)));(and beach (on_east_of florida) (north_of miami));(exists (and beach (on_east_of florida) (north_of miami))) 32 | island;(exists island);(and island (south_of everglades_national_park));(exists (and island (south_of everglades_national_park))) 33 | major;(exists major);(and major city);(exists (and major city));(and major city (in florida));(exists (and major city (in florida))) 34 | beach;(exists beach);(and beach (in florida));(exists (and beach (in florida))) 35 | (south major);(exists (south major));(and (south major) city);(exists (and (south major) city));(and (south major) city (in florida));(exists (and (south major) city (in florida))) 36 | city;(exists city);(and city (between daytona_beach_key_largo));(exists (and city (between daytona_beach_key_largo))) 37 | state;(exists state);(and state (in__inv key_largo));(exists (and state (in__inv key_largo))) 38 | florida;(exists florida);(and florida park);(exists (and florida park));(and florida park (of peninsula));(exists (and florida park (of peninsula))) 39 | state;(exists state);(and state (in__inv daytona_beach));(exists (and state (in__inv daytona_beach))) 40 | key_largo;(exists key_largo);(and key_largo island);(exists (and key_largo island)) 41 | -------------------------------------------------------------------------------- /extra/geo/ga.sps: -------------------------------------------------------------------------------- 1 | state;(exists state) 2 | state;(exists state);(and state (west_of georgia));(exists (and state (west_of georgia))) 3 | state;(exists state);(and state (east_of alabama));(exists (and state (east_of alabama))) 4 | state;(exists state);(and state (border alabama));(exists (and state (border alabama))) 5 | state;(exists state);(and state (border georgia));(exists (and state (border georgia))) 6 | city;(exists city) 7 | city;(exists city);(and city (in alabama));(exists (and city (in alabama))) 8 | city;(exists city);(and city (in georgia));(exists (and city (in georgia))) 9 | state;(exists state);(and state (in__inv atlanta));(exists (and state (in__inv atlanta))) 10 | state;(exists state);(and state (in__inv macon));(exists (and state (in__inv macon))) 11 | state;(exists state);(and state (in__inv montgomery));(exists (and state (in__inv montgomery))) 12 | state;(exists state);(and state (in__inv birmingham));(exists (and state (in__inv birmingham))) 13 | city;(exists city);(and city (south_of birmingham));(exists (and city (south_of birmingham))) 14 | city;(exists city);(and city (south_of atlanta));(exists (and city (south_of atlanta))) 15 | city;(exists city);(and city (north_of macon));(exists (and city (north_of macon))) 16 | city;(exists city);(and city (north_of montgomery));(exists (and city (north_of montgomery))) 17 | southernmost;(exists southernmost);(and southernmost city);(exists (and southernmost city));(and southernmost city (in alabama));(exists (and southernmost city (in alabama))) 18 | southernmost;(exists southernmost);(and southernmost city);(exists (and southernmost city));(and southernmost city (in georgia));(exists (and southernmost city (in georgia))) 19 | city;(exists city);(and city (from birmingham));(exists (and city (from birmingham))) 20 | city;(exists city);(and city (from macon));(exists (and city (from macon))) 21 | city;(exists city);(and city (to birmingham));(exists (and city (to birmingham))) 22 | city;(exists city);(and city (to atlanta));(exists (and city (to atlanta))) 23 | city;(exists city);(and city (in state));(exists (and city (in state)));(and city (in state) (border alabama));(exists (and city (in state) (border alabama))) 24 | city;(exists city);(and city (in state));(exists (and city (in state)));(and city (in state) (border georgia));(exists (and city (in state) (border georgia))) 25 | city;(exists city);(and city (north_of montgomery));(exists (and city (north_of montgomery)));(and city (north_of montgomery) (in state));(exists (and city (north_of montgomery) (in state)));(and city (north_of montgomery) (in state) (border georgia));(exists (and city (north_of montgomery) (in state) (border georgia))) 26 | -------------------------------------------------------------------------------- /extra/geo/mi.sps: -------------------------------------------------------------------------------- 1 | (of lake);(exists (of lake)) 2 | city;(exists city) 3 | state;(exists state) 4 | city;(exists city);(and city (in michigan));(exists (and city (in michigan))) 5 | lake;(exists lake);(and lake (near michigan));(exists (and lake (near michigan))) 6 | lake;(exists lake);(and lake (near wisconsin));(exists (and lake (near wisconsin))) 7 | lake;(exists lake);(and lake (west_of michigan));(exists (and lake (west_of michigan))) 8 | state;(exists state);(and state (border_west_of michigan));(exists (and state (border_west_of michigan))) 9 | state;(exists state);(and state (east_of wisconsin));(exists (and state (east_of wisconsin))) 10 | city;(exists city);(and city (on lake_michigan));(exists (and city (on lake_michigan))) 11 | largest;(exists largest);(and largest city);(exists (and largest city));(and largest city (in michigan));(exists (and largest city (in michigan))) 12 | largest;(exists largest);(and largest city);(exists (and largest city));(and largest city (in michigan));(exists (and largest city (in michigan))) 13 | largest;(exists largest);(and largest city);(exists (and largest city));(and largest city (in wisconsin));(exists (and largest city (in wisconsin))) 14 | lake;(exists lake);(and lake (border michigan));(exists (and lake (border michigan))) 15 | city;(exists city);(and city (east_of grand_rapids));(exists (and city (east_of grand_rapids))) 16 | city;(exists city);(and city (west_of detroit));(exists (and city (west_of detroit))) 17 | city;(exists city);(and city (west_of grand_rapids));(exists (and city (west_of grand_rapids))) 18 | lake;(exists lake);(and lake (border wisconsin));(exists (and lake (border wisconsin))) 19 | lake;(exists lake);(and lake (east_of lake_michigan));(exists (and lake (east_of lake_michigan))) 20 | state;(exists state);(and state (on lake_michigan));(exists (and state (on lake_michigan))) 21 | (border michigan);(exists (border michigan)) 22 | -------------------------------------------------------------------------------- /extra/geo/nc.sps: -------------------------------------------------------------------------------- 1 | city;(exists city);(and city (to ocean));(exists (and city (to ocean))) 2 | capital;(exists capital);(and capital (of north_carolina));(exists (and capital (of north_carolina))) 3 | (north__inv state);(exists (north__inv state)) 4 | state;(exists state);(and state larger);(exists (and state larger)) 5 | city;(exists city) 6 | forest;(exists forest);(and forest (in north_carolina));(exists (and forest (in north_carolina))) 7 | northernmost;(exists northernmost);(and northernmost city);(exists (and northernmost city));(and northernmost city (in north_carolina));(exists (and northernmost city (in north_carolina))) 8 | city;(exists city);(and city (in north_carolina));(exists (and city (in north_carolina))) 9 | city;(exists city);(and city (east_of greensboro));(exists (and city (east_of greensboro)));(and city (east_of greensboro) (in north_carolina));(exists (and city (east_of greensboro) (in north_carolina))) 10 | city;(exists city);(and city (north_east_of charlotte));(exists (and city (north_east_of charlotte)));(and city (north_east_of charlotte) (in north_carolina));(exists (and city (north_east_of charlotte) (in north_carolina))) 11 | state;(exists state);(and state (south_of north_carolina));(exists (and state (south_of north_carolina))) 12 | city;(exists city);(and city (west_of raleigh));(exists (and city (west_of raleigh))) 13 | myrtle_beach;(exists myrtle_beach);(and myrtle_beach (near north_carolina));(exists (and myrtle_beach (near north_carolina))) 14 | city;(exists city);(and city (on ocean));(exists (and city (on ocean))) 15 | north_carolina;(exists north_carolina);(and north_carolina city);(exists (and north_carolina city));(and north_carolina city (near south_carolina));(exists (and north_carolina city (near south_carolina))) 16 | forest;(exists forest);(and forest (to charlotte));(exists (and forest (to charlotte))) 17 | (south_of greensboro);(exists (south_of greensboro)) 18 | (north_of uwharrie_national_forest);(exists (north_of uwharrie_national_forest)) 19 | (north_east_of uwharrie_national_forest);(exists (north_east_of uwharrie_national_forest)) 20 | -------------------------------------------------------------------------------- /extra/geo/ok.sps: -------------------------------------------------------------------------------- 1 | state;(exists state);(and state (border texas));(exists (and state (border texas))) 2 | state;(exists state);(and state (border louisiana));(exists (and state (border louisiana))) 3 | state;(exists state);(and state (border arkansas));(exists (and state (border arkansas))) 4 | state;(exists state);(and state (east_of oklahoma));(exists (and state (east_of oklahoma))) 5 | state;(exists state);(and state (east_of louisiana));(exists (and state (east_of louisiana))) 6 | state;(exists state);(and state (west_of louisiana));(exists (and state (west_of louisiana))) 7 | state;(exists state);(and state (west_of arkansas));(exists (and state (west_of arkansas))) 8 | state;(exists state);(and state (south_of arkansas));(exists (and state (south_of arkansas))) 9 | state;(exists state);(and state (south_of oklahoma));(exists (and state (south_of oklahoma))) 10 | state;(exists state);(and state (north_of louisiana));(exists (and state (north_of louisiana))) 11 | state;(exists state);(and state (north_of texas));(exists (and state (north_of texas))) 12 | state;(exists state);(and state (north_west_of louisiana));(exists (and state (north_west_of louisiana))) 13 | state;(exists state);(and state (north_east_of texas));(exists (and state (north_east_of texas))) 14 | state;(exists state);(and state (south_east_of oklahoma));(exists (and state (south_east_of oklahoma))) 15 | state;(exists state);(and state (south_west_of arkansas));(exists (and state (south_west_of arkansas))) 16 | state;(exists state);(and state (between texas_mississippi));(exists (and state (between texas_mississippi))) 17 | state;(exists state);(and state (border texas_arkansas));(exists (and state (border texas_arkansas))) 18 | state;(exists state);(and state (border louisiana_arkansas));(exists (and state (border louisiana_arkansas))) 19 | -------------------------------------------------------------------------------- /extra/geo/pa.sps: -------------------------------------------------------------------------------- 1 | city;(exists city) 2 | state;(exists state) 3 | city;(exists city);(and city (of pittsburgh));(exists (and city (of pittsburgh))) 4 | city;(exists city);(and city harrisburg);(exists (and city harrisburg)) 5 | city;(exists city);(and city (of pittsburgh));(exists (and city (of pittsburgh))) 6 | city;(exists city);(and city (of harrisburg));(exists (and city (of harrisburg))) 7 | city;(exists city);(and city (in new_jersey));(exists (and city (in new_jersey))) 8 | (of_in pennsylvania);(exists (of_in pennsylvania)) 9 | city;(exists city);(and city (in pennsylvania));(exists (and city (in pennsylvania))) 10 | pittsburgh;(exists pittsburgh);(and pittsburgh (west_of harrisburg));(exists (and pittsburgh (west_of harrisburg))) 11 | city;(exists city);(and city (west_of harrisburg));(exists (and city (west_of harrisburg))) 12 | city;(exists city);(and city (west_of newark));(exists (and city (west_of newark))) 13 | city;(exists city);(and city (west_of trenton));(exists (and city (west_of trenton))) 14 | newark;(exists newark);(and newark (east_of harrisburg));(exists (and newark (east_of harrisburg))) 15 | newark;(exists newark);(and newark (east_of pittsburgh));(exists (and newark (east_of pittsburgh))) 16 | new_jersey;(exists new_jersey);(and new_jersey (east_of pennsylvania));(exists (and new_jersey (east_of pennsylvania))) 17 | state;(exists state);(and state (east_of pennsylvania));(exists (and state (east_of pennsylvania))) 18 | capital;(exists capital);(and capital (of pennsylvania));(exists (and capital (of pennsylvania))) 19 | capital;(exists capital);(and capital (of new_jersey));(exists (and capital (of new_jersey))) 20 | city;(exists city);(and city (in pennsylvania));(exists (and city (in pennsylvania)));(and city (in pennsylvania) (west_of harrisburg));(exists (and city (in pennsylvania) (west_of harrisburg))) 21 | city;(exists city);(and city (in pennsylvania));(exists (and city (in pennsylvania)));(and city (in pennsylvania) (east_of pittsburgh));(exists (and city (in pennsylvania) (east_of pittsburgh))) 22 | pittsburgh;(exists pittsburgh);(and pittsburgh city);(exists (and pittsburgh city)) 23 | newark;(exists newark);(and newark city);(exists (and newark city)) 24 | pennsylvania;(exists pennsylvania);(and pennsylvania city);(exists (and pennsylvania city)) 25 | pennsylvania;(exists pennsylvania);(and pennsylvania state);(exists (and pennsylvania state)) 26 | pittsburgh;(exists pittsburgh);(and pittsburgh state);(exists (and pittsburgh state)) 27 | newark;(exists newark);(and newark state);(exists (and newark state)) 28 | newark;(exists newark);(and newark city);(exists (and newark city)) 29 | pittsburgh;(exists pittsburgh);(and pittsburgh city);(exists (and pittsburgh city)) 30 | harrisburg;(exists harrisburg);(and harrisburg city);(exists (and harrisburg city)) 31 | newark;(exists newark);(and newark (east_of pittsburgh));(exists (and newark (east_of pittsburgh))) 32 | newark;(exists newark);(and newark (east_of pittsburgh));(exists (and newark (east_of pittsburgh))) 33 | pittsburgh;(exists pittsburgh);(and pittsburgh (east_of newark));(exists (and pittsburgh (east_of newark))) 34 | new_jersey;(exists new_jersey);(and new_jersey state);(exists (and new_jersey state)) 35 | new_jersey;(exists new_jersey);(and new_jersey (east_of pittsburgh));(exists (and new_jersey (east_of pittsburgh))) 36 | new_jersey;(exists new_jersey);(and new_jersey (east_of pittsburgh));(exists (and new_jersey (east_of pittsburgh))) 37 | pittsburgh;(exists pittsburgh);(and pittsburgh (east_of new_jersey));(exists (and pittsburgh (east_of new_jersey))) 38 | pittsburgh;(exists pittsburgh);(and pittsburgh (east_of new_jersey));(exists (and pittsburgh (east_of new_jersey))) 39 | pittsburgh;(exists pittsburgh);(and pittsburgh (west_of new_jersey));(exists (and pittsburgh (west_of new_jersey))) 40 | pittsburgh;(exists pittsburgh);(and pittsburgh (west_of new_jersey));(exists (and pittsburgh (west_of new_jersey))) 41 | pittsburgh;(exists pittsburgh);(and pittsburgh (west_of harrisburg));(exists (and pittsburgh (west_of harrisburg))) 42 | harrisburg;(exists harrisburg);(and harrisburg (east_of pittsburgh));(exists (and harrisburg (east_of pittsburgh))) 43 | harrisburg;(exists harrisburg);(and harrisburg (west_of pittsburgh));(exists (and harrisburg (west_of pittsburgh))) 44 | -------------------------------------------------------------------------------- /extra/geo/sc.sps: -------------------------------------------------------------------------------- 1 | city;(exists city) 2 | lake;(exists lake) 3 | beach;(exists beach) 4 | (of water);(exists (of water));(and (of water) (in south_carolina));(exists (and (of water) (in south_carolina))) 5 | ocean;(exists ocean) 6 | forest;(exists forest) 7 | park;(exists park) 8 | ocean;(exists ocean);(and ocean (border south_carolina));(exists (and ocean (border south_carolina))) 9 | (in south_carolina);(exists (in south_carolina)) 10 | (surrounded_by water);(exists (surrounded_by water)) 11 | (of water);(exists (of water)) 12 | (of water);(exists (of water)) 13 | (on atlantic_ocean);(exists (on atlantic_ocean)) 14 | state;(exists state);(and state (north_of south_carolina));(exists (and state (north_of south_carolina))) 15 | state;(exists state);(and state (border south_carolina));(exists (and state (border south_carolina))) 16 | state;(exists state);(and state (border north_carolina));(exists (and state (border north_carolina))) 17 | city;(exists city);(and city (of charleston));(exists (and city (of charleston))) 18 | city;(exists city);(and city (of charleston));(exists (and city (of charleston))) 19 | major;(exists major);(and major city);(exists (and major city));(and major city (of greenville));(exists (and major city (of greenville))) 20 | major;(exists major);(and major city);(exists (and major city));(and major city (of myrtle_beach));(exists (and major city (of myrtle_beach))) 21 | island;(exists island);(and island (of charleston));(exists (and island (of charleston))) 22 | state;(exists state);(and state (on ocean));(exists (and state (on ocean))) 23 | city;(exists city);(and city (north_of myrtle_beach));(exists (and city (north_of myrtle_beach))) 24 | capital;(exists capital);(and capital (of north_carolina));(exists (and capital (of north_carolina))) 25 | island;(exists island);(and island (near charleston));(exists (and island (near charleston))) 26 | beach;(exists beach);(and beach city);(exists (and beach city));(and beach city (of charleston));(exists (and beach city (of charleston)));(and beach city (of charleston) (near north_carolina));(exists (and beach city (of charleston) (near north_carolina))) 27 | (near charleston);(exists (near charleston)) 28 | lake;(exists lake);(and lake (to francis_marion_national_forest));(exists (and lake (to francis_marion_national_forest))) 29 | hilton_head_island;(exists hilton_head_island);(and hilton_head_island (in south_carolina));(exists (and hilton_head_island (in south_carolina))) 30 | -------------------------------------------------------------------------------- /extra/geo/tn.sps: -------------------------------------------------------------------------------- 1 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee))) 2 | state;(exists state) 3 | park;(exists park) 4 | (in state);(exists (in state));(and (in state) knoxville);(exists (and (in state) knoxville)) 5 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee))) 6 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee))) 7 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee))) 8 | city;(exists city);(and city (west_of nashville));(exists (and city (west_of nashville))) 9 | park;(exists park);(and park (near knoxville));(exists (and park (near knoxville))) 10 | state;(exists state);(and state (south_of tennessee));(exists (and state (south_of tennessee))) 11 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee)));(and city (in tennessee) (east_of nashville));(exists (and city (in tennessee) (east_of nashville))) 12 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee)));(and city (in tennessee) (west_of nashville));(exists (and city (in tennessee) (west_of nashville))) 13 | park;(exists park);(and park (in tennessee));(exists (and park (in tennessee))) 14 | major;(exists major);(and major city);(exists (and major city));(and major city (in tennessee));(exists (and major city (in tennessee)));(and major city (in tennessee) (to alabama));(exists (and major city (in tennessee) (to alabama))) 15 | tennessee;(exists tennessee);(and tennessee bigger);(exists (and tennessee bigger));(and tennessee bigger alabama);(exists (and tennessee bigger alabama)) 16 | alabama;(exists alabama);(and alabama (border tennessee));(exists (and alabama (border tennessee))) 17 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee)));(and city (in tennessee) (west_of knoxville));(exists (and city (in tennessee) (west_of knoxville))) 18 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee)));(and city (in tennessee) (east_of memphis));(exists (and city (in tennessee) (east_of memphis))) 19 | major;(exists major);(and major city);(exists (and major city));(and major city park);(exists (and major city park));(and major city park (east_of nashville));(exists (and major city park (east_of nashville))) 20 | city;(exists city);(and city (in tennessee));(exists (and city (in tennessee)));(and city (in tennessee) (between knoxville_memphis));(exists (and city (in tennessee) (between knoxville_memphis))) 21 | -------------------------------------------------------------------------------- /extra/geo/va.sps: -------------------------------------------------------------------------------- 1 | state;(exists state) 2 | capital;(exists capital);(and capital (of virginia));(exists (and capital (of virginia))) 3 | city;(exists city);(and city (in virginia));(exists (and city (in virginia))) 4 | city;(exists city);(and city (on ocean));(exists (and city (on ocean))) 5 | ocean;(exists ocean) 6 | capital;(exists capital);(and capital (of virginia));(exists (and capital (of virginia))) 7 | city;(exists city);(and city (to virginia_beach));(exists (and city (to virginia_beach))) 8 | west_virginia;(exists west_virginia);(and west_virginia (border atlantic_ocean));(exists (and west_virginia (border atlantic_ocean))) 9 | state;(exists state);(and state (south_of virginia));(exists (and state (south_of virginia))) 10 | state;(exists state);(and state (east_of west_virginia));(exists (and state (east_of west_virginia))) 11 | city;(exists city);(and city (in virginia));(exists (and city (in virginia))) 12 | city;(exists city);(and city (in west_virginia));(exists (and city (in west_virginia))) 13 | city;(exists city);(and city northernmost);(exists (and city northernmost));(and city northernmost (in virginia));(exists (and city northernmost (in virginia))) 14 | southernmost;(exists southernmost);(and southernmost city);(exists (and southernmost city));(and southernmost city (in west_virginia));(exists (and southernmost city (in west_virginia))) 15 | state;(exists state);(and state (border atlantic_ocean));(exists (and state (border atlantic_ocean))) 16 | city;(exists city);(and city along);(exists (and city along));(and city along atlantic_ocean);(exists (and city along atlantic_ocean)) 17 | city;(exists city);(and city (near richmond));(exists (and city (near richmond))) 18 | state;(exists state);(and state (west_of virginia));(exists (and state (west_of virginia))) 19 | state;(exists state);(and state (south_of west_virginia));(exists (and state (south_of west_virginia))) 20 | city;(exists city);(and city (near richmond));(exists (and city (near richmond))) 21 | city;(exists city);(and city (south_east_of richmond));(exists (and city (south_east_of richmond))) 22 | state;(exists state);(and state (on atlantic_ocean));(exists (and state (on atlantic_ocean))) 23 | state;(exists state);(and state (in_inv virginia_beach));(exists (and state (in_inv virginia_beach))) 24 | ocean;(exists ocean);(and ocean (border virginia));(exists (and ocean (border virginia))) 25 | ocean;(exists ocean);(and ocean (border west_virginia));(exists (and ocean (border west_virginia))) 26 | capital;(exists capital);(and capital (of virginia));(exists (and capital (of virginia))) 27 | state;(exists state);(and state (in__inv_east_of west_virginia));(exists (and state (in__inv_east_of west_virginia))) 28 | state;(exists state);(and state (in__inv_west_of virginia_virginia));(exists (and state (in__inv_west_of virginia_virginia))) 29 | richmond;(exists richmond);(and richmond capital);(exists (and richmond capital));(and richmond capital (of west_virginia));(exists (and richmond capital (of west_virginia))) 30 | virginia;(exists virginia);(and virginia (of west_virginia));(exists (and virginia (of west_virginia))) 31 | virginia_beach;(exists virginia_beach);(and virginia_beach (of west_virginia));(exists (and virginia_beach (of west_virginia))) 32 | state;(exists state);(and state (in__inv virginia));(exists (and state (in__inv virginia))) 33 | richmond;(exists richmond);(and richmond capital);(exists (and richmond capital));(and richmond capital (of virginia));(exists (and richmond capital (of virginia))) 34 | virginia_beach;(exists virginia_beach);(and virginia_beach (in atlantic_ocean));(exists (and virginia_beach (in atlantic_ocean))) 35 | atlantic_ocean;(exists atlantic_ocean);(and atlantic_ocean (in virginia));(exists (and atlantic_ocean (in virginia))) 36 | richmond;(exists richmond);(and richmond (in atlantic_ocean));(exists (and richmond (in atlantic_ocean))) 37 | atlantic_ocean;(exists atlantic_ocean);(and atlantic_ocean (in west_virginia));(exists (and atlantic_ocean (in west_virginia))) 38 | atlantic_ocean;(exists atlantic_ocean);(and atlantic_ocean (between virginia_west_virginia));(exists (and atlantic_ocean (between virginia_west_virginia))) 39 | -------------------------------------------------------------------------------- /extra/geo/wv.sps: -------------------------------------------------------------------------------- 1 | state;(exists state) 2 | park;(exists park) 3 | lake;(exists lake) 4 | capital;(exists capital);(and capital (of virginia));(exists (and capital (of virginia))) 5 | city;(exists city);(and city (in virginia));(exists (and city (in virginia))) 6 | state;(exists state);(and state (west_of virginia));(exists (and state (west_of virginia))) 7 | state;(exists state);(and state (in__inv monongahela_national_forest));(exists (and state (in__inv monongahela_national_forest))) 8 | park;(exists park);(and park (in west_virginia));(exists (and park (in west_virginia))) 9 | richmond;(exists richmond);(and richmond (east_of west_virginia));(exists (and richmond (east_of west_virginia))) 10 | virginia;(exists virginia);(and virginia (east_of west_virginia));(exists (and virginia (east_of west_virginia))) 11 | -------------------------------------------------------------------------------- /extra/vqa/parse.old.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | import re 4 | import sys 5 | 6 | def name_to_match(name, used_names): 7 | if isinstance(name, int) and name not in used_names: 8 | r_name = "(?P[/\\w-]+)" % name 9 | used_names.add(name) 10 | elif isinstance(name, int): 11 | r_name = "(?P=g%s)" % name 12 | else: 13 | r_name = name 14 | return r_name 15 | 16 | def render(spec, match): 17 | if isinstance(spec, str): 18 | return spec 19 | if isinstance(spec, int): 20 | rel = match.group("rel") 21 | token = match.group("g%d" % spec) 22 | word = token.split("/")[0] 23 | if "subj" in rel: 24 | #return "subj__" + word 25 | return word 26 | elif "obj" in rel: 27 | #return "obj__" + word 28 | return word 29 | else: 30 | return word 31 | 32 | return "(%s)" % " ".join([render(p, match) for p in spec]) 33 | 34 | def match_simple_query(query, edges): 35 | edge_spec, query_spec = query 36 | used_names = set() 37 | query_re = ".*" 38 | for rel, head, tail in edge_spec: 39 | r_head = name_to_match(head, used_names) 40 | r_tail = name_to_match(tail, used_names) 41 | query_re += "(?P%s)\(%s, %s\).*" % (rel, r_head, r_tail) 42 | 43 | m = re.match(query_re, edges.lower()) 44 | if m is None: 45 | return None 46 | 47 | return render(query_spec, m) 48 | 49 | verb_query_1 = ( 50 | [ 51 | ("(nsubj|dobj|det|dep|nsubjpass|acl:relcl)", 0, r"what/w[a-z]+-\d+"), 52 | ], 53 | ("what", 0) 54 | ) 55 | 56 | verb_query_2 = ( 57 | [ 58 | ("(nsubj|dobj|det|dep|nsubjpass|acl:relcl)", r"what/w[a-z]+-\d+", 0), 59 | ], 60 | ("what", 0) 61 | ) 62 | 63 | SIMPLE_QUERIES = [ 64 | verb_query_1, 65 | verb_query_2 66 | ] 67 | def make_simple_query(edges): 68 | for query in SIMPLE_QUERIES: 69 | m = match_simple_query(query, edges) 70 | if m: break 71 | return m 72 | 73 | #QUERY_RE = r"(.*)\((.*)/\w+-([\d\']+), (.*)/\w+-([\d\']+)\)" 74 | #FORBIDDEN_RELS = ["acl:relcl"] 75 | def make_what_query(query_lines): 76 | joined_lines = "".join(query_lines) 77 | q = make_simple_query(joined_lines) 78 | return q 79 | 80 | CONTENT_RE = r"([^/]+)/(NN|VB|JJ)" 81 | 82 | if __name__ == "__main__": 83 | queries = [] 84 | query_lines = [] 85 | got_question = False 86 | question = None 87 | for line in sys.stdin: 88 | sline = line.strip() 89 | if sline == "" and not got_question: 90 | got_question = True 91 | question = query_lines[0].lower() 92 | query_lines = [] 93 | elif sline == "": 94 | got_question = False 95 | 96 | #print 97 | #print question 98 | 99 | #if "what is the color of the" in question: 100 | # obj = re.search(r", (.*)/", query_lines[6]).group(1) 101 | # print "(color %s)" % obj 102 | m = re.match(r"^what (\w+) (is|are)", question) 103 | if m is not None: 104 | wh = m.group(1) 105 | content = query_lines[3:] 106 | content = [w.split()[1] for w in content] 107 | content = [re.match(CONTENT_RE, w) for w in content] 108 | content = [m for m in content if m] 109 | if len(content) == 0: 110 | target = "object" 111 | else: 112 | target = content[0].group(1).lower() 113 | 114 | print "(%s %s)" % (wh, target) 115 | query_lines = [] 116 | continue 117 | 118 | m = re.match(r"^what (is|are) the (\w+) of", question) 119 | if m is not None: 120 | wh = m.group(1) 121 | content = query_lines[5:] 122 | content = [w.split()[1] for w in content] 123 | content = [re.match(CONTENT_RE, w) for w in content] 124 | content = [m for m in content if m] 125 | if len(content) == 0: 126 | target = "object" 127 | else: 128 | target = content[0].group(1).lower() 129 | 130 | print "(%s %s)" % (wh, target) 131 | query_lines = [] 132 | continue 133 | 134 | m = re.match(r"^what (\w+) of", question) 135 | if m is not None: 136 | wh = m.group(1) 137 | content = query_lines[3:] 138 | content = [w.split()[1] for w in content] 139 | content = [re.match(CONTENT_RE, w) for w in content] 140 | content = [m for m in content if m] 141 | if len(content) == 0: 142 | target = "object" 143 | else: 144 | target = content[0].group(1).lower() 145 | 146 | print "(%s %s)" % (wh, target) 147 | query_lines = [] 148 | continue 149 | 150 | m = re.match(r"^what", question) 151 | if m is not None: 152 | content = query_lines[1:] 153 | content = [w.split()[1] for w in content] 154 | content = [re.match(CONTENT_RE, w) for w in content] 155 | content = [m for m in content if m] 156 | content = [m for m in content if m.group(1) not in ("be", "do", "have")] 157 | if len(content) == 0: 158 | target = "object" 159 | else: 160 | target = content[0].group(1).lower() 161 | 162 | print "(what %s)" % target 163 | query_lines = [] 164 | continue 165 | 166 | m = re.match(r"^(is|are|has|have|were|did|does|do|was) ", question) 167 | if m is not None: 168 | wh = m.group(1) 169 | content = query_lines[1:] 170 | content = [w.split()[1] for w in content] 171 | content = [re.match(CONTENT_RE, w) for w in content] 172 | content = [m for m in content if m] 173 | if len(content) == 0: 174 | print "none" 175 | elif len(content) == 1: 176 | print "(is1 %s)" % content[0].group(1).lower() 177 | else: 178 | print "(is2 %s %s)" % (content[0].group(1).lower(), content[1].group(1).lower()) 179 | 180 | query_lines = [] 181 | continue 182 | 183 | m = re.match(r"^how many", question) 184 | if m is not None: 185 | content = query_lines[2:] 186 | content = [w.split()[1] for w in content] 187 | content = [re.match(CONTENT_RE, w) for w in content] 188 | content = [m for m in content if m] 189 | if len(content) == 0: 190 | target = "object" 191 | else: 192 | target = content[0].group(1).lower() 193 | 194 | print "(how_many %s)" % target 195 | query_lines = [] 196 | continue 197 | 198 | m = re.match(r"^where", question) 199 | if m is not None: 200 | content = query_lines[2:] 201 | content = [w.split()[1] for w in content] 202 | content = [re.match(CONTENT_RE, w) for w in content] 203 | content = [m for m in content if m] 204 | if len(content) == 0: 205 | target = "object" 206 | else: 207 | target = content[0].group(1).lower() 208 | 209 | print "(where %s)" % target 210 | query_lines = [] 211 | continue 212 | 213 | m = re.match(r"^(can|could) ", question) 214 | if m is not None: 215 | wh = m.group(1) 216 | content = query_lines[1:] 217 | content = [w.split()[1] for w in content] 218 | content = [re.match(CONTENT_RE, w) for w in content] 219 | content = [m for m in content if m] 220 | if len(content) == 0: 221 | print "(what object)" 222 | elif len(content) == 1: 223 | print "(can1 %s)" % content[0].group(1).lower() 224 | else: 225 | print "(can2 %s %s)" % (content[0].group(1).lower(), content[1].group(1).lower()) 226 | 227 | query_lines = [] 228 | continue 229 | 230 | print "(what object)" 231 | query_lines = [] 232 | continue 233 | 234 | 235 | ##### 236 | 237 | if "how many" in question: 238 | mtch = None 239 | idx = 2 240 | while mtch is None and idx < len(query_lines): 241 | mtch = re.search(r", (.*)/N", query_lines[idx]) 242 | idx += 1 243 | if mtch is None: 244 | mtch = re.search(r", (.*)/", query_lines[2]) 245 | assert mtch is not None 246 | obj = mtch.group(1) 247 | print "(count %s)" % obj 248 | 249 | elif "where" in question: 250 | mtch = None 251 | idx = 2 252 | while mtch is None and idx < len(query_lines): 253 | mtch = re.search(r", (.*)/N", query_lines[idx]) 254 | idx += 1 255 | if mtch is None and len(query_lines) >= 3: 256 | mtch = re.search(r", (.*)/", query_lines[2]) 257 | if mtch is not None: 258 | obj = mtch.group(1) 259 | else: 260 | obj = "object" 261 | print "(where %s)" % obj 262 | 263 | else: 264 | query = make_what_query(query_lines) 265 | if query is None: 266 | #print "\n".join(query_lines) 267 | #print "warning: null query" 268 | query = "(_what _thing)" 269 | print query 270 | 271 | #query = convert_to_query(query_lines) 272 | #queries.append(query) 273 | 274 | #print question 275 | #print "\n".join(query_lines) 276 | #if query is None: 277 | # print "none" 278 | #else: 279 | # print query 280 | #print 281 | 282 | query_lines = [] 283 | else: 284 | query_lines.append(sline) 285 | 286 | #print len(queries) 287 | -------------------------------------------------------------------------------- /extra/vqa/parse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from collections import namedtuple 4 | import itertools 5 | import re 6 | import sys 7 | 8 | Node = namedtuple("Node", ["word", "tag", "parent", "rel", "path"]) 9 | 10 | BE_FORMS="is|are|was|were|has|have|had|does|do|did|be" 11 | 12 | WH_RES = [ 13 | r"^what (\w+) (is|are)", 14 | r"^what (is|are) the (\w+) of", 15 | r"^what (\w+) of", 16 | r"^(what|which|where)", 17 | r"(%s)" % BE_FORMS, 18 | r"^(how many)", 19 | r"^(can|could)" 20 | ] 21 | 22 | EDGE_RE = re.compile(r"([^()]+)\((.+)-(\d+), (.+)-(\d+)\)") 23 | CONTENT_RE = re.compile(r"NN*|VB*|JJ*") 24 | #CONTENT_RE = re.compile(r"NN|VB|JJ") 25 | 26 | REL_PRECEDENCE = ["root", "nsubj", "dobj", "nsubjpass", "dep", "xcomp"] 27 | 28 | def precedence(rel): 29 | if "nmod:" in rel: 30 | return len(REL_PRECEDENCE) 31 | if "conj:" in rel: 32 | return len(REL_PRECEDENCE) + 1 33 | if "acl:" in rel: 34 | return len(REL_PRECEDENCE) + 1 35 | return REL_PRECEDENCE.index(rel) 36 | 37 | class LfParser(object): 38 | def __init__(self, use_relations, max_leaves, max_conjuncts): 39 | self.use_relations = use_relations 40 | self.max_leaves = max_leaves 41 | self.max_conjuncts = max_conjuncts 42 | 43 | def extract_nodes(self, content): 44 | nodes = {} 45 | for edge in content: 46 | rel, w1, i1, w2, i2 = EDGE_RE.match(edge.replace("'", "")).groups() 47 | i1 = int(i1) 48 | i2 = int(i2) 49 | w2, t2 = w2.rsplit("/", 1) 50 | node = Node(w2.lower(), t2, i1, rel, []) 51 | if i2 in nodes: 52 | if precedence(node.rel) < precedence(nodes[i2].rel): 53 | nodes[i2] = node 54 | else: 55 | nodes[i2] = node 56 | return nodes 57 | 58 | def annotate_paths(self, nodes): 59 | for i, node in nodes.items(): 60 | path = node.path 61 | at = node 62 | hit = {i} 63 | while at.parent in nodes: 64 | if "nmod:" in at.rel: 65 | path.append(at.rel.split(":")[1]) 66 | if at.parent in hit: 67 | break 68 | hit.add(at.parent) 69 | at = nodes[at.parent] 70 | 71 | def extract_predicates(self, nodes): 72 | preds = [] 73 | for i, node in sorted(nodes.items()): 74 | if not CONTENT_RE.match(node.tag): 75 | continue 76 | if re.match(BE_FORMS, node.word): 77 | continue 78 | pred = node.word 79 | if len(node.path) > 0 and self.use_relations: 80 | pred = "(%s %s)" % (node.path[0], pred) 81 | preds.append(pred) 82 | return list(set(preds)) 83 | 84 | def make_lfs(self, wh, content): 85 | nodes = self.extract_nodes(content) 86 | self.annotate_paths(nodes) 87 | predicates = self.extract_predicates(nodes) 88 | if self.max_leaves is not None: 89 | predicates = predicates[:self.max_leaves] 90 | out = [] 91 | for i in range(1, max(self.max_conjuncts + 1, len(predicates))): 92 | comb = itertools.combinations(predicates, i) 93 | for pred_comb in comb: 94 | if len(pred_comb) == 1: 95 | out.append("(%s %s)" % (wh, pred_comb[0])) 96 | else: 97 | out.append("(%s (and %s))" % (wh, " ".join(pred_comb))) 98 | return out 99 | 100 | def parse_all(self, stream): 101 | queries = [] 102 | query_lines = [] 103 | got_question = False 104 | question = None 105 | for line in stream: 106 | sline = line.strip() 107 | if sline == "" and not got_question: 108 | got_question = True 109 | question = query_lines[0].lower() 110 | query_lines = [] 111 | elif sline == "": 112 | got_question = False 113 | 114 | queries = None 115 | for expr in WH_RES: 116 | m = re.match(expr, question) 117 | if m is None: 118 | continue 119 | wh = m.group(1).replace(" ", "_") 120 | if re.match(BE_FORMS, wh): 121 | wh = "is" 122 | n_expr_words = len(expr.split(" ")) 123 | content = query_lines[n_expr_words:] 124 | queries = self.make_lfs(wh, content) 125 | success = True 126 | break 127 | 128 | if not queries: 129 | queries = ["(_what _thing)"] 130 | 131 | yield queries 132 | query_lines = [] 133 | 134 | else: 135 | query_lines.append(sline) 136 | 137 | """ 138 | This script consumes output from the Stanford parser on stdin. I run the parser as 139 | 140 | java -mx150m -cp "$scriptdir/*:" edu.stanford.nlp.parser.lexparser.LexicalizedParser \ 141 | -outputFormat "words,typedDependencies" -outputFormatOptions "stem,collapsedDependencies,includeTags" \ 142 | -sentences newline \ 143 | edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz \ 144 | $* 145 | """ 146 | 147 | if __name__ == "__main__": 148 | #parser = LfParser(use_relations=True, max_conjuncts=2, max_leaves=None) 149 | parser = LfParser(use_relations=False, max_conjuncts=2, max_leaves=2) 150 | for parses in parser.parse_all(sys.stdin): 151 | print ";".join(parses) 152 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobandreas/nmn2/7e42dd98420f9580fd34185ba670490a5d86fb04/layers/__init__.py -------------------------------------------------------------------------------- /layers/reinforce.py: -------------------------------------------------------------------------------- 1 | from apollocaffe.layers.layer_headers import Layer, PyLayer, LossLayer 2 | from IPython import embed 3 | 4 | import numpy as np 5 | 6 | class Index(PyLayer): 7 | def forward(self, bottom, top): 8 | data, indices = bottom 9 | index_data = indices.data.astype(int) 10 | top[0].reshape(indices.shape) 11 | top[0].data[...] = data.data[range(indices.shape[0]), index_data] 12 | 13 | def backward(self, top, bottom): 14 | data, indices = bottom 15 | index_data = indices.data.astype(int) 16 | data.diff[...] = 0 17 | data.diff[range(indices.shape[0]), index_data] = top[0].diff 18 | 19 | class AsLoss(PyLayer): 20 | def __init__(self, name, **kwargs): 21 | PyLayer.__init__(self, name, dict(), **kwargs) 22 | 23 | def forward(self, bottom, top): 24 | top[0].reshape(bottom[0].shape) 25 | top[0].data[...] = bottom[0].data 26 | 27 | def backward(self, top, bottom): 28 | bottom[0].diff[...] = top[0].data 29 | 30 | class Reduction(LossLayer): 31 | def __init__(self, name, axis, **kwargs): 32 | kwargs["axis"] = axis 33 | super(Reduction, self).__init__(self, name, kwargs) 34 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | # check profiler 4 | if not isinstance(__builtins__, dict) or "profile" not in __builtins__: 5 | __builtins__.__dict__["profile"] = lambda x: x 6 | 7 | from misc import util 8 | from misc.indices import QUESTION_INDEX, ANSWER_INDEX, MODULE_INDEX, MODULE_TYPE_INDEX, \ 9 | NULL, NULL_ID, UNK_ID 10 | from misc.visualizer import visualizer 11 | import models 12 | from models.nmn import MLPFindModule, MultiplicativeFindModule 13 | import tasks 14 | 15 | import apollocaffe 16 | import argparse 17 | import json 18 | import logging.config 19 | import random 20 | import numpy as np 21 | import yaml 22 | 23 | def main(): 24 | config = configure() 25 | task = tasks.load_task(config) 26 | model = models.build_model(config.model, config.opt) 27 | 28 | for i_epoch in range(config.opt.iters): 29 | 30 | train_loss, train_acc, _ = \ 31 | do_iter(task.train, model, config, train=True) 32 | val_loss, val_acc, val_predictions = \ 33 | do_iter(task.val, model, config, vis=True) 34 | test_loss, test_acc, test_predictions = \ 35 | do_iter(task.test, model, config) 36 | 37 | logging.info( 38 | "%5d | %8.3f %8.3f %8.3f | %8.3f %8.3f %8.3f", 39 | i_epoch, 40 | train_loss, val_loss, test_loss, 41 | train_acc, val_acc, test_acc) 42 | 43 | with open("logs/val_predictions_%d.json" % i_epoch, "w") as pred_f: 44 | print >>pred_f, json.dumps(val_predictions) 45 | 46 | #with open("logs/test_predictions_%d.json" % i_epoch, "w") as pred_f: 47 | # print >>pred_f, json.dumps(test_predictions) 48 | 49 | def configure(): 50 | apollocaffe.set_random_seed(0) 51 | np.random.seed(0) 52 | random.seed(0) 53 | 54 | arg_parser = argparse.ArgumentParser() 55 | arg_parser.add_argument( 56 | "-c", "--config", dest="config", required=True, 57 | help="model configuration file") 58 | arg_parser.add_argument( 59 | "-l", "--log-config", dest="log_config", default="config/log.yml", 60 | help="log configuration file") 61 | 62 | args = arg_parser.parse_args() 63 | config_name = args.config.split("/")[-1].split(".")[0] 64 | 65 | with open(args.log_config) as log_config_f: 66 | log_filename = "logs/%s.log" % config_name 67 | log_config = yaml.load(log_config_f) 68 | log_config["handlers"]["fileHandler"]["filename"] = log_filename 69 | logging.config.dictConfig(log_config) 70 | 71 | with open(args.config) as config_f: 72 | config = util.Struct(**yaml.load(config_f)) 73 | 74 | assert not hasattr(config, "name") 75 | config.name = config_name 76 | 77 | return config 78 | 79 | def do_iter(task_set, model, config, train=False, vis=False): 80 | loss = 0.0 81 | acc = 0.0 82 | predictions = [] 83 | n_batches = 0 84 | 85 | # sort first to guarantee deterministic behavior with fixed seed 86 | data = list(sorted(task_set.data)) 87 | np.random.shuffle(data) 88 | 89 | if vis: 90 | visualizer.begin(config.name, 100) 91 | 92 | for batch_start in range(0, len(data), config.opt.batch_size): 93 | batch_end = batch_start + config.opt.batch_size 94 | batch_data = data[batch_start:batch_end] 95 | 96 | batch_loss, batch_acc, batch_preds = do_batch( 97 | batch_data, model, config, train, vis) 98 | 99 | loss += batch_loss 100 | acc += batch_acc 101 | predictions += batch_preds 102 | n_batches += 1 103 | 104 | if vis: 105 | visualize(batch_data, model) 106 | 107 | if vis: 108 | visualizer.end() 109 | 110 | if n_batches == 0: 111 | return 0, 0, dict() 112 | assert len(predictions) == len(data) 113 | loss /= n_batches 114 | acc /= n_batches 115 | return loss, acc, predictions 116 | 117 | def do_batch(data, model, config, train, vis): 118 | predictions = forward(data, model, config, train, vis) 119 | answer_loss = backward(data, model, config, train, vis) 120 | acc = compute_acc(predictions, data, config) 121 | 122 | return answer_loss, acc, predictions 123 | 124 | # TODO this is ugly and belongs somewhere else 125 | def featurize_layouts(datum, max_layouts): 126 | # TODO pre-fill module type index 127 | layout_reprs = np.zeros((max_layouts, len(MODULE_INDEX) + 7)) 128 | for i_layout in range(len(datum.layouts)): 129 | layout = datum.layouts[i_layout] 130 | labels = util.flatten(layout.labels) 131 | modules = util.flatten(layout.modules) 132 | for i_mod in range(len(modules)): 133 | if isinstance(modules[i_mod], MLPFindModule) or isinstance(modules[i_mod], MultiplicativeFindModule): 134 | layout_reprs[i_layout, labels[i_mod]] += 1 135 | mt = MODULE_TYPE_INDEX.index(modules[i_mod]) 136 | layout_reprs[i_layout, len(MODULE_INDEX) + mt] += 1 137 | return layout_reprs 138 | 139 | def forward(data, model, config, train, vis): 140 | model.reset() 141 | 142 | # load batch data 143 | max_len = max(len(d.question) for d in data) 144 | max_layouts = max(len(d.layouts) for d in data) 145 | channels, size, trailing = data[0].load_features().shape 146 | assert trailing == 1 147 | has_rel_features = data[0].load_rel_features() is not None 148 | if has_rel_features: 149 | rel_channels, size_1, size_2 = data[0].load_rel_features().shape 150 | assert size_1 == size_2 == size 151 | questions = np.ones((config.opt.batch_size, max_len)) * NULL_ID 152 | features = np.zeros((config.opt.batch_size, channels, size, 1)) 153 | if has_rel_features: 154 | rel_features = np.zeros((config.opt.batch_size, rel_channels, size, size)) 155 | else: 156 | rel_features = None 157 | layout_reprs = np.zeros( 158 | (config.opt.batch_size, max_layouts, len(MODULE_INDEX) + 7)) 159 | for i, datum in enumerate(data): 160 | questions[i, max_len-len(datum.question):] = datum.question 161 | features[i, ...] = datum.load_features() 162 | if has_rel_features: 163 | rel_features[i, ...] = datum.load_rel_features() 164 | layout_reprs[i, ...] = featurize_layouts(datum, max_layouts) 165 | layouts = [d.layouts for d in data] 166 | 167 | # apply model 168 | model.forward( 169 | layouts, layout_reprs, questions, features, rel_features, 170 | dropout=(train and config.opt.dropout), deterministic=not train) 171 | 172 | # extract predictions 173 | if config.opt.multiclass: 174 | pred_words = [] 175 | for i in range(model.prediction_data.shape[0]): 176 | preds = model.prediction_data[i, :] 177 | chosen = np.where(preds > 0.5)[0] 178 | pred_words.append(set(ANSWER_INDEX.get(w) for w in chosen)) 179 | else: 180 | pred_ids = np.argmax(model.prediction_data, axis=1) 181 | pred_words = [ANSWER_INDEX.get(w) for w in pred_ids] 182 | predictions = list() 183 | for i in range(len(data)): 184 | qid = data[i].id 185 | answer = pred_words[i] 186 | predictions.append({"question_id": qid, "answer": answer}) 187 | 188 | return predictions 189 | 190 | def backward(data, model, config, train, vis): 191 | n_answers = len(data[0].answers) 192 | loss = 0 193 | 194 | for i in range(n_answers): 195 | if config.opt.multiclass: 196 | output_i = np.zeros((config.opt.batch_size, len(ANSWER_INDEX))) 197 | for i_datum, datum in enumerate(data): 198 | for answer in datum.answers[i]: 199 | output_i[i_datum, answer] = 1 200 | else: 201 | output_i = UNK_ID * np.ones(config.opt.batch_size) 202 | output_i[:len(data)] = \ 203 | np.asarray([d.answers[i] for d in data]) 204 | loss += model.loss(output_i, multiclass=config.opt.multiclass) 205 | 206 | if train: 207 | model.train() 208 | 209 | return loss 210 | 211 | def visualize(batch_data, model): 212 | i_datum = 0 213 | #mod_layout_choice = model.module_layout_choices[i_datum] 214 | #print model.apollo_net.blobs.keys() 215 | #att_blob_name = "Find_%d_softmax" % (mod_layout_choice * 100 + 1) 216 | # 217 | datum = batch_data[i_datum] 218 | question = " ".join([QUESTION_INDEX.get(w) for w in datum.question[1:-1]]), 219 | preds = model.prediction_data[i_datum,:] 220 | top = np.argsort(preds)[-5:] 221 | top_answers = reversed([ANSWER_INDEX.get(p) for p in top]) 222 | #att_data = model.apollo_net.blobs[att_blob_name].data[i_datum,...] 223 | #att_data = att_data.reshape((14, 14)) 224 | att_data = np.zeros((14, 14)) 225 | chosen_parse = datum.parses[model.layout_ids[i_datum]] 226 | 227 | fields = [ 228 | question, 229 | str(chosen_parse), 230 | "" % datum.image_path, 231 | att_data, 232 | ", ".join(top_answers), 233 | ", ".join([ANSWER_INDEX.get(a) for a in datum.answers]) 234 | ] 235 | visualizer.show(fields) 236 | 237 | def compute_acc(predictions, data, config): 238 | score = 0.0 239 | for prediction, datum in zip(predictions, data): 240 | pred_answer = prediction["answer"] 241 | if config.opt.multiclass: 242 | answers = [set(ANSWER_INDEX.get(aa) for aa in a) for a in datum.answers] 243 | else: 244 | answers = [ANSWER_INDEX.get(a) for a in datum.answers] 245 | 246 | matching_answers = [a for a in answers if a == pred_answer] 247 | if len(answers) == 1: 248 | score += len(matching_answers) 249 | else: 250 | score += min(len(matching_answers) / 3.0, 1.0) 251 | score /= len(data) 252 | return score 253 | 254 | if __name__ == "__main__": 255 | main() 256 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobandreas/nmn2/7e42dd98420f9580fd34185ba670490a5d86fb04/misc/__init__.py -------------------------------------------------------------------------------- /misc/datum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from collections import namedtuple 4 | 5 | class Layout: 6 | def __init__(self, modules, labels): 7 | #assert isinstance(modules, tuple) 8 | #assert isinstance(labels, tuple) 9 | self.modules = modules 10 | self.labels = labels 11 | 12 | def __eq__(self, other): 13 | return isinstance(other, Layout) and \ 14 | other.modules == self.modules and \ 15 | other.labels == self.labels 16 | 17 | def __hash__(self): 18 | return hash(self.modules) + 3 * hash(self.labels) 19 | 20 | def __str__(self): 21 | return self.__str_helper(self.modules, self.labels) 22 | 23 | def __str_helper(self, modules, labels): 24 | if isinstance(modules, tuple): 25 | mhead, mtail = modules[0], modules[1:] 26 | ihead, itail = labels[0], labels[1:] 27 | mod_name = str(mhead) # mhead.__name__ 28 | below = [self.__str_helper(m, i) for m, i in zip(mtail, itail)] 29 | return "(%s[%s] %s)" % (mod_name, ihead, " ".join(below)) 30 | 31 | mod_name = str(modules) 32 | return "%s[%s]" % (mod_name, labels) 33 | 34 | class Datum: 35 | def __init__(self): 36 | self.id = None 37 | self.string = None 38 | self.outputs = None 39 | 40 | def load_input(self): 41 | raise NotImplementedError() 42 | -------------------------------------------------------------------------------- /misc/indices.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from util import Index 4 | 5 | UNK = "*unknown*" 6 | NULL = "*null*" 7 | 8 | QUESTION_INDEX = Index() 9 | MODULE_INDEX = Index() 10 | MODULE_TYPE_INDEX = Index() 11 | ANSWER_INDEX = Index() 12 | 13 | UNK_ID = QUESTION_INDEX.index(UNK) 14 | MODULE_INDEX.index(UNK) 15 | ANSWER_INDEX.index(UNK) 16 | 17 | NULL_ID = QUESTION_INDEX.index(NULL) 18 | #MODULE_INDEX.index(NULL) 19 | #ANSWER_INDEX.index(NULL) 20 | -------------------------------------------------------------------------------- /misc/parse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | import sexpdata 4 | 5 | def parse_tree(p): 6 | if "'" in p: 7 | p = "none" 8 | parsed = sexpdata.loads(p) 9 | extracted = extract_parse(parsed) 10 | return extracted 11 | 12 | def extract_parse(p): 13 | if isinstance(p, sexpdata.Symbol): 14 | return p.value() 15 | elif isinstance(p, int): 16 | return str(p) 17 | elif isinstance(p, bool): 18 | return str(p).lower() 19 | elif isinstance(p, float): 20 | return str(p).lower() 21 | return tuple(extract_parse(q) for q in p) 22 | -------------------------------------------------------------------------------- /misc/util.py: -------------------------------------------------------------------------------- 1 | class Struct: 2 | def __init__(self, **entries): 3 | rec_entries = {} 4 | for k, v in entries.items(): 5 | if isinstance(v, dict): 6 | rv = Struct(**v) 7 | elif isinstance(v, list): 8 | rv = [] 9 | for item in v: 10 | if isinstance(item, dict): 11 | rv.append(Struct(**item)) 12 | else: 13 | rv.append(item) 14 | else: 15 | rv = v 16 | rec_entries[k] = rv 17 | self.__dict__.update(rec_entries) 18 | 19 | def __str_helper(self, depth): 20 | lines = [] 21 | for k, v in self.__dict__.items(): 22 | if isinstance(v, Struct): 23 | v_str = v.__str_helper(depth + 1) 24 | lines.append("%s:\n%s" % (k, v_str)) 25 | else: 26 | lines.append("%s: %r" % (k, v)) 27 | indented_lines = [" " * depth + l for l in lines] 28 | return "\n".join(indented_lines) 29 | 30 | def __str__(self): 31 | return "struct {\n%s\n}" % self.__str_helper(1) 32 | 33 | def __repr__(self): 34 | return "Struct(%r)" % self.__dict__ 35 | 36 | class Index: 37 | def __init__(self): 38 | self.contents = dict() 39 | self.ordered_contents = [] 40 | self.reverse_contents = dict() 41 | 42 | def __getitem__(self, item): 43 | if item not in self.contents: 44 | return None 45 | return self.contents[item] 46 | 47 | def index(self, item): 48 | if item not in self.contents: 49 | idx = len(self.contents) + 1 50 | self.ordered_contents.append(item) 51 | self.contents[item] = idx 52 | self.reverse_contents[idx] = item 53 | idx = self[item] 54 | assert idx != 0 55 | return idx 56 | 57 | def get(self, idx): 58 | if idx == 0: 59 | return "*invalid*" 60 | return self.reverse_contents[idx] 61 | 62 | def __len__(self): 63 | return len(self.contents) + 1 64 | 65 | def __iter__(self): 66 | return iter(self.ordered_contents) 67 | 68 | def flatten(lol): 69 | if isinstance(lol, tuple) or isinstance(lol, list): 70 | return sum([flatten(l) for l in lol], []) 71 | else: 72 | return [lol] 73 | 74 | def postorder(tree): 75 | if isinstance(tree, tuple): 76 | for subtree in tree[1:]: 77 | for node in postorder(subtree): 78 | yield node 79 | yield tree[0] 80 | else: 81 | yield tree 82 | 83 | def tree_map(function, tree): 84 | if isinstance(tree, tuple): 85 | head = function(tree) 86 | tail = tuple(tree_map(function, subtree) for subtree in tree[1:]) 87 | return (head,) + tail 88 | return function(tree) 89 | 90 | def tree_zip(*trees): 91 | if isinstance(trees[0], tuple): 92 | zipped_children = [[t[i] for t in trees] for i in range(len(trees[0]))] 93 | zipped_children_rec = [tree_zip(*z) for z in zipped_children] 94 | return tuple(zipped_children_rec) 95 | return trees 96 | -------------------------------------------------------------------------------- /misc/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | import numpy as np 4 | import os 5 | import scipy 6 | 7 | VIS_DIR = "vis" 8 | 9 | class Visualizer: 10 | def __init__(self): 11 | self.active = False 12 | 13 | def begin(self, dest, max_entries): 14 | self.lines = [] 15 | self.active = True 16 | self.max_entries = max_entries 17 | self.next_entry = 0 18 | 19 | self.dest_dir = os.path.join(VIS_DIR, dest) 20 | if not os.path.exists(self.dest_dir): 21 | os.mkdir(self.dest_dir) 22 | 23 | def reset(self): 24 | self.next_entry = 0 25 | self.active = True 26 | 27 | def end(self): 28 | self.active = False 29 | 30 | with open(os.path.join(self.dest_dir, "index.html"), "w") as vis_file: 31 | #print >>vis_file, "" 32 | print >>vis_file, "" 33 | print >>vis_file, "" 34 | print >>vis_file, "
" 35 | for line in self.lines: 36 | print >>vis_file, " " 37 | for field in line: 38 | print >>vis_file, " " 41 | print >>vis_file, " " 42 | print >>vis_file, "
", 39 | print >>vis_file, field, 40 | print >>vis_file, "
" 43 | 44 | def show(self, data): 45 | if not self.active: 46 | return 47 | table_data = [] 48 | for i_field, field in enumerate(data): 49 | if isinstance(field, np.ndarray): 50 | filename = "%d_%d.jpg" % (self.next_entry, i_field) 51 | filepath = os.path.join(self.dest_dir, filename) 52 | scipy.misc.imsave(filepath, field) 53 | table_data.append("" % filename) 54 | else: 55 | table_data.append(str(field)) 56 | 57 | self.lines.append(table_data) 58 | self.next_entry += 1 59 | if self.next_entry >= self.max_entries: 60 | self.active = False 61 | 62 | visualizer = Visualizer() 63 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | import apollocaffe 4 | import caffe 5 | 6 | 7 | apollocaffe.set_device(0) 8 | #apollocaffe.set_random_seed(0) 9 | #apollocaffe.set_cpp_loglevel(1) 10 | 11 | def build_model(config, opt_config): 12 | #elif config.name == "monolithic": 13 | # return MonolithicNMNModel(config, opt_config) 14 | #elif config.name == "lstm": 15 | # return LSTMModel(config, opt_config) 16 | #elif config.name == "ensemble": 17 | # return EnsembleModel(config, opt_config) 18 | #else: 19 | if config.name == "nmn": 20 | from nmn import NmnModel 21 | return NmnModel(config, opt_config) 22 | if config.name == "att": 23 | from att import AttModel 24 | return AttModel(config, opt_config) 25 | if config.name == "sp": 26 | from sp import StoredProgramModel 27 | return StoredProgramModel(config, opt_config) 28 | else: 29 | raise NotImplementedError( 30 | "Don't know how to build a %s model" % config.name) 31 | -------------------------------------------------------------------------------- /models/att.py: -------------------------------------------------------------------------------- 1 | from misc.indices import QUESTION_INDEX, ANSWER_INDEX, UNK_ID 2 | from opt import adadelta 3 | 4 | import apollocaffe 5 | from apollocaffe.layers import * 6 | import numpy as np 7 | 8 | class AttModel: 9 | def __init__(self, config, opt_config): 10 | self.config = config 11 | self.opt_config = opt_config 12 | self.opt_state = adadelta.State() 13 | self.apollo_net = apollocaffe.ApolloNet() 14 | 15 | def train(self): 16 | self.apollo_net.backward() 17 | adadelta.update(self.apollo_net, self.opt_state, self.opt_config) 18 | #self.apollo_net.update( 19 | # lr=0.1, momentum=0.9, clip_gradients=10.0) 20 | 21 | def reset(self): 22 | self.loss_counter = 0 23 | self.att_counter = 0 24 | self.apollo_net.clear_forward() 25 | 26 | @profile 27 | def forward(self, layout_type, batch_layout_labels, question_data, 28 | image_data, dropout): 29 | images = self.forward_image_data(image_data, dropout) 30 | question_hidden = self.forward_lstm(question_data, dropout) 31 | att1_hidden = self.forward_att(question_hidden, images) 32 | #att2_hidden = self.forward_att(att1_hidden) 33 | self.pred_layer = self.forward_pred(att1_hidden) 34 | 35 | self.prediction_data = self.apollo_net.blobs[self.pred_layer].data 36 | self.att_data = self.apollo_net.blobs["att_softmax_0"].data.reshape((-1, 14, 14)) 37 | #self.att_data = np.zeros((100, 14, 14)) 38 | 39 | @profile 40 | def forward_image_data(self, image_data, dropout): 41 | 42 | self.batch_size = image_data.shape[0] 43 | self.channels = image_data.shape[1] 44 | self.image_size = image_data.shape[2] * image_data.shape[3] 45 | 46 | image_data_rs = image_data.reshape( 47 | (self.batch_size, self.channels, self.image_size, 1)) 48 | 49 | net = self.apollo_net 50 | 51 | images = "data_images" 52 | images_dropout = "data_images_dropout" 53 | 54 | if images not in net.blobs: 55 | net.f(DummyData(images, image_data_rs.shape)) 56 | net.blobs[images].data[...] = image_data_rs 57 | 58 | if dropout: 59 | net.f(Dropout(images_dropout, 0.5, bottoms=[images])) 60 | return images_dropout 61 | else: 62 | return images 63 | 64 | @profile 65 | def forward_lstm(self, question_data, dropout): 66 | net = self.apollo_net 67 | batch_size, length = question_data.shape 68 | assert batch_size == self.batch_size 69 | 70 | wordvec_param = "wordvec_param" 71 | input_value_param = "input_value_param" 72 | input_gate_param = "input_gate_param" 73 | forget_gate_param = "forget_gate_param" 74 | output_gate_param = "output_gate_param" 75 | 76 | seed = "lstm_seed" 77 | final_hidden = "lstm_final_hidden" 78 | 79 | prev_hidden = seed 80 | prev_mem = seed 81 | 82 | net.f(NumpyData(seed, np.zeros((batch_size, self.config.lstm_hidden)))) 83 | 84 | for t in range(length): 85 | word = "lstm_word_%d" % t 86 | wordvec = "lstm_wordvec_%d" % t 87 | concat = "lstm_concat_%d" % t 88 | lstm = "lstm_unit_%d" % t 89 | 90 | hidden = "lstm_hidden_%d" % t 91 | mem = "lstm_mem_%d" % t 92 | 93 | net.f(NumpyData(word, question_data[:,t])) 94 | 95 | net.f(Wordvec( 96 | wordvec, self.config.lstm_hidden, len(QUESTION_INDEX), 97 | bottoms=[word], param_names=[wordvec_param])) 98 | 99 | net.f(Concat(concat, bottoms=[prev_hidden, wordvec])) 100 | 101 | net.f(LstmUnit( 102 | lstm, self.config.lstm_hidden, bottoms=[concat, prev_mem], 103 | param_names=[input_value_param, input_gate_param, 104 | forget_gate_param, output_gate_param], 105 | tops=[hidden, mem])) 106 | 107 | prev_hidden = hidden 108 | prev_mem = mem 109 | 110 | if dropout: 111 | net.f(Dropout("lstm_dropout", 0.5, bottoms=[prev_hidden])) 112 | net.f(InnerProduct( 113 | final_hidden, self.config.att_hidden, bottoms=["lstm_dropout"])) 114 | 115 | else: 116 | net.f(InnerProduct( 117 | final_hidden, self.config.att_hidden, bottoms=[prev_hidden])) 118 | 119 | return final_hidden 120 | 121 | @profile 122 | def forward_att(self, last_hidden, images): 123 | net = self.apollo_net 124 | 125 | proj_image = "att_proj_image_%d" % self.att_counter 126 | tile_hidden = "att_broadcast_hidden_%d" % self.att_counter 127 | add = "att_sum_%d" % self.att_counter 128 | relu = "att_relu_%d" % self.att_counter 129 | mask = "att_mask_%d" % self.att_counter 130 | softmax = "att_softmax_%d" % self.att_counter 131 | 132 | tile_mask = "att_tile_mask_%d" % self.att_counter 133 | weight = "att_weight_%d" % self.att_counter 134 | reduction = "att_reduction_%d" % self.att_counter 135 | ip = "att_ip_%d" % self.att_counter 136 | 137 | comb = "att_comb_%d" % self.att_counter 138 | 139 | # compute attention mask 140 | 141 | net.f(Convolution( 142 | proj_image, (1,1), self.config.att_hidden, 143 | bottoms=[images])) 144 | 145 | net.blobs[last_hidden].reshape( 146 | (self.batch_size, self.config.att_hidden, 1, 1)) 147 | 148 | net.f(Tile( 149 | tile_hidden, axis=2, tiles=self.image_size, 150 | bottoms=[last_hidden])) 151 | 152 | net.f(Eltwise(add, "SUM", bottoms=[proj_image, tile_hidden])) 153 | 154 | net.f(ReLU(relu, bottoms=[add])) 155 | 156 | net.f(Convolution(mask, (1, 1), 1, bottoms=[relu])) 157 | 158 | net.blobs[mask].reshape((self.batch_size, self.image_size)) 159 | 160 | net.f(Softmax(softmax, bottoms=[mask])) 161 | 162 | # TODO WTF 163 | net.f(Power("copy_softmax", bottoms=[softmax])) 164 | net.blobs["copy_softmax"].reshape((self.batch_size, 1, self.image_size, 1)) 165 | 166 | # compute average features 167 | 168 | net.f(Tile(tile_mask, axis=1, tiles=self.channels, bottoms=["copy_softmax"])) 169 | 170 | net.f(Eltwise(weight, "PROD", bottoms=[tile_mask, images])) 171 | 172 | # reduction 173 | net.f(InnerProduct( 174 | reduction, 1, axis=2, bottoms=[weight], 175 | weight_filler=Filler("constant", 1), 176 | bias_filler=Filler("constant", 0), param_lr_mults=[0, 0])) 177 | 178 | net.f(InnerProduct( 179 | ip, self.config.att_hidden, bottoms=[reduction])) 180 | #weight_filler=Filler("uniform", 0.01))) 181 | #param_lr_mults=[0.1, 0.1])) 182 | 183 | net.blobs[ip].reshape( 184 | (self.batch_size, self.config.att_hidden, 1, 1)) 185 | 186 | net.f(Power( 187 | "scale", scale=0.1, bottoms=[ip])) 188 | 189 | #print np.squeeze(net.blobs[ip].data[0,...]) 190 | #print np.linalg.norm(net.blobs[ip].data) 191 | #print np.linalg.norm(net.blobs[last_hidden].data) 192 | #print 193 | 194 | # combine with previous hidden 195 | 196 | net.f(Eltwise(comb, "SUM", bottoms=["scale", last_hidden])) 197 | #net.f(Eltwise(comb, "SUM", bottoms=[ip, last_hidden])) 198 | 199 | return comb 200 | 201 | def forward_pred(self, last_hidden): 202 | net = self.apollo_net 203 | 204 | hidden_relu = "answer_hidden_relu" 205 | ip = "answer_ip" 206 | 207 | net.f(ReLU(hidden_relu, bottoms=[last_hidden])) 208 | net.f(InnerProduct(ip, len(ANSWER_INDEX), bottoms=[hidden_relu])) 209 | 210 | return ip 211 | 212 | def loss(self, answers): 213 | net = self.apollo_net 214 | 215 | loss_data = "loss_data_%d" % self.loss_counter 216 | loss_score = "loss_score_%d" % self.loss_counter 217 | 218 | #print net.blobs[self.pred_layer].data[:10,:10] 219 | #print answers[:10] 220 | #exit() 221 | 222 | net.f(NumpyData(loss_data, answers)) 223 | 224 | self.loss_counter += 1 225 | 226 | loss = net.f(SoftmaxWithLoss( 227 | loss_score, bottoms=[self.pred_layer, loss_data], 228 | ignore_label=UNK_ID)) 229 | 230 | return loss 231 | -------------------------------------------------------------------------------- /models/nmn.py: -------------------------------------------------------------------------------- 1 | from layers.reinforce import Index, AsLoss 2 | from misc.indices import QUESTION_INDEX, MODULE_INDEX, ANSWER_INDEX, UNK_ID 3 | from misc import util 4 | from opt import adadelta 5 | 6 | import apollocaffe 7 | from apollocaffe.layers import * 8 | import numpy as np 9 | 10 | class Module(object): 11 | def __init__(self, config): 12 | self.config = config 13 | 14 | if hasattr(config, "pred_hidden"): 15 | self.pred_size = config.pred_hidden 16 | else: 17 | self.pred_size = len(ANSWER_INDEX) 18 | 19 | def forward(self, index, label_data, bottoms, features, rel_features, dropout, apollo_net): 20 | raise NotImplementedError() 21 | 22 | def __str__(self): 23 | return self.__class__.__name__ 24 | 25 | class AnswerAdaptor(Module): 26 | def __init__(self, config): 27 | super(AnswerAdaptor, self).__init__(config) 28 | self.mappings = dict() 29 | self.indices = dict() 30 | self.loaded_weights = False 31 | 32 | def register(self, key, mapping): 33 | if key in self.mappings: 34 | assert mapping == self.mappings[key] 35 | return self.indices[key] 36 | self.mappings[key] = mapping 37 | self.indices[key] = len(self.indices) 38 | return self.indices[key] 39 | 40 | def forward(self, index, label_data, bottoms, features, rel_features, dropout, apollo_net): 41 | assert len(bottoms) == 1 42 | input = bottoms[0] 43 | net = apollo_net 44 | assert self.config.att_normalization == "local" 45 | 46 | data = "AnswerAdaptor_%d_data" % index 47 | weights = "AnswerAdaptor_%d_weights" % index 48 | copy_input = "AnswerAdaptor_%d_copy_input" % index 49 | scalar = "AnswerAdaptor_%d_scalar" % index 50 | reduce = "AnswerAdaptor_%d_reduce" % index 51 | 52 | weight_param = "AnswerAdaptor_weight_param" 53 | reduce_param_w = "AnswerAdaptor_reduce_param_w" 54 | reduce_param_b = "AnswerAdaptor_reduce_param_b" 55 | 56 | batch_size = net.blobs[features].shape[0] 57 | # TODO correct 58 | n_mappings = 10 59 | n_inputs = net.blobs[features].shape[2] 60 | n_outputs = len(ANSWER_INDEX) 61 | 62 | net.f(NumpyData(data, label_data)) 63 | net.f(Wordvec( 64 | weights, n_inputs * n_outputs, n_mappings, bottoms=[data], 65 | param_names=[weight_param], weight_filler=Filler("constant", 0), 66 | param_lr_mults=[0])) 67 | net.blobs[weights].reshape((batch_size, n_inputs, n_outputs)) 68 | 69 | net.f(Power(copy_input, bottoms=[input])) 70 | net.blobs[copy_input].reshape((batch_size, n_inputs)) 71 | net.f(Scalar(scalar, 0, bottoms=[weights, copy_input])) 72 | 73 | net.blobs[scalar].reshape((batch_size, n_inputs, n_outputs, 1)) 74 | net.f(Convolution( 75 | reduce, (1, 1), 1, bottoms=[scalar], 76 | param_names=[reduce_param_w, reduce_param_b], 77 | weight_filler=Filler("constant", 1), 78 | bias_filler=Filler("constant", 0), 79 | param_lr_mults=[0,0])) 80 | net.blobs[reduce].reshape((batch_size, n_outputs)) 81 | 82 | if not self.loaded_weights: 83 | for key, mapping in self.mappings.items(): 84 | index = self.indices[key] 85 | for inp, out in mapping.items(): 86 | net.params[weight_param].data[0, index, 0, inp * n_outputs + out] = 1 87 | self.loaded_weights = True 88 | 89 | return reduce 90 | 91 | class LookupModule(Module): 92 | def forward(self, index, label_data, bottoms, features, rel_features, dropout, apollo_net): 93 | assert len(bottoms) == 0 94 | net = apollo_net 95 | batch_size, channels, image_size, trailing = net.blobs[features].shape 96 | assert trailing == 1 97 | 98 | assert self.config.att_normalization == "local" 99 | data = np.zeros((batch_size, 1, image_size, trailing)) 100 | for i in range(len(label_data)): 101 | data[i, :, label_data[i], ...] = 1 102 | 103 | lookup = "Lookup_%d_data" % index 104 | 105 | net.f(NumpyData(lookup, data)) 106 | 107 | return lookup 108 | 109 | class MLPFindModule(Module): 110 | def forward(self, index, label_data, bottoms, image, rel_features, dropout, 111 | apollo_net): 112 | assert len(bottoms) == 0 113 | 114 | net = apollo_net 115 | 116 | batch_size, channels, image_size, trailing = net.blobs[image].shape 117 | assert trailing == 1 118 | 119 | proj_image = "Find_%d_proj_image" % index 120 | label = "Find_%d_label" % index 121 | label_vec = "Find_%d_label_vec" % index 122 | label_vec_dropout = "Find_%d_label_vec_dropout" % index 123 | tile = "Find_%d_tile" % index 124 | sum = "Find_%d_sum" % index 125 | relu = "Find_%d_relu" % index 126 | mask = "Find_%d_mask" % index 127 | softmax = "Find_%d_softmax" % index 128 | sigmoid = "Find_%d_sigmoid" % index 129 | copy = "Find_%d_copy" % index 130 | 131 | proj_image_param_weight = "Find_proj_image_param_weight" 132 | proj_image_param_bias = "Find_proj_image_param_bias" 133 | label_vec_param = "Find_label_vec_param" 134 | mask_param_weight = "Find_mask_param_weight" 135 | mask_param_bias = "Find_mask_param_bias" 136 | 137 | # compute attention mask 138 | 139 | net.f(Convolution( 140 | proj_image, (1, 1), self.config.att_hidden, bottoms=[image], 141 | param_names=[proj_image_param_weight, proj_image_param_bias])) 142 | 143 | net.f(NumpyData(label, label_data)) 144 | net.f(Wordvec( 145 | label_vec, self.config.att_hidden, len(MODULE_INDEX), 146 | bottoms=[label], param_names=[label_vec_param])) 147 | net.blobs[label_vec].reshape((batch_size, self.config.att_hidden, 1, 1)) 148 | if dropout: 149 | net.f(Dropout(label_vec_dropout, 0.5, bottoms=[label_vec])) 150 | label_vec_final = label_vec_dropout 151 | else: 152 | label_vec_final = label_vec 153 | 154 | net.f(Tile(tile, axis=2, tiles=image_size, bottoms=[label_vec_final])) 155 | net.f(Eltwise(sum, "SUM", bottoms=[proj_image, tile])) 156 | net.f(ReLU(relu, bottoms=[sum])) 157 | net.f(Convolution(mask, (1, 1), 1, bottoms=[relu], 158 | param_names=[mask_param_weight, mask_param_bias])) 159 | 160 | # TODO or defer? 161 | if self.config.att_normalization == "local": 162 | net.f(Sigmoid(sigmoid, bottoms=[mask])) 163 | prev = sigmoid 164 | elif self.config.att_normalization == "global": 165 | net.f(Softmax(softmax, bottoms=[mask])) 166 | prev = softmax 167 | net.f(Power(copy, bottoms=[prev])) 168 | net.blobs[copy].reshape((batch_size, 1, image_size, 1)) 169 | 170 | return copy 171 | 172 | class MultiplicativeFindModule(Module): 173 | def forward(self, index, label_data, bottoms, features, rel_features, dropout, apollo_net): 174 | assert len(bottoms) == 0 175 | 176 | net = apollo_net 177 | 178 | batch_size, channels, image_size, trailing = net.blobs[features].shape 179 | assert trailing == 1 180 | 181 | proj_image = "Find_%d_proj_image" % index 182 | label = "Find_%d_label" % index 183 | label_vec = "Find_%d_label_vec" % index 184 | label_vec_dropout = "Find_%d_label_vec_dropout" % index 185 | tile = "Find_%d_tile" % index 186 | prod = "Find_%d_prod" % index 187 | mask = "Find_%d_mask" % index 188 | sigmoid = "Find_%d_sigmoid" % index 189 | softmax = "Find_%d_softmax" % index 190 | copy = "Find_%d_copy" % index 191 | 192 | proj_image_param_weight = "Find_proj_image_param_weight" 193 | proj_image_param_bias = "Find_proj_image_param_bias" 194 | label_vec_param = "Find_label_vec_param" 195 | mask_param_weight = "Find_mask_param_weight" 196 | mask_param_bias = "Find_mask_param_bias" 197 | 198 | # compute attention mask 199 | 200 | net.f(NumpyData(label, label_data)) 201 | net.f(Wordvec( 202 | label_vec, channels, len(MODULE_INDEX), 203 | bottoms=[label], param_names=[label_vec_param])) 204 | net.blobs[label_vec].reshape((batch_size, channels, 1, 1)) 205 | if dropout: 206 | net.f(Dropout(label_vec_dropout, 0.5, bottoms=[label_vec])) 207 | label_vec_final = label_vec_dropout 208 | else: 209 | label_vec_final = label_vec 210 | 211 | net.f(Tile(tile, axis=2, tiles=image_size, bottoms=[label_vec_final])) 212 | 213 | net.f(Eltwise(prod, "PROD", bottoms=[features, tile])) 214 | net.f(Convolution(mask, (1, 1), 1, bottoms=[prod], 215 | param_names=[mask_param_weight, mask_param_bias], 216 | weight_filler=Filler("constant", 1), 217 | bias_filler=Filler("constant", 0), 218 | param_lr_mults=[0, 0])) 219 | 220 | if self.config.att_normalization == "local": 221 | net.f(Sigmoid(sigmoid, bottoms=[mask])) 222 | prev = sigmoid 223 | elif self.config.att_normalization == "global": 224 | net.f(Softmax(softmax, bottoms=[mask])) 225 | prev = softmax 226 | # TODO still WTF 227 | net.f(Power(copy, bottoms=[prev])) 228 | 229 | net.blobs[copy].reshape((batch_size, 1, image_size, 1)) 230 | 231 | return copy 232 | 233 | class RelateModule(Module): 234 | def forward(self, index, label_data, bottoms, features, rel_features, 235 | dropout, apollo_net): 236 | net = apollo_net 237 | batch_size, rel_channels, image_size, _ = net.blobs[rel_features].shape 238 | assert len(bottoms) == 1 239 | mask = bottoms[0] 240 | 241 | tile_mask_feats = "ReAttend_%d_tile_mask_feats" % index 242 | tile_mask_neighbors = "ReAttend_%d_tile_mask_neighbors" % index 243 | weight = "ReAttend_%d_weight" % index 244 | reduce = "ReAttend_%d_reduce" % index 245 | labels = "ReAttend_%d_labels" % index 246 | param = "ReAttend_%d_param" % index 247 | tile_param = "ReAttend_%d_tile_param" % index 248 | prod = "ReAttend_%d_prod" % index 249 | reduce2 = "ReAttend_%d_reduce2" % index 250 | sigmoid = "ReAttend_%d_sigmoid" % index 251 | softmax = "ReAttend_%d_softmax" % index 252 | copy = "ReAttend_%d_copy" % index 253 | 254 | reduce_param_weight = "ReAttend_reduce_param_weight" 255 | reduce_param_bias = "ReAttend_reduce_param_bias" 256 | wordvec_param = "ReAttend_wordvec_param" 257 | reduce2_param_weight = "ReAttend_reduce2_param_weight" 258 | reduce2_param_bias = "ReAttend_reduce2_param_bias" 259 | 260 | net.blobs[mask].reshape((batch_size, 1, 1, image_size)) 261 | net.f(Tile( 262 | tile_mask_feats, axis=1, tiles=rel_channels, bottoms=[mask])) 263 | net.f(Tile( 264 | tile_mask_neighbors, axis=2, tiles=image_size, 265 | bottoms=[tile_mask_feats])) 266 | net.f(Eltwise( 267 | weight, "PROD", bottoms=[tile_mask_neighbors, rel_features])) 268 | net.f(InnerProduct( 269 | reduce, 1, axis=3, bottoms=[weight], 270 | weight_filler=Filler("constant", 1), 271 | bias_filler=Filler("constant", 0), 272 | param_lr_mults=[0, 0], 273 | param_names=[reduce_param_weight, reduce_param_bias])) 274 | 275 | net.f(NumpyData(labels, label_data)) 276 | net.f(Wordvec( 277 | param, rel_channels, len(MODULE_INDEX), bottoms=[labels], 278 | param_names=[wordvec_param])) 279 | net.blobs[param].reshape((batch_size, rel_channels, 1, 1)) 280 | net.f(Tile(tile_param, axis=2, tiles=image_size, bottoms=[param])) 281 | net.f(Eltwise(prod, "PROD", bottoms=[tile_param, reduce])) 282 | 283 | net.f(Convolution( 284 | reduce2, (1, 1), 1, bottoms=[prod], 285 | param_names=[reduce2_param_weight, reduce2_param_bias], 286 | weight_filler=Filler("constant", 1), 287 | bias_filler=Filler("constant", 0), 288 | param_lr_mults=[0, 0])) 289 | 290 | if self.config.att_normalization == "local": 291 | net.f(Sigmoid(sigmoid, bottoms=[reduce2])) 292 | prev = sigmoid 293 | elif self.config.att_normalization == "global": 294 | net.f(Softmax(softmax, bottoms=[reduce2])) 295 | prev = softmax 296 | # TODO still WTF 297 | net.f(Power(copy, bottoms=[prev])) 298 | 299 | return copy 300 | 301 | class AndModule(Module): 302 | def forward(self, index, label_data, bottoms, features, rel_features, dropout, apollo_net): 303 | net = apollo_net 304 | assert len(bottoms) >= 2 305 | assert all(net.blobs[l].shape[1] == 1 for l in bottoms) 306 | 307 | prod = "And_%d_prod" % index 308 | 309 | net.f(Eltwise(prod, "PROD", bottoms=bottoms)) 310 | 311 | return prod 312 | 313 | class DescribeModule(Module): 314 | def forward(self, index, label_data, bottoms, features, rel_features, dropout, apollo_net): 315 | assert len(bottoms) == 1 316 | mask = bottoms[0] 317 | #assert self.config.att_normalization == "global" 318 | 319 | net = apollo_net 320 | 321 | batch_size, channels, image_size, _ = net.blobs[features].shape 322 | 323 | tile_mask = "Describe_%d_tile_mask" % index 324 | weight = "Describe_%d_weight" % index 325 | reduction = "Describe_%d_reduction" % index 326 | ip = "Describe_%d_ip" % index 327 | scale = "Describe_%d_scale" % index 328 | 329 | reduction_param_weight = "Describe_reduction_param_weight" 330 | reduction_param_bias = "Describe_reduction_param_bias" 331 | ip_param_weight = "Describe_ip_param_weight" 332 | ip_param_bias = "Describe_ip_param_bias" 333 | 334 | net.f(Tile(tile_mask, axis=1, tiles=channels, bottoms=[mask])) 335 | net.f(Eltwise(weight, "PROD", bottoms=[tile_mask, features])) 336 | net.f(InnerProduct( 337 | reduction, 1, axis=2, bottoms=[weight], 338 | weight_filler=Filler("constant", 1), 339 | bias_filler=Filler("constant", 0), 340 | param_lr_mults=[0, 0], 341 | param_names=[reduction_param_weight, reduction_param_bias])) 342 | net.f(InnerProduct( 343 | ip, self.pred_size, bottoms=[reduction], 344 | param_names=[ip_param_weight, ip_param_bias])) 345 | net.f(Power(scale, scale=0.001, bottoms=[ip])) 346 | 347 | return scale 348 | 349 | class ExistsModule(Module): 350 | def forward(self, index, label_data, bottoms, features, rel_features, dropout, apollo_net): 351 | assert len(bottoms) == 1 352 | mask = bottoms[0] 353 | 354 | net = apollo_net 355 | 356 | reduce = "Exists_%d_reduce" % index 357 | ip = "Exists_%d_ip" % index 358 | 359 | reduce_param_weight = "Exists_reduce_param_weight" 360 | reduce_param_bias = "Exists_reduce_param_bias" 361 | ip_param_weight = "Exists_ip_param_weight" 362 | ip_param_bias = "Exists_ip_param_bias" 363 | 364 | net.f(Pooling(reduce, kernel_h=10, kernel_w=1, bottoms=[mask])) 365 | 366 | net.f(InnerProduct( 367 | ip, len(ANSWER_INDEX), bottoms=[reduce], 368 | param_names=[ip_param_weight, ip_param_bias], 369 | weight_filler=Filler("constant", 0), 370 | bias_filler=Filler("constant", 0), 371 | param_lr_mults=[0, 0])) 372 | net.params[ip_param_weight].data[ANSWER_INDEX["yes"],0] = 1 373 | net.params[ip_param_weight].data[ANSWER_INDEX["no"],0] = -1 374 | net.params[ip_param_bias].data[ANSWER_INDEX["no"]] = 1 375 | 376 | return ip 377 | 378 | class Nmn: 379 | def __init__(self, index, modules, apollo_net): 380 | self.index = index 381 | self.apollo_net = apollo_net 382 | 383 | # TODO eww 384 | counter = [0] 385 | def number(tree): 386 | r = counter[0] 387 | counter[0] += 1 388 | return r 389 | numbered = util.tree_map(number, modules) 390 | assert util.flatten(numbered) == range(len(util.flatten(numbered))) 391 | 392 | def children(tree): 393 | if not isinstance(tree, tuple): 394 | return str(()) 395 | # TODO nasty hack to make flatten behave right 396 | return str(tuple(c[0] if isinstance(c, tuple) else c for c in tree[1:])) 397 | child_annotated = util.tree_map(children, numbered) 398 | 399 | self.modules = util.flatten(modules) 400 | self.children = [eval(c) for c in util.flatten(child_annotated)] 401 | 402 | def forward(self, label_data, features, rel_features, dropout): 403 | 404 | flat_data = [util.flatten(d) for d in label_data] 405 | flat_data = np.asarray(flat_data) 406 | outputs = [None for i in range(len(self.modules))] 407 | for i in reversed(range(len(self.modules))): 408 | bottoms = [outputs[j] for j in self.children[i]] 409 | assert None not in bottoms 410 | mod_index = self.index * 100 + i 411 | output = self.modules[i].forward( 412 | mod_index, flat_data[:,i], bottoms, features, rel_features, 413 | dropout, self.apollo_net) 414 | outputs[i] = output 415 | 416 | self.outputs = outputs 417 | return outputs[0] 418 | 419 | class NmnModel: 420 | def __init__(self, config, opt_config): 421 | self.config = config 422 | self.opt_config = opt_config 423 | self.opt_state = adadelta.State() 424 | 425 | self.nmns = dict() 426 | 427 | self.apollo_net = apollocaffe.ApolloNet() 428 | 429 | def get_nmn(self, modules): 430 | if modules not in self.nmns: 431 | self.nmns[modules] = Nmn(len(self.nmns), modules, self.apollo_net) 432 | return self.nmns[modules] 433 | 434 | 435 | def forward(self, layouts, layout_data, question_data, features_data, 436 | rel_features_data, dropout, deterministic): 437 | 438 | # predict layout 439 | 440 | question_hidden = self.forward_question(question_data, dropout) 441 | layout_ids, layout_probs = \ 442 | self.forward_layout(question_hidden, layouts, layout_data, 443 | deterministic) 444 | 445 | self.layout_ids = layout_ids 446 | self.layout_probs = layout_probs 447 | 448 | chosen_layouts = [ll[i] for ll, i in zip(layouts, layout_ids)] 449 | 450 | # prepare layout data 451 | 452 | module_layouts = list(set(l.modules for l in chosen_layouts)) 453 | module_layout_choices = [] 454 | default_labels = [None for i in range(len(module_layouts))] 455 | for layout in chosen_layouts: 456 | choice = module_layouts.index(layout.modules) 457 | module_layout_choices.append(choice) 458 | if default_labels[choice] is None: 459 | default_labels[choice] = layout.labels 460 | layout_label_data = [] 461 | layout_mask = [] 462 | for layout, choice in zip(chosen_layouts, module_layout_choices): 463 | labels_here = list(default_labels) 464 | labels_here[choice] = layout.labels 465 | layout_label_data.append(labels_here) 466 | mask_here = [0 for i in range(len(module_layouts))] 467 | mask_here[choice] = 1 468 | layout_mask.append(mask_here) 469 | layout_mask = np.asarray(layout_mask) 470 | 471 | self.module_layout_choices = module_layout_choices 472 | 473 | # predict answer 474 | 475 | features = self.forward_features(0, features_data, dropout) 476 | if rel_features_data is not None: 477 | rel_features = self.forward_features(1, rel_features_data, dropout) 478 | else: 479 | rel_features = None 480 | 481 | nmn_hiddens = [] 482 | nmn_outputs = [] 483 | for i in range(len(module_layouts)): 484 | module_layout = module_layouts[i] 485 | label_data = [lld[i] for lld in layout_label_data] 486 | nmn = self.get_nmn(module_layout) 487 | nmn_hidden = nmn.forward(label_data, features, rel_features, dropout) 488 | nmn_hiddens.append(nmn_hidden) 489 | nmn_outputs.append(nmn.outputs) 490 | 491 | chosen_hidden = self.forward_choice(module_layouts, layout_mask, nmn_hiddens) 492 | self.prediction = self.forward_pred(question_hidden, chosen_hidden) 493 | 494 | batch_size = self.apollo_net.blobs[nmn_hiddens[0]].shape[0] 495 | self.prediction_data = self.apollo_net.blobs[self.prediction].data 496 | self.att_data = np.zeros((batch_size, 14, 14)) 497 | 498 | def forward_choice(self, module_layouts, layout_mask, nmn_hiddens): 499 | net = self.apollo_net 500 | 501 | concat = "CHOOSE_concat" 502 | mask = "CHOOSE_mask" 503 | prod = "CHOOSE_prod" 504 | tile_mask = "CHOOSE_tile_mask" 505 | sum = "CHOOSE%d_sum" % len(module_layouts) 506 | 507 | batch_size = self.apollo_net.blobs[nmn_hiddens[0]].shape[0] 508 | 509 | output_hidden = self.config.pred_hidden \ 510 | if hasattr(self.config, "pred_hidden") \ 511 | else len(ANSWER_INDEX) 512 | for h in nmn_hiddens: 513 | self.apollo_net.blobs[h].reshape((batch_size, output_hidden, 1)) 514 | if len(nmn_hiddens) == 1: 515 | concat_layer = nmn_hiddens[0] 516 | else: 517 | self.apollo_net.f(Concat(concat, axis=2, bottoms=nmn_hiddens)) 518 | concat_layer = concat 519 | 520 | self.apollo_net.f(NumpyData(mask, layout_mask)) 521 | self.apollo_net.blobs[mask].reshape( 522 | (batch_size, 1, len(module_layouts))) 523 | self.apollo_net.f(Tile( 524 | tile_mask, axis=1, tiles=output_hidden, bottoms=[mask])) 525 | self.apollo_net.f(Eltwise( 526 | prod, "PROD", bottoms=[tile_mask, concat_layer])) 527 | self.apollo_net.f(InnerProduct( 528 | sum, 1, axis=2, bottoms=[prod], 529 | weight_filler=Filler("constant", 1), 530 | bias_filler=Filler("constant", 0), 531 | param_lr_mults=[0, 0])) 532 | self.apollo_net.blobs[sum].reshape((batch_size, output_hidden)) 533 | 534 | return sum 535 | 536 | def forward_layout(self, question_hidden, layouts, layout_data, 537 | deterministic=False): 538 | net = self.apollo_net 539 | batch_size, n_layouts, n_features = layout_data.shape 540 | 541 | proj_question = "LAYOUT_proj_question" 542 | tile_question = "LAYOUT_tile%d_question" % n_layouts 543 | layout_feats = "LAYOUT_feats_%d" 544 | proj_layout = "LAYOUT_proj_layout_%d" 545 | concat = "LAYOUT_concat" 546 | sum = "LAYOUT_sum" 547 | relu = "LAYOUT_relu" 548 | pred = "LAYOUT_pred" 549 | softmax = "LAYOUT_softmax" 550 | 551 | net.f(InnerProduct( 552 | proj_question, self.config.layout_hidden, 553 | bottoms=[question_hidden])) 554 | net.blobs[proj_question].reshape( 555 | (batch_size, 1, self.config.layout_hidden)) 556 | net.f(Tile(tile_question, axis=1, tiles=n_layouts, 557 | bottoms=[proj_question])) 558 | 559 | concat_bottoms = [] 560 | for i in range(n_layouts): 561 | # TODO normalize these? 562 | net.f(NumpyData(layout_feats % i, layout_data[:,i,:])) 563 | net.f(InnerProduct(proj_layout % i, self.config.layout_hidden, 564 | bottoms=[layout_feats % i])) 565 | net.blobs[proj_layout % i].reshape( 566 | (batch_size, 1, self.config.layout_hidden)) 567 | concat_bottoms.append(proj_layout % i) 568 | 569 | if n_layouts > 1: 570 | net.f(Concat(concat, axis=1, bottoms=concat_bottoms)) 571 | concat_layer = concat 572 | else: 573 | concat_layer = concat_bottoms[0] 574 | 575 | net.f(Eltwise(sum, "SUM", bottoms=[tile_question, concat_layer])) 576 | net.f(ReLU(relu, bottoms=[sum])) 577 | net.f(InnerProduct(pred, 1, axis=2, bottoms=[relu])) 578 | net.blobs[pred].reshape((batch_size, n_layouts)) 579 | net.f(Softmax(softmax, bottoms=[pred])) 580 | 581 | probs = net.blobs[softmax].data 582 | 583 | layout_choices = [] 584 | for i in range(len(layouts)): 585 | pr_here = probs[i,:len(layouts[i])].astype(np.float) 586 | pr_here /= np.sum(pr_here) 587 | if deterministic: 588 | choice = np.argmax(pr_here) 589 | else: 590 | choice = np.random.choice(pr_here.size, p=pr_here) 591 | layout_choices.append(choice) 592 | for i in range(batch_size - len(layouts)): 593 | layout_choices.append(0) 594 | 595 | return layout_choices, softmax 596 | 597 | def forward_features(self, index, feature_data, dropout): 598 | batch_size = feature_data.shape[0] 599 | channels = feature_data.shape[1] 600 | image_size = feature_data.shape[2] * feature_data.shape[3] 601 | 602 | net = self.apollo_net 603 | 604 | features = "FEATURES_%d_data" % index 605 | features_dropout = "FEATURES_%d_dropout" % index 606 | 607 | if features not in net.blobs: 608 | net.f(DummyData(features, feature_data.shape)) 609 | net.blobs[features].data[...] = feature_data 610 | 611 | if dropout: 612 | net.f(Dropout(features_dropout, 0.5, bottoms=[features])) 613 | return features_dropout 614 | else: 615 | return features 616 | 617 | def forward_question(self, question_data, dropout): 618 | net = self.apollo_net 619 | batch_size, length = question_data.shape 620 | 621 | wordvec_param = "QUESTION_wordvec_param" 622 | input_value_param = "QUESTION_input_value_param" 623 | input_gate_param = "QUESTION_input_gate_param" 624 | forget_gate_param = "QUESTION_forget_gate_param" 625 | output_gate_param = "QUESTION_output_gate_param" 626 | 627 | seed = "QUESTION_lstm_seed" 628 | question_dropout = "QUESTION_lstm_dropout" 629 | final_hidden = "QUESTION_lstm_final_hidden" 630 | 631 | prev_hidden = seed 632 | prev_mem = seed 633 | 634 | net.f(NumpyData(seed, np.zeros((batch_size, self.config.lstm_hidden)))) 635 | 636 | for t in range(length): 637 | word = "QUESTION_lstm_word_%d" % t 638 | wordvec = "QUESTION_lstm_wordvec_%d" % t 639 | concat = "QUESTION_lstm_concat_%d" % t 640 | lstm = "QUESTION_lstm_unit_%d" % t 641 | 642 | hidden = "QUESTION_lstm_hidden_%d" % t 643 | mem = "QUESTION_lstm_mem_%d" % t 644 | 645 | net.f(NumpyData(word, question_data[:,t])) 646 | net.f(Wordvec( 647 | wordvec, self.config.lstm_hidden, len(QUESTION_INDEX), 648 | bottoms=[word], param_names=[wordvec_param])) 649 | net.f(Concat(concat, bottoms=[prev_hidden, wordvec])) 650 | net.f(LstmUnit( 651 | lstm, self.config.lstm_hidden, bottoms=[concat, prev_mem], 652 | param_names=[input_value_param, input_gate_param, 653 | forget_gate_param, output_gate_param], 654 | tops=[hidden, mem])) 655 | 656 | prev_hidden = hidden 657 | prev_mem = mem 658 | 659 | # TODO consolidate with module? 660 | if hasattr(self.config, "pred_hidden"): 661 | pred_size = self.config.pred_hidden 662 | else: 663 | pred_size = len(ANSWER_INDEX) 664 | 665 | if dropout: 666 | net.f(Dropout(question_dropout, 0.5, bottoms=[prev_hidden])) 667 | net.f(InnerProduct(final_hidden, pred_size, bottoms=[question_dropout])) 668 | else: 669 | net.f(InnerProduct(final_hidden, pred_size, bottoms=[prev_hidden])) 670 | 671 | return final_hidden 672 | 673 | def forward_pred(self, question_hidden, nmn_hidden): 674 | net = self.apollo_net 675 | 676 | relu = "PRED_relu" 677 | ip = "PRED_ip" 678 | 679 | if self.config.combine_question: 680 | sum = "PRED_sum" 681 | net.f(Eltwise(sum, "SUM", bottoms=[question_hidden, nmn_hidden])) 682 | else: 683 | sum = nmn_hidden 684 | 685 | if hasattr(self.config, "pred_hidden"): 686 | net.f(ReLU(relu, bottoms=[sum])) 687 | net.f(InnerProduct(ip, len(ANSWER_INDEX), bottoms=[relu])) 688 | return ip 689 | else: 690 | return sum 691 | 692 | def loss(self, answer_data, multiclass=False): 693 | net = self.apollo_net 694 | 695 | target = "PRED_target_%d" % self.loss_counter 696 | loss = "PRED_loss_%d" % self.loss_counter 697 | datum_loss = "PRED_datum_loss_%d" % self.loss_counter 698 | self.loss_counter += 1 699 | 700 | if multiclass: 701 | net.f(NumpyData(target, answer_data)) 702 | acc_loss = net.f(EuclideanLoss( 703 | loss, bottoms=[self.prediction, target])) 704 | 705 | pred_probs = net.blobs[self.prediction].data 706 | batch_size = pred_probs.shape[0] 707 | pred_ans_probs = np.sum(np.abs(answer_data - pred_probs) ** 2, axis=1) 708 | # TODO 709 | pred_ans_log_probs = pred_ans_probs 710 | 711 | else: 712 | net.f(NumpyData(target, answer_data)) 713 | acc_loss = net.f(SoftmaxWithLoss( 714 | loss, bottoms=[self.prediction, target], ignore_label=UNK_ID)) 715 | 716 | net.f(Softmax(datum_loss, bottoms=[self.prediction])) 717 | 718 | pred_probs = net.blobs[datum_loss].data 719 | batch_size = pred_probs.shape[0] 720 | pred_ans_probs = pred_probs[np.arange(batch_size), answer_data.astype(np.int)] 721 | pred_ans_log_probs = np.log(pred_ans_probs) 722 | pred_ans_log_probs[answer_data == UNK_ID] = 0 723 | 724 | self.cumulative_datum_losses += pred_ans_log_probs 725 | 726 | return acc_loss 727 | 728 | def reinforce_layout(self, losses): 729 | net = self.apollo_net 730 | 731 | choice_data = "REINFORCE_choice_data" 732 | loss_data = "REINFORCE_loss_data" 733 | index = "REINFORCE_index" 734 | weight = "REINFORCE_weight" 735 | reduction = "REINFORCE_reduction" 736 | loss = "REINFORCE_loss" 737 | 738 | net.f(NumpyData(choice_data, self.layout_ids)) 739 | net.f(NumpyData(loss_data, losses)) 740 | net.f(Index(index, {}, bottoms=[self.layout_probs, choice_data])) 741 | net.f(Eltwise(weight, "PROD", bottoms=[index, loss_data])) 742 | 743 | net.f(AsLoss(loss, bottoms=[weight])) 744 | 745 | def reset(self): 746 | self.apollo_net.clear_forward() 747 | self.loss_counter = 0 748 | self.cumulative_datum_losses = np.zeros((self.opt_config.batch_size,)) 749 | self.question_hidden = None 750 | 751 | def train(self): 752 | self.reinforce_layout(self.cumulative_datum_losses) 753 | self.apollo_net.backward() 754 | adadelta.update(self.apollo_net, self.opt_state, self.opt_config) 755 | -------------------------------------------------------------------------------- /opt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobandreas/nmn2/7e42dd98420f9580fd34185ba670490a5d86fb04/opt/__init__.py -------------------------------------------------------------------------------- /opt/adadelta.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from collections import namedtuple 4 | import numpy as np 5 | 6 | class State: 7 | def __init__(self): 8 | self.sq_updates = dict() 9 | self.sq_grads = dict() 10 | 11 | def update(net, state, config): 12 | rho = config.rho 13 | epsilon = config.eps 14 | lr = config.lr 15 | clip = config.clip 16 | 17 | all_norm = 0. 18 | for param_name in net.active_param_names(): 19 | param = net.params[param_name] 20 | grad = param.diff * net.param_lr_mults(param_name) 21 | all_norm += np.sum(np.square(grad)) 22 | all_norm = np.sqrt(all_norm) 23 | 24 | for param_name in net.active_param_names(): 25 | param = net.params[param_name] 26 | grad = param.diff * net.param_lr_mults(param_name) 27 | 28 | if all_norm > clip: 29 | grad = clip * grad / all_norm 30 | 31 | if param_name in state.sq_grads: 32 | state.sq_grads[param_name] = \ 33 | (1 - rho) * np.square(grad) + rho * state.sq_grads[param_name] 34 | rms_update = np.sqrt(state.sq_updates[param_name] + epsilon) 35 | rms_grad = np.sqrt(state.sq_grads[param_name] + epsilon) 36 | update = -rms_update / rms_grad * grad 37 | 38 | state.sq_updates[param_name] = \ 39 | (1 - rho) * np.square(update) + rho * state.sq_updates[param_name] 40 | else: 41 | state.sq_grads[param_name] = (1 - rho) * np.square(grad) 42 | update = np.sqrt(epsilon) / np.sqrt(epsilon + 43 | state.sq_grads[param_name]) * grad 44 | state.sq_updates[param_name] = (1 - rho) * np.square(update) 45 | 46 | param.data[...] += lr * update 47 | param.diff[...] = 0 48 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export APOLLO_ROOT=/home/jda/3p/apollocaffe 4 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$APOLLO_ROOT/build/lib 5 | export PYTHONPATH=$PYTHONPATH:$APOLLO_ROOT/python:$APOLLO_ROOT/python/caffe/proto 6 | 7 | python main.py -c config/vqa_nmn.yml 8 | #python main.py -c config/geo_nmn.yml 9 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | def load_task(config): 4 | if config.task.name == "vqa": 5 | from vqa import VqaTask 6 | return VqaTask(config) 7 | if config.task.name == "geo": 8 | from geo import GeoTask 9 | return GeoTask(config) 10 | else: 11 | raise NotImplementedError( 12 | "Don't know how to build a %s task" % config.task.name) 13 | -------------------------------------------------------------------------------- /tasks/geo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from misc.datum import Datum, Layout 4 | from misc.indices import QUESTION_INDEX, MODULE_INDEX, ANSWER_INDEX, UNK_ID 5 | from misc.parse import parse_tree 6 | import misc.util 7 | from models.nmn import MultiplicativeFindModule, LookupModule, AndModule, \ 8 | ExistsModule, RelateModule, AnswerAdaptor 9 | 10 | from collections import defaultdict, namedtuple 11 | import logging 12 | import numpy as np 13 | import os 14 | import re 15 | import xml.etree.ElementTree as ET 16 | 17 | DATA_FILE = "data/geo/environments/%s/training.txt" 18 | PARSE_FILE = "data/geo/environments/%s/training.sps" 19 | WORLD_FILE = "data/geo/environments/%s/world.txt" 20 | LOCATION_FILE = "data/geo/environments/%s/locations.txt" 21 | 22 | ENVIRONMENTS = ["fl", "ga", "mi", "nc", "ok", "pa", "sc", "tn", "va", "wv"] 23 | 24 | CATS = ["city", "state", "park", "island", "beach", "ocean", "lake", "forest", 25 | "major", "peninsula", "capital"] 26 | RELS = ["in-rel", "north-rel", "south-rel", "east-rel", "west-rel", "border-rel"] 27 | 28 | DATABASE_SIZE=10 29 | 30 | TRAIN = 0 31 | VAL = 1 32 | TEST = 2 33 | 34 | YES = "yes" 35 | NO = "no" 36 | 37 | World = namedtuple("World", ("name", "entities", "entity_features", "relation_features")) 38 | 39 | def parse_to_layout_helper(parse, world, config, modules): 40 | if isinstance(parse, str): 41 | if parse in world.entities: 42 | return modules["lookup"], world.entities[parse] 43 | else: 44 | return modules["find"], MODULE_INDEX.index(parse) 45 | head = parse[0] 46 | below = [parse_to_layout_helper(c, world, config, modules) for c in parse[1:]] 47 | modules_below, labels_below = zip(*below) 48 | modules_below = tuple(modules_below) 49 | labels_below = tuple(labels_below) 50 | if head == "and": 51 | module_head = modules["and"] 52 | elif head == "exists": 53 | module_head = modules["exists"] 54 | else: 55 | module_head = modules["relate"] 56 | label_head = MODULE_INDEX.index(head) 57 | modules_here = (module_head,) + modules_below 58 | labels_here = (label_head,) + labels_below 59 | return modules_here, labels_here 60 | 61 | def parse_to_layout(parse, world, config, modules): 62 | mods, indices = parse_to_layout_helper(parse, world, config, modules) 63 | 64 | head = mods[0] if isinstance(mods, tuple) else mods 65 | if not isinstance(head, ExistsModule): 66 | # wrap in translation module 67 | # TODO bad naming 68 | mapping = {i: ANSWER_INDEX[ent] for ent, i in world.entities.items()} 69 | index = modules["answer_adaptor"].register(world.name, mapping) 70 | mods = (modules["answer_adaptor"], mods) 71 | indices = (index, indices) 72 | return Layout(mods, indices) 73 | 74 | class GeoDatum(Datum): 75 | def __init__(self, id, question, parses, layouts, answer, world): 76 | Datum.__init__(self) 77 | self.id = id 78 | self.question = question 79 | self.parses = parses 80 | self.layouts = layouts 81 | self.answers = [answer] 82 | self.world = world 83 | 84 | def load_features(self): 85 | return self.world.entity_features 86 | 87 | def load_rel_features(self): 88 | return self.world.relation_features 89 | 90 | class GeoTask: 91 | def __init__(self, config): 92 | modules = { 93 | "find": MultiplicativeFindModule(config.model), 94 | "lookup": LookupModule(config.model), 95 | "exists": ExistsModule(config.model), 96 | "and": AndModule(config.model), 97 | "relate": RelateModule(config.model), 98 | "answer_adaptor": AnswerAdaptor(config.model) 99 | } 100 | self.train = GeoTaskSet(config.task, TRAIN, modules) 101 | self.val = GeoTaskSet(config.task, VAL, modules) 102 | self.test = GeoTaskSet(config.task, TEST, modules) 103 | 104 | class GeoTaskSet: 105 | def __init__(self, config, set_name, modules): 106 | if set_name == VAL: 107 | self.data = [] 108 | return 109 | 110 | questions = [] 111 | answers = [] 112 | parse_lists = [] 113 | worlds = [] 114 | 115 | if config.quant: 116 | ANSWER_INDEX.index(YES) 117 | ANSWER_INDEX.index(NO) 118 | 119 | for i_env, environment in enumerate(ENVIRONMENTS): 120 | if i_env == config.fold and set_name == TRAIN: 121 | continue 122 | if i_env != config.fold and set_name == TEST: 123 | continue 124 | 125 | places = list() 126 | with open(LOCATION_FILE % environment) as loc_f: 127 | for line in loc_f: 128 | parts = line.strip().split(";") 129 | places.append(parts[0]) 130 | 131 | cats = {place: np.zeros((len(CATS),)) for place in places} 132 | rels = {(pl1, pl2): np.zeros((len(RELS),)) for pl1 in places for pl2 in places} 133 | 134 | with open(WORLD_FILE % environment) as world_f: 135 | for line in world_f: 136 | parts = line.strip().split(";") 137 | if len(parts) < 2: 138 | continue 139 | name = parts[0][1:] 140 | places_here = parts[1].split(",") 141 | if name in CATS: 142 | cat_id = CATS.index(name) 143 | for place in places_here: 144 | cats[place][cat_id] = 1 145 | elif name in RELS: 146 | rel_id = RELS.index(name) 147 | for place_pair in places_here: 148 | pl1, pl2 = place_pair.split("#") 149 | rels[pl1, pl2][rel_id] = 1 150 | rels[pl2, pl1][rel_id] = -1 151 | 152 | clean_places = [p.lower().replace(" ", "_") for p in places] 153 | place_index = {place: i for (i, place) in enumerate(places)} 154 | clean_place_index = {place: i for (i, place) in enumerate(clean_places)} 155 | 156 | cat_features = np.zeros((len(CATS), DATABASE_SIZE, 1)) 157 | rel_features = np.zeros((len(RELS), DATABASE_SIZE, DATABASE_SIZE)) 158 | 159 | for p1, i_p1 in place_index.items(): 160 | cat_features[:, i_p1, 0] = cats[p1] 161 | for p2, i_p2 in place_index.items(): 162 | rel_features[:, i_p1, i_p2] = rels[p1, p2] 163 | 164 | world = World(environment, clean_place_index, cat_features, rel_features) 165 | 166 | for place in clean_places: 167 | ANSWER_INDEX.index(place) 168 | 169 | with open(DATA_FILE % environment) as data_f: 170 | for line in data_f: 171 | line = line.strip() 172 | if line == "" or line[0] == "#": 173 | continue 174 | 175 | parts = line.split(";") 176 | 177 | question = parts[0] 178 | if question[-1] != "?": 179 | question += " ?" 180 | question = question.lower() 181 | questions.append(question) 182 | 183 | answer = parts[1].lower().replace(" ", "_") 184 | if config.quant and question.split()[0] in ("is", "are"): 185 | answer = YES if answer else NO 186 | answers.append(answer) 187 | 188 | worlds.append(world) 189 | 190 | with open(PARSE_FILE % environment) as parse_f: 191 | for line in parse_f: 192 | parse_strs = line.strip().split(";") 193 | trees = [parse_tree(s) for s in parse_strs] 194 | if not config.quant: 195 | trees = [t for t in trees if t[0] != "exists"] 196 | parse_lists.append(trees) 197 | 198 | assert len(questions) == len(parse_lists) 199 | 200 | data = [] 201 | i_datum = 0 202 | for question, answer, parse_list, world in \ 203 | zip(questions, answers, parse_lists, worlds): 204 | tokens = [""] + question.split() + [""] 205 | 206 | parse_list = parse_list[-config.k_best_parses:] 207 | 208 | indexed_question = [QUESTION_INDEX.index(w) for w in tokens] 209 | indexed_answer = \ 210 | tuple(ANSWER_INDEX[a] for a in answer.split(",") if a != "") 211 | assert all(a is not None for a in indexed_answer) 212 | layouts = [parse_to_layout(p, world, config, modules) for p in parse_list] 213 | 214 | data.append(GeoDatum( 215 | i_datum, indexed_question, parse_list, layouts, indexed_answer, world)) 216 | i_datum += 1 217 | 218 | self.data = data 219 | 220 | logging.info("%s:", set_name) 221 | logging.info("%s items", len(self.data)) 222 | logging.info("%s words", len(QUESTION_INDEX)) 223 | logging.info("%s functions", len(MODULE_INDEX)) 224 | logging.info("%s answers", len(ANSWER_INDEX)) 225 | -------------------------------------------------------------------------------- /tasks/vqa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | from misc.datum import Datum, Layout 4 | from misc.indices import QUESTION_INDEX, MODULE_INDEX, ANSWER_INDEX, UNK_ID 5 | from misc.parse import parse_tree 6 | from models.nmn import MLPFindModule, DescribeModule, ExistsModule, AndModule 7 | 8 | from collections import defaultdict 9 | import json 10 | import logging 11 | import numpy as np 12 | import os 13 | import re 14 | 15 | QUESTION_FILE = "data/vqa/Questions/OpenEnded_mscoco_%s_questions.json" 16 | SINGLE_PARSE_FILE = "data/vqa/Questions/%s.sp" 17 | MULTI_PARSE_FILE = "data/vqa/Questions/%s.sps2" 18 | ANN_FILE = "data/vqa/Annotations/mscoco_%s_annotations.json" 19 | IMAGE_FILE = "data/vqa/Images/%s/conv/COCO_%s_%012d.jpg.npz" 20 | RAW_IMAGE_FILE = "data/vqa/Images/%s/raw/COCO_%s_%012d.jpg" 21 | 22 | MIN_COUNT = 10 23 | 24 | def proc_question(question): 25 | qstr = question.lower().strip() 26 | if qstr[-1] == "?": 27 | qstr = qstr[:-1] 28 | words = qstr.split() 29 | words = [""] + words + [""] 30 | return words 31 | 32 | def prepare_indices(config): 33 | set_name = "train2014" 34 | 35 | word_counts = defaultdict(lambda: 0) 36 | with open(QUESTION_FILE % set_name) as question_f: 37 | questions = json.load(question_f)["questions"] 38 | for question in questions: 39 | words = proc_question(question["question"]) 40 | for word in words: 41 | word_counts[word] += 1 42 | for word, count in word_counts.items(): 43 | if count >= MIN_COUNT: 44 | QUESTION_INDEX.index(word) 45 | 46 | pred_counts = defaultdict(lambda: 0) 47 | with open(MULTI_PARSE_FILE % set_name) as parse_f: 48 | for line in parse_f: 49 | parts = line.strip().replace("(", "").replace(")", "").replace(";", " ").split() 50 | for part in parts: 51 | pred_counts[part] += 1 52 | for pred, count in pred_counts.items(): 53 | if count >= 10 * MIN_COUNT: 54 | MODULE_INDEX.index(pred) 55 | 56 | answer_counts = defaultdict(lambda: 0) 57 | with open(ANN_FILE % set_name) as ann_f: 58 | annotations = json.load(ann_f)["annotations"] 59 | for ann in annotations: 60 | for answer in ann["answers"]: 61 | if answer["answer_confidence"] != "yes": 62 | continue 63 | word = answer["answer"] 64 | if re.search(r"[^\w\s]", word): 65 | continue 66 | answer_counts[word] += 1 67 | 68 | keep_answers = reversed(sorted([(c, a) for a, c in answer_counts.items()])) 69 | keep_answers = list(keep_answers)[:config.answers] 70 | for count, answer in keep_answers: 71 | ANSWER_INDEX.index(answer) 72 | 73 | def compute_normalizers(config): 74 | mean = np.zeros((512,)) 75 | mmt2 = np.zeros((512,)) 76 | count = 0 77 | with open(QUESTION_FILE % "train2014") as question_f: 78 | questions = json.load(question_f)["questions"] 79 | image_ids = [q["image_id"] for q in questions] 80 | if hasattr(config, "debug"): 81 | image_ids = image_ids[:config.debug] 82 | for image_id in image_ids: 83 | with np.load(IMAGE_FILE % ("train2014", "train2014", image_id)) as zdata: 84 | assert len(zdata.keys()) == 1 85 | image_data = zdata[zdata.keys()[0]] 86 | sq_image_data = np.square(image_data) 87 | mean += np.sum(image_data, axis=(1,2)) 88 | mmt2 += np.sum(sq_image_data, axis=(1,2)) 89 | count += image_data.shape[1] * image_data.shape[2] 90 | mean /= count 91 | mmt2 /= count 92 | var = mmt2 - np.square(mean) 93 | std = np.sqrt(var) 94 | 95 | return mean, std 96 | 97 | def parse_to_layout_helper(parse, config, modules): 98 | if isinstance(parse, str): 99 | return modules["find"], MODULE_INDEX[parse] or UNK_ID 100 | head = parse[0] 101 | below = [parse_to_layout_helper(c, config, modules) for c in parse[1:]] 102 | modules_below, labels_below = zip(*below) 103 | modules_below = tuple(modules_below) 104 | labels_below = tuple(labels_below) 105 | if head == "and": 106 | module_head = modules["and"] 107 | else: 108 | module_head = modules["describe"] 109 | label_head = MODULE_INDEX[head] or UNK_ID 110 | modules_here = (module_head,) + modules_below 111 | labels_here = (label_head,) + labels_below 112 | return modules_here, labels_here 113 | 114 | def parse_to_layout(parse, config, modules): 115 | modules, indices = parse_to_layout_helper(parse, config, modules) 116 | return Layout(modules, indices) 117 | 118 | class VqaDatum(Datum): 119 | def __init__(self, id, question, parses, layouts, input_set, input_id, answers, mean, std): 120 | Datum.__init__(self) 121 | self.id = id 122 | self.question = question 123 | self.parses = parses 124 | self.layouts = layouts 125 | self.input_set = input_set 126 | self.input_id = input_id 127 | self.answers = answers 128 | 129 | self.input_path = IMAGE_FILE % (self.input_set, self.input_set, self.input_id) 130 | self.image_path = RAW_IMAGE_FILE % (self.input_set, self.input_set, self.input_id) 131 | 132 | self.mean = mean[:,np.newaxis,np.newaxis] 133 | self.std = std[:,np.newaxis,np.newaxis] 134 | 135 | if not os.path.exists(self.input_path): 136 | raise IOError("No such processed image: " + self.input_path) 137 | if not os.path.exists(self.input_path): 138 | raise IOError("No such source image: " + self.image_paht) 139 | 140 | def load_features(self): 141 | with np.load(self.input_path) as zdata: 142 | assert len(zdata.keys()) == 1 143 | image_data = zdata[zdata.keys()[0]] 144 | image_data -= self.mean 145 | image_data /= self.std 146 | channels, width, height = image_data.shape 147 | image_data = image_data.reshape((channels, width * height, 1)) 148 | return image_data 149 | 150 | def load_rel_features(self): 151 | return None 152 | 153 | class VqaTask: 154 | def __init__(self, config): 155 | prepare_indices(config.task) 156 | logging.debug("prepared indices") 157 | 158 | modules = { 159 | "find": MLPFindModule(config.model), 160 | "describe": DescribeModule(config.model), 161 | "exists": ExistsModule(config.model), 162 | "and": AndModule(config.model), 163 | } 164 | 165 | mean, std = compute_normalizers(config.task) 166 | logging.debug("computed image feature normalizers") 167 | logging.debug("using %s chooser", config.task.chooser) 168 | 169 | self.train = VqaTaskSet(config.task, ["train2014", "val2014"], modules, mean, std) 170 | self.val = VqaTaskSet(config.task, ["test-dev2015"], modules, mean, std) 171 | self.test = VqaTaskSet(config.task, ["test2015"], modules, mean, std) 172 | 173 | class VqaTaskSet: 174 | def __init__(self, config, set_names, modules, mean, std): 175 | size = config.debug if hasattr(config, "debug") else None 176 | 177 | self.by_id = dict() 178 | self.by_layout_type = defaultdict(list) 179 | 180 | for set_name in set_names: 181 | self.load_set(config, set_name, size, modules, mean, std) 182 | 183 | for datum in self.by_id.values(): 184 | self.by_layout_type[datum.layouts[0].modules].append(datum) 185 | datum.layout = datum.layouts[0] 186 | 187 | self.layout_types = self.by_layout_type.keys() 188 | self.data = self.by_id.values() 189 | 190 | logging.info("%s:", ", ".join(set_names).upper()) 191 | logging.info("%s items", len(self.by_id)) 192 | logging.info("%d answers", len(ANSWER_INDEX)) 193 | logging.info("%d predicates", len(MODULE_INDEX)) 194 | logging.info("%d words", len(QUESTION_INDEX)) 195 | #logging.info("%d layouts", len(self.layout_types)) 196 | logging.info("") 197 | 198 | def load_set(self, config, set_name, size, modules, mean, std): 199 | parse_file = MULTI_PARSE_FILE 200 | with open(QUESTION_FILE % set_name) as question_f, \ 201 | open(parse_file % set_name) as parse_f: 202 | questions = json.load(question_f)["questions"] 203 | parse_groups = [l.strip() for l in parse_f] 204 | assert len(questions) == len(parse_groups) 205 | pairs = zip(questions, parse_groups) 206 | if size is not None: 207 | pairs = pairs[:size] 208 | for question, parse_group in pairs: 209 | id = question["question_id"] 210 | question_str = proc_question(question["question"]) 211 | indexed_question = \ 212 | [QUESTION_INDEX[w] or UNK_ID for w in question_str] 213 | 214 | parse_strs = parse_group.split(";") 215 | parses = [parse_tree(p) for p in parse_strs] 216 | parses = [("_what", "_thing") if p == "none" else p for p in parses] 217 | if config.chooser == "null": 218 | parses = [("_what", "_thing")] 219 | elif config.chooser == "cvpr": 220 | if parses[0][0] == "is": 221 | parses = parses[-1:] 222 | else: 223 | parses = parses[:1] 224 | elif config.chooser == "naacl": 225 | pass 226 | else: 227 | assert False 228 | 229 | layouts = [parse_to_layout(p, config, modules) for p in parses] 230 | image_id = question["image_id"] 231 | try: 232 | image_set_name = "test2015" if set_name == "test-dev2015" else set_name 233 | datum = VqaDatum(id, indexed_question, parses, layouts, image_set_name, image_id, [], mean, std) 234 | self.by_id[id] = datum 235 | except IOError as e: 236 | print e 237 | pass 238 | 239 | if set_name not in ("test2015", "test-dev2015"): 240 | with open(ANN_FILE % set_name) as ann_f: 241 | annotations = json.load(ann_f)["annotations"] 242 | for ann in annotations: 243 | question_id = ann["question_id"] 244 | if question_id not in self.by_id: 245 | continue 246 | 247 | answer_counter = defaultdict(lambda: 0) 248 | answers = [a["answer"] for a in ann["answers"]] 249 | indexed_answers = [ANSWER_INDEX[a] or UNK_ID for a in answers] 250 | self.by_id[question_id].answers = indexed_answers 251 | --------------------------------------------------------------------------------