├── LICENSE ├── config.py ├── configs ├── args.txt ├── args1.txt ├── args2.txt ├── args3.txt └── args4.txt ├── extract_features.py ├── mac_cell.py ├── main.py ├── mi_gru_cell.py ├── mi_lstm_cell.py ├── model.py ├── ops.py ├── preprocess.py ├── program_translator.py ├── readme.md ├── requirements.txt └── visualization.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | ###################################### configuration ###################################### 5 | class Config(object): 6 | 7 | typeFilters = [[], ["1_query_size_", 8 | "1_query_material_", 9 | "2_equal_color_", 10 | "2_equal_shape_"], 11 | ["1_query_color_", 12 | "1_query_shape_", 13 | "2_equal_size_", 14 | "2_equal_material_"]] 15 | 16 | #### files interface 17 | ## data files 18 | dataPath = "" # dataset specific 19 | datasetFilename = "" # dataset specific 20 | 21 | # file names 22 | imagesFilename = "{tier}.h5" # Images 23 | instancesFilename = "{tier}Instances.json" 24 | # symbols dictionaries 25 | questionDictFilename = "questionDict.pkl" 26 | answerDictFilename = "answerDict.pkl" 27 | qaDictFilename = "qaDict.pkl" 28 | 29 | ## experiment files 30 | expPathname = "{expName}" 31 | expName = "" # will be assigned through argparse 32 | 33 | weightsPath = "./weights" 34 | weightsFilename = "weights{epoch}.ckpt" 35 | 36 | # model predictions and optionally attention maps 37 | predsPath = "./preds" 38 | predsFilename = "{tier}Predictions-{expName}.json" 39 | answersFilename = "{tier}Answers-{expName}.txt" 40 | 41 | # logging of accuracy, loss etc. per epoch 42 | logPath = "./results" 43 | logFilename = "results-{expName}.csv" 44 | 45 | # configuration file of the used flags to run the experiment 46 | configPath = "./results" 47 | configFilename = "config-{expName}.json" 48 | 49 | def toString(self): 50 | return self.expName 51 | 52 | # make directories of experiment if not exist yet 53 | def makedirs(self, directory): 54 | directory = os.path.join(directory, self.expPath()) 55 | if not os.path.exists(directory): 56 | os.makedirs(directory) 57 | return directory 58 | 59 | ### filename builders 60 | ## data files 61 | def dataFile(self, filename): 62 | return os.path.join(self.dataPath, filename) 63 | 64 | def generatedFile(self, filename): 65 | return self.dataFile(self.generatedPrefix + filename) 66 | 67 | datasetFile = lambda self, tier: self.dataFile(self.datasetFilename.format(tier = tier)) 68 | imagesIdsFile = lambda self, tier: self.dataFile(self.imgIdsFilename.format(tier = tier)) # 69 | imagesFile = lambda self, tier: self.dataFile(self.imagesFilename.format(tier = tier)) 70 | instancesFile = lambda self, tier: self.generatedFile(self.instancesFilename.format(tier = tier)) 71 | 72 | questionDictFile = lambda self: self.generatedFile(self.questionDictFilename) 73 | answerDictFile = lambda self: self.generatedFile(self.answerDictFilename) 74 | qaDictFile = lambda self: self.generatedFile(self.qaDictFilename) 75 | 76 | ## experiment files 77 | expPath = lambda self: self.expPathname.format(expName = self.toString()) 78 | 79 | weightsDir = lambda self: self.makedirs(self.weightsPath) 80 | predsDir = lambda self: self.makedirs(self.predsPath) 81 | logDir = lambda self: self.makedirs(self.logPath) 82 | configDir = lambda self: self.makedirs(self.configPath) 83 | 84 | weightsFile = lambda self, epoch: os.path.join(self.weightsDir(), self.weightsFilename.format(epoch = str(epoch))) 85 | predsFile = lambda self, tier: os.path.join(self.predsDir(), self.predsFilename.format(tier = tier, expName = self.expName)) 86 | answersFile = lambda self, tier: os.path.join(self.predsDir(), self.answersFilename.format(tier = tier, expName = self.expName)) 87 | logFile = lambda self: os.path.join(self.logDir(), self.logFilename.format(expName = self.expName)) 88 | configFile = lambda self: os.path.join(self.configDir(), self.configFilename.format(expName = self.expName)) 89 | 90 | 91 | # global configuration variable. Holds file paths and program parameters 92 | config = Config() 93 | 94 | ###################################### arguments ###################################### 95 | def parseArgs(): 96 | parser = argparse.ArgumentParser(fromfile_prefix_chars = "@") 97 | 98 | ################ systems 99 | 100 | # gpus and memory 101 | parser.add_argument("--gpus", default = "", type = str, help = "comma-separated list of gpus to use") 102 | parser.add_argument("--gpusNum", default = 1, type = int, help = "number of gpus to use") 103 | 104 | parser.add_argument("--allowGrowth", action = "store_true", help = "allow gpu memory growth") 105 | parser.add_argument("--maxMemory", default = 1.0, type = float, help = "set maximum gpu memory usage") 106 | 107 | parser.add_argument("--parallel", action = "store_true", help = "load images in parallel to batch running") 108 | parser.add_argument("--workers", default = 1, type = int, help = "number of workers to load images") 109 | parser.add_argument("--taskSize", default = 8, type = int, help = "number of image batches to load in advance") # 40 110 | # parser.add_argument("--tasksNum", default = 20, type = int, help = "maximal queue size for tasks (to constrain ram usage)") # 2 111 | 112 | parser.add_argument("--useCPU", action = "store_true", help = "put word embeddings on cpu") 113 | 114 | # weight loading and training 115 | parser.add_argument("-r", "--restore", action = "store_true", help = "restore last epoch (based on results file)") 116 | parser.add_argument("--restoreEpoch", default = 0, type = int, help = "if positive, specific epoch to restore") 117 | parser.add_argument("--weightsToKeep", default = 2, type = int, help = "number of previous epochs' weights keep") 118 | parser.add_argument("--saveEvery", default = 3000, type = int, help = "number of iterations to save weights after") 119 | parser.add_argument("--calleEvery", default = 1500, type = int, help = "number of iterations to call custom function after") 120 | 121 | parser.add_argument("--saveSubset", action = "store_true", help = "save only subset of the weights") 122 | parser.add_argument("--trainSubset", action = "store_true", help = "train only subset of the weights") 123 | parser.add_argument("--varSubset", default = [], nargs = "*", type = str, help = "list of namespaces to train on") 124 | 125 | # trainReader = ["questionEmbeddings", "questionReader"] 126 | # saveControl = ["questionEmbeddings", "programEmbeddings", "seqReader", "programControl"] 127 | 128 | # experiment files 129 | parser.add_argument("--expName", default = "experiment", type = str, help = "experiment name") 130 | 131 | # data files 132 | parser.add_argument("--dataset", default = "CLEVR", choices = ["CLEVR", "NLVR"], type = str) # 133 | parser.add_argument("--dataBasedir", default = "./", type = str, help = "data base directory") # /jagupard14/scr1/dorarad/ 134 | parser.add_argument("--generatedPrefix", default = "gennew", type = str, help = "prefix for generated data files") 135 | parser.add_argument("--featureType", default = "norm_128x32", type = str, help = "features type") # 136 | # resnet101_512x128, norm_400x100, none_80x20, normPerImage_80x20, norm_80x20 137 | 138 | ################ optimization 139 | 140 | # training/testing 141 | parser.add_argument("--train", action = "store_true", help = "run training") 142 | parser.add_argument("--evalTrain", action = "store_true", help = "run eval with ema on train dataset") # 143 | parser.add_argument("--test", action = "store_true", help = "run testing every epoch and generate predictions file") # 144 | parser.add_argument("--finalTest", action = "store_true", help = "run testing on final epoch") 145 | parser.add_argument("--retainVal", action = "store_true", help = "retain validation order between runs") # 146 | 147 | parser.add_argument("--getPreds", action = "store_true", help = "store prediction") 148 | parser.add_argument("--getAtt", action = "store_true", help = "store attention maps") 149 | parser.add_argument("--analysisType", default = "", type = str, choices = ["", "questionLength, programLength","type", "arity"], help = "show breakdown of results according to type") # 150 | 151 | parser.add_argument("--trainedNum", default = 0, type = int, help = "if positive, train on subset of the data") 152 | parser.add_argument("--testedNum", default = 0, type = int, help = "if positive, test on subset of the data") 153 | 154 | # bucketing 155 | parser.add_argument("--noBucket", action = "store_true", help = "bucket data according to question length") 156 | parser.add_argument("--noRebucket", action = "store_true", help = "bucket data according to question and program length") # 157 | 158 | # filtering 159 | parser.add_argument("--tOnlyChain", action = "store_true", help = "train only chain questions") 160 | parser.add_argument("--vOnlyChain", action = "store_true", help = "test only chain questions") 161 | parser.add_argument("--tMaxQ", default = 0, type = int, help = "if positive, train on questions up to this length") 162 | parser.add_argument("--tMaxP", default = 0, type = int, help = "if positive, test on questions up to this length") 163 | parser.add_argument("--vMaxQ", default = 0, type = int, help = "if positive, train on questions with programs up to this length") 164 | parser.add_argument("--vMaxP", default = 0, type = int, help = "if positive, test on questions with programs up to this length") 165 | parser.add_argument("--tFilterOp", default = 0, type = int, help = "train questions by to be included in the types listed") 166 | parser.add_argument("--vFilterOp", default = 0, type = int, help = "test questions by to be included in the types listed") 167 | 168 | # extra and extraVal 169 | parser.add_argument("--extra", action = "store_true", help = "prepare extra data (add to vocabulary") # 170 | parser.add_argument("--trainExtra", action = "store_true", help = "train (only) on extra data") # 171 | parser.add_argument("--alterExtra", action = "store_true", help = "alter main data training with extra dataset") # 172 | parser.add_argument("--alterNum", default = 1, type = int, help = "alteration rate") # 173 | parser.add_argument("--extraVal", action = "store_true", help = "only extra validation data (for compositional clevr)") # 174 | parser.add_argument("--finetuneNum", default = 0, type = int, help = "if positive, finetune on that subset of val (for compositional clevr)") # 175 | 176 | # exponential moving average 177 | parser.add_argument("--useEMA", action = "store_true", help = "use exponential moving average for weights") 178 | parser.add_argument("--emaDecayRate", default = 0.999, type = float, help = "decay rate for exponential moving average") 179 | 180 | # sgd optimizer 181 | parser.add_argument("--batchSize", default = 64, type = int, help = "batch size") 182 | parser.add_argument("--epochs", default = 100, type = int, help = "number of epochs to run") 183 | parser.add_argument("--lr", default = 0.0001, type = float, help = "learning rate") 184 | parser.add_argument("--lrReduce", action = "store_true", help = "reduce learning rate if training loss doesn't go down (manual annealing)") 185 | parser.add_argument("--lrDecayRate", default = 0.5, type = float, help = "learning decay rate if training loss doesn't go down") 186 | parser.add_argument("--earlyStopping", default = 0, type = int, help = "if positive, stop if no improvement for that number of epochs") 187 | 188 | parser.add_argument("--adam", action = "store_true", help = "use adam") 189 | parser.add_argument("--l2", default = 0, type = float, help = "if positive, add l2 loss term") 190 | parser.add_argument("--clipGradients", action = "store_true", help = "clip gradients") 191 | parser.add_argument("--gradMaxNorm", default = 8, type = int, help = "clipping value") 192 | 193 | # batch normalization 194 | parser.add_argument("--memoryBN", action = "store_true", help = "use batch normalization on the recurrent memory") 195 | parser.add_argument("--stemBN", action = "store_true", help = "use batch normalization in the image input unit (stem)") 196 | parser.add_argument("--outputBN", action = "store_true", help = "use batch normalization in the output unit") 197 | parser.add_argument("--bnDecay", default = 0.999, type = float, help = "batch norm decay rate") 198 | parser.add_argument("--bnCenter", action = "store_true", help = "batch norm with centering") 199 | parser.add_argument("--bnScale", action = "store_true", help = "batch norm with scaling") 200 | 201 | ## dropouts 202 | parser.add_argument("--encInputDropout", default = 0.85, type = float, help = "dropout of the rnn inputs to the Question Input Unit") 203 | parser.add_argument("--encStateDropout", default = 1.0, type = float, help = "dropout of the rnn states of the Question Input Unit") 204 | parser.add_argument("--stemDropout", default = 0.82, type = float, help = "dropout of the Image Input Unit (the stem)") 205 | 206 | parser.add_argument("--qDropout", default = 0.92, type = float, help = "dropout on the question vector") 207 | # parser.add_argument("--qDropoutOut", default = 1.0, type = float, help = "dropout on the question vector the goes to the output unit") 208 | # parser.add_argument("--qDropoutMAC", default = 1.0, type = float, help = "dropout on the question vector the goes to MAC") 209 | 210 | parser.add_argument("--memoryDropout", default = 0.85, type = float, help = "dropout on the recurrent memory") 211 | parser.add_argument("--readDropout", default = 0.85, type = float, help = "dropout of the read unit") 212 | parser.add_argument("--writeDropout", default = 1.0, type = float, help = "dropout of the write unit") 213 | parser.add_argument("--outputDropout", default = 0.85, type = float, help = "dropout of the output unit") 214 | 215 | parser.add_argument("--parametricDropout", action = "store_true", help = "use parametric dropout") # 216 | parser.add_argument("--encVariationalDropout", action = "store_true", help = "use variational dropout in the RNN input unit") 217 | parser.add_argument("--memoryVariationalDropout", action = "store_true", help = "use variational dropout across the MAC network") 218 | 219 | ## nonlinearities 220 | parser.add_argument("--relu", default = "STD", choices = ["STD", "PRM", "ELU", "LKY", "SELU"], type = str, help = "type of ReLU to use: standard, parametric, ELU, or leaky") 221 | # parser.add_argument("--reluAlpha", default = 0.2, type = float, help = "alpha value for the leaky ReLU") 222 | 223 | parser.add_argument("--mulBias", default = 0.0, type = float, help = "bias to add in multiplications (x + b) * (y + b) for better training") # 224 | 225 | parser.add_argument("--imageLinPool", default = 2, type = int, help = "pooling for image linearizion") 226 | 227 | ################ baseline model parameters 228 | 229 | parser.add_argument("--useBaseline", action = "store_true", help = "run the baseline model") 230 | parser.add_argument("--baselineLSTM", action = "store_true", help = "use LSTM in baseline") 231 | parser.add_argument("--baselineCNN", action = "store_true", help = "use CNN in baseline") 232 | parser.add_argument("--baselineAtt", action = "store_true", help = "use stacked attention baseline") 233 | 234 | parser.add_argument("--baselineProjDim", default = 64, type = int, help = "projection dimension for image linearizion") 235 | 236 | parser.add_argument("--baselineAttNumLayers", default = 2, type = int, help = "number of stacked attention layers") 237 | parser.add_argument("--baselineAttType", default = "ADD", type = str, choices = ["MUL", "DIAG", "BL", "ADD"], help = "attention type (multiplicative, additive, etc)") 238 | 239 | ################ image input unit (the "stem") 240 | 241 | parser.add_argument("--stemDim", default = 512, type = int, help = "dimension of stem CNNs") 242 | parser.add_argument("--stemNumLayers", default = 2, type = int, help = "number of stem layers") 243 | parser.add_argument("--stemKernelSize", default = 3, type = int, help = "kernel size for stem (same for all the stem layers)") 244 | parser.add_argument("--stemKernelSizes", default = None, nargs = "*", type = int, help = "kernel sizes for stem (per layer)") 245 | parser.add_argument("--stemStrideSizes", default = None, nargs = "*", type = int, help = "stride sizes for stem (per layer)") 246 | 247 | parser.add_argument("--stemLinear", action = "store_true", help = "use a linear stem (instead of CNNs)") # 248 | # parser.add_argument("--stemProjDim", default = 64, type = int, help = "projection dimension of in image linearization") # 249 | # parser.add_argument("--stemProjPooling", default = 2, type = int, help = "pooling for the image linearization") # 250 | 251 | parser.add_argument("--stemGridRnn", action = "store_true", help = "use grid RNN layer") # 252 | parser.add_argument("--stemGridRnnMod", default = "RNN", type = str, choices = ["RNN", "GRU"], help = "RNN type for grid") # 253 | parser.add_argument("--stemGridAct", default = "NON", type = str, choices = ["NON", "RELU", "TANH"], help = "nonlinearity type for grid") # 254 | 255 | ## location 256 | parser.add_argument("--locationAware", action = "store_true", help = "add positional features to image representation (linear meshgrid by default)") 257 | parser.add_argument("--locationType", default = "L", type = str, choices = ["L", "PE"], help = "L: linear features, PE: Positional Encoding") 258 | parser.add_argument("--locationBias", default = 1.0, type = float, help = "the scale of the positional features") 259 | parser.add_argument("--locationDim", default = 32, type = int, help = "the number of PE dimensions") 260 | 261 | ################ question input unit (the "encoder") 262 | parser.add_argument("--encType", default = "LSTM", choices = ["RNN", "GRU", "LSTM", "MiGRU", "MiLSTM"], help = "encoder RNN type") 263 | parser.add_argument("--encDim", default = 512, type = int, help = "dimension of encoder RNN") 264 | parser.add_argument("--encNumLayers", default = 1, type = int, help = "number of encoder RNN layers") 265 | parser.add_argument("--encBi", action = "store_true", help = "use bi-directional encoder") 266 | # parser.add_argument("--encOutProj", action = "store_true", help = "add projection layer for encoder outputs") 267 | # parser.add_argument("--encOutProjDim", default = 256, type = int, help = "dimension of the encoder projection layer") 268 | # parser.add_argument("--encQProj", action = "store_true", help = "add projection for the question representation") 269 | parser.add_argument("--encProj", action = "store_true", help = "project encoder outputs and question") 270 | parser.add_argument("--encProjQAct", default = "NON", type = str, choices = ["NON", "RELU", "TANH"], help = "project question vector with this activation") 271 | 272 | ##### word embeddings 273 | parser.add_argument("--wrdEmbDim", default = 300, type = int, help = "word embeddings dimension") 274 | parser.add_argument("--wrdEmbRandom", action = "store_true", help = "initialize word embeddings to random (normal)") 275 | parser.add_argument("--wrdEmbUniform", action = "store_true", help = "initialize with uniform distribution") 276 | parser.add_argument("--wrdEmbScale", default = 1.0, type = float, help = "word embeddings initialization scale") 277 | parser.add_argument("--wrdEmbFixed", action = "store_true", help = "set word embeddings fixed (don't train)") 278 | parser.add_argument("--wrdEmbUnknown", action = "store_true", help = "set words outside of training set to ") 279 | 280 | parser.add_argument("--ansEmbMod", default = "NON", choices = ["NON", "SHARED", "BOTH"], type = str, help = "BOTH: create word embeddings for answers. SHARED: share them with question embeddings.") # 281 | parser.add_argument("--answerMod", default = "NON", choices = ["NON", "MUL", "DIAG", "BL"], type = str, help = "operation for multiplication with answer embeddings: direct multiplication, scalar weighting, or bilinear") # 282 | 283 | ################ output unit (classifier) 284 | parser.add_argument("--outClassifierDims", default = [512], nargs = "*", type = int, help = "dimensions of the classifier") 285 | parser.add_argument("--outImage", action = "store_true", help = "feed the image to the output unit") 286 | parser.add_argument("--outImageDim", default = 1024, type = int, help = "dimension of linearized image fed to the output unit") 287 | parser.add_argument("--outQuestion", action = "store_true", help = "feed the question to the output unit") 288 | parser.add_argument("--outQuestionMul", action = "store_true", help = "feed the multiplication of question and memory to the output unit") 289 | 290 | ################ network 291 | 292 | parser.add_argument("--netLength", default = 16, type = int, help = "network length (number of cells)") 293 | # parser.add_argument("--netDim", default = 512, type = int) 294 | parser.add_argument("--memDim", default = 512, type = int, help = "dimension of memory state") 295 | parser.add_argument("--ctrlDim", default = 512, type = int, help = "dimension of control state") 296 | parser.add_argument("--attDim", default = 512, type = int, help = "dimension of pre-attention interactions space") 297 | parser.add_argument("--unsharedCells", default = False, type = bool, help = "unshare weights between cells ") 298 | 299 | # initialization 300 | parser.add_argument("--initCtrl", default = "PRM", type = str, choices = ["PRM", "ZERO", "Q"], help = "initialization mod for control") 301 | parser.add_argument("--initMem", default = "PRM", type = str, choices = ["PRM", "ZERO", "Q"], help = "initialization mod for memory") 302 | parser.add_argument("--initKBwithQ", default = "NON", type = str, choices = ["NON", "CNCT", "MUL"], help = "merge question with knowledge base") 303 | parser.add_argument("--addNullWord", action = "store_true", help = "add parametric word in the beginning of the question") 304 | 305 | ################ control unit 306 | # control ablations (use whole question or pre-attention continuous vectors as control) 307 | parser.add_argument("--controlWholeQ", action = "store_true", help = "use whole question vector as control") 308 | parser.add_argument("--controlContinuous", action = "store_true", help = "use continuous representation of control (without attention)") 309 | 310 | # step 0: inputs to control unit (word embeddings or encoder outputs, with optional projection) 311 | parser.add_argument("--controlContextual", action = "store_true", help = "use contextual words for attention (otherwise will use word embeddings)") 312 | parser.add_argument("--controlInWordsProj", action = "store_true", help = "apply linear projection over words for attention computation") 313 | parser.add_argument("--controlOutWordsProj", action = "store_true", help = "apply linear projection over words for summary computation") 314 | 315 | parser.add_argument("--controlInputUnshared", action = "store_true", help = "use different question representation for each cell") 316 | parser.add_argument("--controlInputAct", default = "TANH", type = str, choices = ["NON", "RELU", "TANH"], help = "activation for question projection") 317 | 318 | # step 1: merging previous control and whole question 319 | parser.add_argument("--controlFeedPrev", action = "store_true", help = "feed previous control state") 320 | parser.add_argument("--controlFeedPrevAtt", action = "store_true", help = "feed previous control post word attention (otherwise will feed continuous control)") 321 | parser.add_argument("--controlFeedInputs", action = "store_true", help = "feed question representation") 322 | parser.add_argument("--controlContAct", default = "NON", type = str, choices = ["NON", "RELU", "TANH"], help = "activation on the words interactions") 323 | 324 | # step 2: word attention and optional projection 325 | parser.add_argument("--controlConcatWords", action = "store_true", help = "concatenate words to interaction when computing attention") 326 | parser.add_argument("--controlProj", action = "store_true", help = "apply linear projection on words interactions") 327 | parser.add_argument("--controlProjAct", default = "NON", type = str, choices = ["NON", "RELU", "TANH"], help = "activation for control interactions") 328 | 329 | # parser.add_argument("--controlSelfAtt", default = False, type = bool) 330 | 331 | # parser.add_argument("--controlCoverage", default = False, type = bool) 332 | # parser.add_argument("--controlCoverageBias", default = 1.0, type = float) 333 | 334 | # parser.add_argument("--controlPostRNN", default = False, type = bool) 335 | # parser.add_argument("--controlPostRNNmod", default = "RNN", type = str) # GRU 336 | 337 | # parser.add_argument("--selfAttShareInter", default = False, type = bool) 338 | 339 | # parser.add_argument("--wordControl", default = False, type = bool) 340 | # parser.add_argument("--gradualControl", default = False, type = bool) 341 | 342 | ################ read unit 343 | # step 1: KB-memory interactions 344 | parser.add_argument("--readProjInputs", action = "store_true", help = "project read unit inputs") 345 | parser.add_argument("--readProjShared", action = "store_true", help = "use shared projection for all read unit inputs") 346 | 347 | parser.add_argument("--readMemAttType", default = "MUL", type = str, choices = ["MUL", "DIAG", "BL", "ADD"], help = "attention type for interaction with memory") 348 | parser.add_argument("--readMemConcatKB", action = "store_true", help = "concatenate KB elements to memory interaction") 349 | parser.add_argument("--readMemConcatProj", action = "store_true", help = "concatenate projected values instead or original to memory interaction") 350 | parser.add_argument("--readMemProj", action = "store_true", help = "project interactions with memory") 351 | parser.add_argument("--readMemAct", default = "RELU", type = str, choices = ["NON", "RELU", "TANH"], help = "activation for memory interaction") 352 | 353 | # step 2: interaction with control 354 | parser.add_argument("--readCtrl", action = "store_true", help = "compare KB-memory interactions to control") 355 | parser.add_argument("--readCtrlAttType", default = "MUL", type = str, choices = ["MUL", "DIAG", "BL", "ADD"], help = "attention type for interaction with control") 356 | parser.add_argument("--readCtrlConcatKB", action = "store_true", help = "concatenate KB elements to control interaction") 357 | parser.add_argument("--readCtrlConcatProj", action = "store_true", help = "concatenate projected values instead or original to control interaction") 358 | parser.add_argument("--readCtrlConcatInter", action = "store_true", help = "concatenate memory interactions to control interactions") 359 | parser.add_argument("--readCtrlAct", default = "RELU", type = str, choices = ["NON", "RELU", "TANH"], help = "activation for control interaction") 360 | 361 | # step 3: summarize attention over knowledge base 362 | parser.add_argument("--readSmryKBProj", action = "store_true", help = "use knowledge base projections when summing attention up (should be used only if KB is projected.") 363 | 364 | # parser.add_argument("--saAllMultiplicative", default = False, type = bool) 365 | # parser.add_argument("--saSumMultiplicative", default = False, type = bool) 366 | 367 | ################ write unit 368 | # step 1: input to the write unit (only previous memory, or new information, or both) 369 | parser.add_argument("--writeInputs", default = "BOTH", type = str, choices = ["MEM", "INFO", "BOTH", "SUM"], help = "inputs to the write unit") 370 | parser.add_argument("--writeConcatMul", action = "store_true", help = "add multiplicative integration between inputs") 371 | 372 | parser.add_argument("--writeInfoProj", action = "store_true", help = "project retrieved info") 373 | parser.add_argument("--writeInfoAct", default = "NON", type = str, choices = ["NON", "RELU", "TANH"], help = "new info activation") 374 | 375 | # step 2: self attention and following projection 376 | parser.add_argument("--writeSelfAtt", action = "store_true", help = "use self attention") 377 | parser.add_argument("--writeSelfAttMod", default = "NON", type = str, choices = ["NON", "CONT"], help = "control version to compare to") 378 | 379 | parser.add_argument("--writeMergeCtrl", action = "store_true", help = "merge control with memory") 380 | 381 | parser.add_argument("--writeMemProj", action = "store_true", help = "project new memory") 382 | parser.add_argument("--writeMemAct", default = "NON", type = str, choices = ["NON", "RELU", "TANH"], help = "new memory activation") 383 | 384 | # step 3: gate between new memory and previous value 385 | parser.add_argument("--writeGate", action = "store_true", help = "add gate to write unit") 386 | parser.add_argument("--writeGateShared", action = "store_true", help = "use one gate value for all dimensions of the memory state") 387 | parser.add_argument("--writeGateBias", default = 1.0, type = float, help = "bias for the write unit gate (positive to bias for taking new memory)") 388 | 389 | ## modular 390 | # parser.add_argument("--modulesNum", default = 10, type = int) 391 | # parser.add_argument("--controlBoth", default = False, type = bool) 392 | # parser.add_argument("--addZeroModule", default = False, type = bool) 393 | # parser.add_argument("--endModule", default = False, type = bool) 394 | 395 | ## hybrid 396 | # parser.add_argument("--hybrid", default = False, type = bool, help = "hybrid attention cnn model") 397 | # parser.add_argument("--earlyHybrid", default = False, type = bool) 398 | # parser.add_argument("--lateHybrid", default = False, type = bool) 399 | 400 | ## autoencoders 401 | # parser.add_argument("--autoEncMem", action = "store_true", help = "add memory2control auto-encoder loss") 402 | # parser.add_argument("--autoEncMemW", default = 0.0001, type = float, help = "weight for auto-encoder loss") 403 | # parser.add_argument("--autoEncMemInputs", default = "INFO", type = str, choices = ["MEM", "INFO"], help = "inputs to auto-encoder") 404 | # parser.add_argument("--autoEncMemAct", default = "NON", type = str, choices = ["NON", "RELU", "TANH"], help = "activation type in the auto-encoder") 405 | # parser.add_argument("--autoEncMemLoss", default = "CONT", type = str, choices = ["CONT", "PROB", "SMRY"], help = "target for the auto-encoder loss") 406 | # parser.add_argument("--autoEncMemCnct", action = "store_true", help = "concat word attentions to auto-encoder features") 407 | 408 | # parser.add_argument("--autoEncCtrl", action = "store_true") 409 | # parser.add_argument("--autoEncCtrlW", default = 0.0001, type = float) 410 | # parser.add_argument("--autoEncCtrlGRU", action = "store_true") 411 | 412 | ## temperature 413 | # parser.add_argument("--temperature", default = 1.0, type = float, help = "temperature for modules softmax") # 414 | # parser.add_argument("--tempParametric", action = "store_true", help = "parametric temperature") # 415 | # parser.add_argument("--tempDynamic", action = "store_true", help = "dynamic temperature") # 416 | # parser.add_argument("--tempAnnealRate", default = 0.000004, type = float, help = "temperature annealing rate") # 417 | # parser.add_argument("--tempMin", default = 0.5, type = float, help = "minimum temperature") # 418 | 419 | ## gumbel 420 | # parser.add_argument("--gumbelSoftmax", action = "store_true", help = "use gumbel for the module softmax (soft for training and hard for testing)") # 421 | # parser.add_argument("--gumbelSoftmaxBoth", action = "store_true", help = "use softmax for training and testing") # 422 | # parser.add_argument("--gumbelArgmaxBoth", action = "store_true", help = "use argmax for training and testing") # 423 | 424 | parser.parse_args(namespace = config) 425 | 426 | ###################################### dataset configuration ###################################### 427 | 428 | def configCLEVR(): 429 | config.dataPath = "{dataBasedir}/CLEVR_v1/data".format(dataBasedir = config.dataBasedir) 430 | config.datasetFilename = "CLEVR_{tier}_questions.json" 431 | config.wordVectorsFile = "./CLEVR_v1/data/glove/glove.6B.{dim}d.txt".format(dim = config.wrdEmbDim) # 432 | 433 | config.imageDims = [14, 14, 1024] 434 | config.programLims = [5, 10, 15, 20] 435 | config.questionLims = [10, 15, 20, 25] 436 | 437 | def configNLVR(): 438 | config.dataPath = "{dataBasedir}/nlvr".format(dataBasedir = config.dataBasedir) 439 | config.datasetFilename = "{tier}.json" 440 | config.imagesFilename = "{{tier}}_{featureType}.h5".format(featureType = config.featureType) 441 | config.imgIdsFilename = "{tier}ImgIds.json" 442 | config.wordVectorsFile = "./CLEVR_v1/data/glove/glove.6B.{dim}d.txt".format(dim = config.wrdEmbDim) # 443 | 444 | config.questionLims = [12] 445 | # config.noRebucket = True 446 | 447 | # if config.stemKernelSizes == []: 448 | # if config.featureType.endsWith("128x32"): 449 | # config.stemKernelSizes = [8, 4, 4] 450 | # config.stemStrideSizes = [2, 2, 1] 451 | # config.stemNumLayers = 3 452 | # if config.featureType.endsWith("512x128"): 453 | # config.stemKernelSizes = [8, 4, 4, 2] 454 | # config.stemStrideSizes = [4, 2, 2, 1] 455 | # config.stemNumLayers = 4 456 | # config.stemDim = 64 457 | 458 | if config.featureType == "resnet101_512x128": 459 | config.imageDims = [8, 32, 1024] 460 | else: 461 | stridesOverall = 1 462 | if stemStrideSizes is not None: 463 | for s in config.stemStrideSizes: 464 | stridesOverall *= int(s) 465 | size = config.featureType.split("_")[-1].split("x") 466 | config.imageDims = [int(size[1]) / stridesOverall, int(size[0]) / stridesOverall, 3] 467 | 468 | ## dataset specific configs 469 | loadDatasetConfig = { 470 | "CLEVR": configCLEVR, 471 | "NLVR": configNLVR 472 | } 473 | -------------------------------------------------------------------------------- /configs/args.txt: -------------------------------------------------------------------------------- 1 | --parallel 2 | --evalTrain 3 | --retainVal 4 | --useEMA 5 | --lrReduce 6 | --adam 7 | --clip 8 | --memoryVariationalDropout 9 | --relu=ELU 10 | --encBi 11 | --wrdEmbRandom 12 | --wrdEmbUniform 13 | --outQuestion 14 | --initCtrl=Q 15 | --controlContextual 16 | --controlInputUnshared 17 | --readProjInputs 18 | --readMemConcatKB 19 | --readMemConcatProj 20 | --readMemProj 21 | --readCtrl 22 | --writeMemProj -------------------------------------------------------------------------------- /configs/args1.txt: -------------------------------------------------------------------------------- 1 | --parallel 2 | --evalTrain 3 | --retainVal 4 | --useEMA 5 | --lrReduce 6 | --adam 7 | --clip 8 | --memoryVariationalDropout 9 | --relu=ELU 10 | --encBi 11 | --wrdEmbRandom 12 | --wrdEmbUniform 13 | --outQuestion 14 | --controlContextual 15 | --readProjInputs 16 | --readMemConcatKB 17 | --readMemConcatProj 18 | --readMemProj 19 | --readCtrl 20 | --writeMemProj 21 | --initCtrl=PRM 22 | --controlFeedPrev 23 | --controlFeedPrevAtt 24 | --controlFeedInputs 25 | --controlContAct=TANH 26 | -------------------------------------------------------------------------------- /configs/args2.txt: -------------------------------------------------------------------------------- 1 | --parallel 2 | --evalTrain 3 | --retainVal 4 | --useEMA 5 | --lrReduce 6 | --adam 7 | --clip 8 | --memoryVariationalDropout 9 | --relu=ELU 10 | --encBi 11 | --wrdEmbRandom 12 | --wrdEmbUniform 13 | --outQuestion 14 | --initCtrl=Q 15 | --controlContextual 16 | --controlInputUnshared 17 | --readProjInputs 18 | --readMemConcatKB 19 | --readMemConcatProj 20 | --readMemProj 21 | --readCtrl 22 | --writeMemProj 23 | --qDropout=0.85 24 | --stemDropout=0.85 25 | --noBucket 26 | --noRebucket 27 | -------------------------------------------------------------------------------- /configs/args3.txt: -------------------------------------------------------------------------------- 1 | --parallel 2 | --evalTrain 3 | --retainVal 4 | --useEMA 5 | --lrReduce 6 | --adam 7 | --clip 8 | --memoryVariationalDropout 9 | --relu=ELU 10 | --encBi 11 | --wrdEmbRandom 12 | --wrdEmbUniform 13 | --outQuestion 14 | --initCtrl=Q 15 | --controlContextual 16 | --controlInputUnshared 17 | --readProjInputs 18 | --readMemConcatKB 19 | --readMemConcatProj 20 | --readMemProj 21 | --readCtrl 22 | --writeMemProj 23 | --writeSelfAtt 24 | --writeSelfAttMod=CONT 25 | -------------------------------------------------------------------------------- /configs/args4.txt: -------------------------------------------------------------------------------- 1 | --parallel 2 | --evalTrain 3 | --retainVal 4 | --useEMA 5 | --lrReduce 6 | --adam 7 | --clip 8 | --memoryVariationalDropout 9 | --relu=ELU 10 | --encBi 11 | --wrdEmbRandom 12 | --wrdEmbUniform 13 | --outQuestion 14 | --initCtrl=Q 15 | --controlContextual 16 | --controlInputUnshared 17 | --readProjInputs 18 | --readMemConcatKB 19 | --readMemConcatProj 20 | --readMemProj 21 | --readCtrl 22 | --writeMemProj 23 | --writeGate 24 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse, os, json 8 | import h5py 9 | import numpy as np 10 | from scipy.misc import imread, imresize 11 | 12 | import torch 13 | import torchvision 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--input_image_dir', required=True) 18 | parser.add_argument('--max_images', default=None, type=int) 19 | parser.add_argument('--output_h5_file', required=True) 20 | 21 | parser.add_argument('--image_height', default=224, type=int) 22 | parser.add_argument('--image_width', default=224, type=int) 23 | 24 | parser.add_argument('--model', default='resnet101') 25 | parser.add_argument('--model_stage', default=3, type=int) 26 | parser.add_argument('--batch_size', default=128, type=int) 27 | 28 | 29 | def build_model(args): 30 | if not hasattr(torchvision.models, args.model): 31 | raise ValueError('Invalid model "%s"' % args.model) 32 | if not 'resnet' in args.model: 33 | raise ValueError('Feature extraction only supports ResNets') 34 | cnn = getattr(torchvision.models, args.model)(pretrained=True) 35 | layers = [ 36 | cnn.conv1, 37 | cnn.bn1, 38 | cnn.relu, 39 | cnn.maxpool, 40 | ] 41 | for i in range(args.model_stage): 42 | name = 'layer%d' % (i + 1) 43 | layers.append(getattr(cnn, name)) 44 | model = torch.nn.Sequential(*layers) 45 | model.cuda() 46 | model.eval() 47 | return model 48 | 49 | 50 | def run_batch(cur_batch, model): 51 | mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) 52 | std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) 53 | 54 | image_batch = np.concatenate(cur_batch, 0).astype(np.float32) 55 | image_batch = (image_batch / 255.0 - mean) / std 56 | image_batch = torch.FloatTensor(image_batch).cuda() 57 | image_batch = torch.autograd.Variable(image_batch, volatile=True) 58 | 59 | feats = model(image_batch) 60 | feats = feats.data.cpu().clone().numpy() 61 | 62 | return feats 63 | 64 | 65 | def main(args): 66 | input_paths = [] 67 | idx_set = set() 68 | for fn in os.listdir(args.input_image_dir): 69 | if not fn.endswith('.png'): continue 70 | idx = int(os.path.splitext(fn)[0].split('_')[-1]) 71 | input_paths.append((os.path.join(args.input_image_dir, fn), idx)) 72 | idx_set.add(idx) 73 | input_paths.sort(key=lambda x: x[1]) 74 | assert len(idx_set) == len(input_paths) 75 | assert min(idx_set) == 0 and max(idx_set) == len(idx_set) - 1 76 | if args.max_images is not None: 77 | input_paths = input_paths[:args.max_images] 78 | print(input_paths[0]) 79 | print(input_paths[-1]) 80 | 81 | model = build_model(args) 82 | 83 | img_size = (args.image_height, args.image_width) 84 | with h5py.File(args.output_h5_file, 'w') as f: 85 | feat_dset = None 86 | i0 = 0 87 | cur_batch = [] 88 | for i, (path, idx) in enumerate(input_paths): 89 | img = imread(path, mode='RGB') 90 | img = imresize(img, img_size, interp='bicubic') 91 | img = img.transpose(2, 0, 1)[None] 92 | cur_batch.append(img) 93 | if len(cur_batch) == args.batch_size: 94 | feats = run_batch(cur_batch, model) 95 | if feat_dset is None: 96 | N = len(input_paths) 97 | _, C, H, W = feats.shape 98 | feat_dset = f.create_dataset('features', (N, C, H, W), 99 | dtype=np.float32) 100 | i1 = i0 + len(cur_batch) 101 | feat_dset[i0:i1] = feats 102 | i0 = i1 103 | print('Processed %d / %d images' % (i1, len(input_paths))) 104 | cur_batch = [] 105 | if len(cur_batch) > 0: 106 | feats = run_batch(cur_batch, model) 107 | i1 = i0 + len(cur_batch) 108 | feat_dset[i0:i1] = feats 109 | print('Processed %d / %d images' % (i1, len(input_paths))) 110 | 111 | 112 | if __name__ == '__main__': 113 | args = parser.parse_args() 114 | main(args) -------------------------------------------------------------------------------- /mac_cell.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import ops 6 | from config import config 7 | 8 | MACCellTuple = collections.namedtuple("MACCellTuple", ("control", "memory")) 9 | 10 | ''' 11 | The MAC cell. 12 | 13 | Recurrent cell for multi-step reasoning. Presented in https://arxiv.org/abs/1803.03067. 14 | The cell has recurrent control and memory states that interact with the question 15 | and knowledge base (image) respectively. 16 | 17 | The hidden state structure is MACCellTuple(control, memory) 18 | 19 | At each step the cell performs by calling to three subunits: control, read and write. 20 | 21 | 1. The Control Unit computes the control state by computing attention over the question words. 22 | The control state represents the current reasoning operation the cell performs. 23 | 24 | 2. The Read Unit retrieves information from the knowledge base, given the control and previous 25 | memory values, by computing 2-stages attention over the knowledge base. 26 | 27 | 3. The Write Unit integrates the retrieved information to the previous hidden memory state, 28 | given the value of the control state, to perform the current reasoning operation. 29 | ''' 30 | class MACCell(tf.nn.rnn_cell.RNNCell): 31 | 32 | '''Initialize the MAC cell. 33 | (Note that in the current version the cell is stateful -- 34 | updating its own internals when being called) 35 | 36 | Args: 37 | vecQuestions: the vector representation of the questions. 38 | [batchSize, ctrlDim] 39 | 40 | questionWords: the question words embeddings. 41 | [batchSize, questionLength, ctrlDim] 42 | 43 | questionCntxWords: the encoder outputs -- the "contextual" question words. 44 | [batchSize, questionLength, ctrlDim] 45 | 46 | questionLengths: the length of each question. 47 | [batchSize] 48 | 49 | memoryDropout: dropout on the memory state (Tensor scalar). 50 | readDropout: dropout inside the read unit (Tensor scalar). 51 | writeDropout: dropout on the new information that gets into the write unit (Tensor scalar). 52 | 53 | batchSize: batch size (Tensor scalar). 54 | train: train or test mod (Tensor boolean). 55 | reuse: reuse cell 56 | 57 | knowledgeBase: 58 | ''' 59 | def __init__(self, vecQuestions, questionWords, questionCntxWords, questionLengths, 60 | knowledgeBase, memoryDropout, readDropout, writeDropout, 61 | batchSize, train, reuse = None): 62 | 63 | self.vecQuestions = vecQuestions 64 | self.questionWords = questionWords 65 | self.questionCntxWords = questionCntxWords 66 | self.questionLengths = questionLengths 67 | 68 | self.knowledgeBase = knowledgeBase 69 | 70 | self.dropouts = {} 71 | self.dropouts["memory"] = memoryDropout 72 | self.dropouts["read"] = readDropout 73 | self.dropouts["write"] = writeDropout 74 | 75 | self.none = tf.zeros((batchSize, 1), dtype = tf.float32) 76 | 77 | self.batchSize = batchSize 78 | self.train = train 79 | self.reuse = reuse 80 | 81 | ''' 82 | Cell state size. 83 | ''' 84 | @property 85 | def state_size(self): 86 | return MACCellTuple(config.ctrlDim, config.memDim) 87 | 88 | ''' 89 | Cell output size. Currently it doesn't have any outputs. 90 | ''' 91 | @property 92 | def output_size(self): 93 | return 1 94 | 95 | # pass encoder hidden states to control? 96 | ''' 97 | The Control Unit: computes the new control state -- the reasoning operation, 98 | by summing up the word embeddings according to a computed attention distribution. 99 | 100 | The unit is recurrent: it receives the whole question and the previous control state, 101 | merge them together (resulting in the "continuous control"), and then uses that 102 | to compute attentions over the question words. Finally, it combines the words 103 | together according to the attention distribution to get the new control state. 104 | 105 | Args: 106 | controlInput: external inputs to control unit (the question vector). 107 | [batchSize, ctrlDim] 108 | 109 | inWords: the representation of the words used to compute the attention. 110 | [batchSize, questionLength, ctrlDim] 111 | 112 | outWords: the representation of the words that are summed up. 113 | (by default inWords == outWords) 114 | [batchSize, questionLength, ctrlDim] 115 | 116 | questionLengths: the length of each question. 117 | [batchSize] 118 | 119 | control: the previous control hidden state value. 120 | [batchSize, ctrlDim] 121 | 122 | contControl: optional corresponding continuous control state 123 | (before casting the attention over the words). 124 | [batchSize, ctrlDim] 125 | 126 | Returns: 127 | the new control state 128 | [batchSize, ctrlDim] 129 | 130 | the continuous (pre-attention) control 131 | [batchSize, ctrlDim] 132 | ''' 133 | def control(self, controlInput, inWords, outWords, questionLengths, 134 | control, contControl = None, name = "", reuse = None): 135 | 136 | with tf.variable_scope("control" + name, reuse = reuse): 137 | dim = config.ctrlDim 138 | 139 | ## Step 1: compute "continuous" control state given previous control and question. 140 | # control inputs: question and previous control 141 | newContControl = controlInput 142 | if config.controlFeedPrev: 143 | newContControl = control if config.controlFeedPrevAtt else contControl 144 | if config.controlFeedInputs: 145 | newContControl = tf.concat([newContControl, controlInput], axis = -1) 146 | dim += config.ctrlDim 147 | 148 | # merge inputs together 149 | newContControl = ops.linear(newContControl, dim, config.ctrlDim, 150 | act = config.controlContAct, name = "contControl") 151 | dim = config.ctrlDim 152 | 153 | ## Step 2: compute attention distribution over words and sum them up accordingly. 154 | # compute interactions with question words 155 | interactions = tf.expand_dims(newContControl, axis = 1) * inWords 156 | 157 | # optionally concatenate words 158 | if config.controlConcatWords: 159 | interactions = tf.concat([interactions, inWords], axis = -1) 160 | dim += config.ctrlDim 161 | 162 | # optional projection 163 | if config.controlProj: 164 | interactions = ops.linear(interactions, dim, config.ctrlDim, 165 | act = config.controlProjAct) 166 | dim = config.ctrlDim 167 | 168 | # compute attention distribution over words and summarize them accordingly 169 | logits = ops.inter2logits(interactions, dim) 170 | # self.interL = (interW, interb) 171 | 172 | # if config.controlCoverage: 173 | # logits += coverageBias * coverage 174 | 175 | attention = tf.nn.softmax(ops.expMask(logits, questionLengths)) 176 | self.attentions["question"].append(attention) 177 | 178 | # if config.controlCoverage: 179 | # coverage += attention # Add logits instead? 180 | 181 | newControl = ops.att2Smry(attention, outWords) 182 | 183 | # ablation: use continuous control (pre-attention) instead 184 | if config.controlContinuous: 185 | newControl = newContControl 186 | 187 | return newControl, newContControl 188 | 189 | ''' 190 | The read unit extracts relevant information from the knowledge base given the 191 | cell's memory and control states. It computes attention distribution over 192 | the knowledge base by comparing it first to the memory and then to the control. 193 | Finally, it uses the attention distribution to sum up the knowledge base accordingly, 194 | resulting in an extraction of relevant information. 195 | 196 | Args: 197 | knowledge base: representation of the knowledge base (image). 198 | [batchSize, kbSize (Height * Width), memDim] 199 | 200 | memory: the cell's memory state 201 | [batchSize, memDim] 202 | 203 | control: the cell's control state 204 | [batchSize, ctrlDim] 205 | 206 | Returns the information extracted. 207 | [batchSize, memDim] 208 | ''' 209 | def read(self, knowledgeBase, memory, control, name = "", reuse = None): 210 | with tf.variable_scope("read" + name, reuse = reuse): 211 | dim = config.memDim 212 | 213 | ## memory dropout 214 | if config.memoryVariationalDropout: 215 | memory = ops.applyVarDpMask(memory, self.memDpMask, self.dropouts["memory"]) 216 | else: 217 | memory = tf.nn.dropout(memory, self.dropouts["memory"]) 218 | 219 | ## Step 1: knowledge base / memory interactions 220 | # parameters for knowledge base and memory projection 221 | proj = None 222 | if config.readProjInputs: 223 | proj = {"dim": config.attDim, "shared": config.readProjShared, "dropout": self.dropouts["read"] } 224 | dim = config.attDim 225 | 226 | # parameters for concatenating knowledge base elements 227 | concat = {"x": config.readMemConcatKB, "proj": config.readMemConcatProj} 228 | 229 | # compute interactions between knowledge base and memory 230 | interactions, interDim = ops.mul(x = knowledgeBase, y = memory, dim = config.memDim, 231 | proj = proj, concat = concat, interMod = config.readMemAttType, name = "memInter") 232 | 233 | projectedKB = proj.get("x") if proj else None 234 | 235 | # project memory interactions back to hidden dimension 236 | if config.readMemProj: 237 | interactions = ops.linear(interactions, interDim, dim, act = config.readMemAct, 238 | name = "memKbProj") 239 | else: 240 | dim = interDim 241 | 242 | ## Step 2: compute interactions with control 243 | if config.readCtrl: 244 | # compute interactions with control 245 | if config.ctrlDim != dim: 246 | control = ops.linear(control, ctrlDim, dim, name = "ctrlProj") 247 | 248 | interactions, interDim = ops.mul(interactions, control, dim, 249 | interMod = config.readCtrlAttType, concat = {"x": config.readCtrlConcatInter}, 250 | name = "ctrlInter") 251 | 252 | # optionally concatenate knowledge base elements 253 | if config.readCtrlConcatKB: 254 | if config.readCtrlConcatProj: 255 | addedInp, addedDim = projectedKB, config.attDim 256 | else: 257 | addedInp, addedDim = knowledgeBase, config.memDim 258 | interactions = tf.concat([interactions, addedInp], axis = -1) 259 | dim += addedDim 260 | 261 | # optional nonlinearity 262 | interactions = ops.activations[config.readCtrlAct](interactions) 263 | 264 | ## Step 3: sum attentions up over the knowledge base 265 | # transform vectors to attention distribution 266 | attention = ops.inter2att(interactions, dim, dropout = self.dropouts["read"]) 267 | 268 | self.attentions["kb"].append(attention) 269 | 270 | # optionally use projected knowledge base instead of original 271 | if config.readSmryKBProj: 272 | knowledgeBase = projectedKB 273 | 274 | # sum up the knowledge base according to the distribution 275 | information = ops.att2Smry(attention, knowledgeBase) 276 | 277 | return information 278 | 279 | ''' 280 | The write unit integrates newly retrieved information (from the read unit), 281 | with the cell's previous memory hidden state, resulting in a new memory value. 282 | The unit optionally supports: 283 | 1. Self-attention to previous control / memory states, in order to consider previous steps 284 | in the reasoning process. 285 | 2. Gating between the new memory and previous memory states, to allow dynamic adjustment 286 | of the reasoning process length. 287 | 288 | Args: 289 | memory: the cell's memory state 290 | [batchSize, memDim] 291 | 292 | info: the information to integrate with the memory 293 | [batchSize, memDim] 294 | 295 | control: the cell's control state 296 | [batchSize, ctrlDim] 297 | 298 | contControl: optional corresponding continuous control state 299 | (before casting the attention over the words). 300 | [batchSize, ctrlDim] 301 | 302 | Return the new memory 303 | [batchSize, memDim] 304 | ''' 305 | def write(self, memory, info, control, contControl = None, name = "", reuse = None): 306 | with tf.variable_scope("write" + name, reuse = reuse): 307 | 308 | # optionally project info 309 | if config.writeInfoProj: 310 | info = ops.linear(info, config.memDim, config.memDim, name = "info") 311 | 312 | # optional info nonlinearity 313 | info = ops.activations[config.writeInfoAct](info) 314 | 315 | # compute self-attention vector based on previous controls and memories 316 | if config.writeSelfAtt: 317 | selfControl = control 318 | if config.writeSelfAttMod == "CONT": 319 | selfControl = contControl 320 | # elif config.writeSelfAttMod == "POST": 321 | # selfControl = postControl 322 | selfControl = ops.linear(selfControl, config.ctrlDim, config.ctrlDim, name = "ctrlProj") 323 | 324 | interactions = self.controls * tf.expand_dims(selfControl, axis = 1) 325 | 326 | # if config.selfAttShareInter: 327 | # selfAttlogits = self.linearP(selfAttInter, config.encDim, 1, self.interL[0], self.interL[1], name = "modSelfAttInter") 328 | attention = ops.inter2att(interactions, config.ctrlDim, name = "selfAttention") 329 | self.attentions["self"].append(attention) 330 | selfSmry = ops.att2Smry(attention, self.memories) 331 | 332 | # get write unit inputs: previous memory, the new info, optionally self-attention / control 333 | newMemory, dim = memory, config.memDim 334 | if config.writeInputs == "INFO": 335 | newMemory = info 336 | elif config.writeInputs == "SUM": 337 | newMemory += info 338 | elif config.writeInputs == "BOTH": 339 | newMemory, dim = ops.concat(newMemory, info, dim, mul = config.writeConcatMul) 340 | # else: MEM 341 | 342 | if config.writeSelfAtt: 343 | newMemory = tf.concat([newMemory, selfSmry], axis = -1) 344 | dim += config.memDim 345 | 346 | if config.writeMergeCtrl: 347 | newMemory = tf.concat([newMemory, control], axis = -1) 348 | dim += config.memDim 349 | 350 | # project memory back to memory dimension 351 | if config.writeMemProj or (dim != config.memDim): 352 | newMemory = ops.linear(newMemory, dim, config.memDim, name = "newMemory") 353 | 354 | # optional memory nonlinearity 355 | newMemory = ops.activations[config.writeMemAct](newMemory) 356 | 357 | # write unit gate 358 | if config.writeGate: 359 | gateDim = config.memDim 360 | if config.writeGateShared: 361 | gateDim = 1 362 | 363 | z = tf.sigmoid(ops.linear(control, config.ctrlDim, gateDim, name = "gate", bias = config.writeGateBias)) 364 | 365 | self.attentions["gate"].append(z) 366 | 367 | newMemory = newMemory * z + memory * (1 - z) 368 | 369 | # optional batch normalization 370 | if config.memoryBN: 371 | newMemory = tf.contrib.layers.batch_norm(newMemory, decay = config.bnDecay, 372 | center = config.bnCenter, scale = config.bnScale, 373 | is_training = self.train, updates_collections = None) 374 | 375 | return newMemory 376 | 377 | def memAutoEnc(newMemory, info, control, name = "", reuse = None): 378 | with tf.variable_scope("memAutoEnc" + name, reuse = reuse): 379 | # inputs to auto encoder 380 | features = info if config.autoEncMemInputs == "INFO" else newMemory 381 | features = ops.linear(features, config.memDim, config.ctrlDim, 382 | act = config.autoEncMemAct, name = "aeMem") 383 | 384 | # reconstruct control 385 | if config.autoEncMemLoss == "CONT": 386 | loss = tf.reduce_mean(tf.squared_difference(control, features)) 387 | else: 388 | interactions, dim = ops.mul(self.questionCntxWords, features, config.ctrlDim, 389 | concat = {"x": config.autoEncMemCnct}, mulBias = config.mulBias, name = "aeMem") 390 | 391 | logits = ops.inter2logits(interactions, dim) 392 | logits = self.expMask(logits, self.questionLengths) 393 | 394 | # reconstruct word attentions 395 | if config.autoEncMemLoss == "PROB": 396 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 397 | labels = self.attentions["question"][-1], logits = logits)) 398 | 399 | # reconstruct control through words attentions 400 | else: 401 | attention = tf.nn.softmax(logits) 402 | summary = ops.att2Smry(attention, self.questionCntxWords) 403 | loss = tf.reduce_mean(tf.squared_difference(control, summary)) 404 | 405 | return loss 406 | 407 | ''' 408 | Call the cell to get new control and memory states. 409 | 410 | Args: 411 | inputs: in the current implementation the cell don't get recurrent inputs 412 | every iteration (argument for comparability with rnn interface). 413 | 414 | state: the cell current state (control, memory) 415 | MACCellTuple([batchSize, ctrlDim],[batchSize, memDim]) 416 | 417 | Returns the new state -- the new memory and control values. 418 | MACCellTuple([batchSize, ctrlDim],[batchSize, memDim]) 419 | ''' 420 | def __call__(self, inputs, state, scope = None): 421 | scope = scope or type(self).__name__ 422 | with tf.variable_scope(scope, reuse = self.reuse): # as tfscope 423 | control = state.control 424 | memory = state.memory 425 | 426 | # cell sharing 427 | inputName = "qInput" 428 | inputNameU = "qInputU" 429 | inputReuseU = inputReuse = (self.iteration > 0) 430 | if config.controlInputUnshared: 431 | inputNameU = "qInput%d" % self.iteration 432 | inputReuseU = None 433 | 434 | cellName = "" 435 | cellReuse = (self.iteration > 0) 436 | if config.unsharedCells: 437 | cellName = str(self.iteration) 438 | cellReuse = None 439 | 440 | ## control unit 441 | # prepare question input to control 442 | controlInput = ops.linear(self.vecQuestions, config.ctrlDim, config.ctrlDim, 443 | name = inputName, reuse = inputReuse) 444 | 445 | controlInput = ops.activations[config.controlInputAct](controlInput) 446 | 447 | controlInput = ops.linear(controlInput, config.ctrlDim, config.ctrlDim, 448 | name = inputNameU, reuse = inputReuseU) 449 | 450 | newControl, self.contControl = self.control(controlInput, self.inWords, self.outWords, 451 | self.questionLengths, control, self.contControl, name = cellName, reuse = cellReuse) 452 | 453 | # read unit 454 | # ablation: use whole question as control 455 | if config.controlWholeQ: 456 | newControl = self.vecQuestions 457 | # ops.linear(self.vecQuestions, config.ctrlDim, projDim, name = "qMod") 458 | 459 | info = self.read(self.knowledgeBase, memory, newControl, name = cellName, reuse = cellReuse) 460 | 461 | if config.writeDropout < 1.0: 462 | # write unit 463 | info = tf.nn.dropout(info, self.dropouts["write"]) 464 | 465 | newMemory = self.write(memory, info, newControl, self.contControl, name = cellName, reuse = cellReuse) 466 | 467 | # add auto encoder loss for memory 468 | # if config.autoEncMem: 469 | # self.autoEncLosses["memory"] += memAutoEnc(newMemory, info, newControl) 470 | 471 | # append as standard list? 472 | self.controls = tf.concat([self.controls, tf.expand_dims(newControl, axis = 1)], axis = 1) 473 | self.memories = tf.concat([self.memories, tf.expand_dims(newMemory, axis = 1)], axis = 1) 474 | self.infos = tf.concat([self.infos, tf.expand_dims(info, axis = 1)], axis = 1) 475 | 476 | # self.contControls = tf.concat([self.contControls, tf.expand_dims(contControl, axis = 1)], axis = 1) 477 | # self.postControls = tf.concat([self.controls, tf.expand_dims(postControls, axis = 1)], axis = 1) 478 | 479 | newState = MACCellTuple(newControl, newMemory) 480 | return self.none, newState 481 | 482 | ''' 483 | Initializes the a hidden state to based on the value of the initType: 484 | "PRM" for parametric initialization 485 | "ZERO" for zero initialization 486 | "Q" to initialize to question vectors. 487 | 488 | Args: 489 | name: the state variable name. 490 | dim: the dimension of the state. 491 | initType: the type of the initialization 492 | batchSize: the batch size 493 | 494 | Returns the initialized hidden state. 495 | ''' 496 | def initState(self, name, dim, initType, batchSize): 497 | if initType == "PRM": 498 | prm = tf.get_variable(name, shape = (dim, ), 499 | initializer = tf.random_normal_initializer()) 500 | initState = tf.tile(tf.expand_dims(prm, axis = 0), [batchSize, 1]) 501 | elif initType == "ZERO": 502 | initState = tf.zeros((batchSize, dim), dtype = tf.float32) 503 | else: # "Q" 504 | initState = self.vecQuestions 505 | return initState 506 | 507 | ''' 508 | Add a parametric null word to the questions. 509 | 510 | Args: 511 | words: the words to add a null word to. 512 | [batchSize, questionLentgth] 513 | 514 | lengths: question lengths. 515 | [batchSize] 516 | 517 | Returns the updated word sequence and lengths. 518 | ''' 519 | def addNullWord(words, lengths): 520 | nullWord = tf.get_variable("zeroWord", shape = (1 , config.ctrlDim), initializer = tf.random_normal_initializer()) 521 | nullWord = tf.tile(tf.expand_dims(nullWord, axis = 0), [self.batchSize, 1, 1]) 522 | words = tf.concat([nullWord, words], axis = 1) 523 | lengths += 1 524 | return words, lengths 525 | 526 | ''' 527 | Initializes the cell internal state (currently it's stateful). In particular, 528 | 1. Data-structures (lists of attention maps and accumulated losses). 529 | 2. The memory and control states. 530 | 3. The knowledge base (optionally merging it with the question vectors) 531 | 4. The question words used by the cell (either the original word embeddings, or the 532 | encoder outputs, with optional projection). 533 | 534 | Args: 535 | batchSize: the batch size 536 | 537 | Returns the initial cell state. 538 | ''' 539 | def zero_state(self, batchSize, dtype = tf.float32): 540 | ## initialize data-structures 541 | self.attentions = {"kb": [], "question": [], "self": [], "gate": []} 542 | self.autoEncLosses = {"control": tf.constant(0.0), "memory": tf.constant(0.0)} 543 | 544 | 545 | ## initialize state 546 | initialControl = self.initState("initCtrl", config.ctrlDim, config.initCtrl, batchSize) 547 | initialMemory = self.initState("initMem", config.memDim, config.initMem, batchSize) 548 | 549 | self.controls = tf.expand_dims(initialControl, axis = 1) 550 | self.memories = tf.expand_dims(initialMemory, axis = 1) 551 | self.infos = tf.expand_dims(initialMemory, axis = 1) 552 | 553 | self.contControl = initialControl 554 | # self.contControls = tf.expand_dims(initialControl, axis = 1) 555 | # self.postControls = tf.expand_dims(initialControl, axis = 1) 556 | 557 | 558 | ## initialize knowledge base 559 | # optionally merge question into knowledge base representation 560 | if config.initKBwithQ != "NON": 561 | iVecQuestions = ops.linear(self.vecQuestions, config.ctrlDim, config.memDim, name = "questions") 562 | 563 | concatMul = (config.initKBwithQ == "MUL") 564 | cnct, dim = ops.concat(self.knowledgeBase, iVecQuestions, config.memDim, mul = concatMul, expandY = True) 565 | self.knowledgeBase = ops.linear(cnct, dim, config.memDim, name = "initKB") 566 | 567 | 568 | ## initialize question words 569 | # choose question words to work with (original embeddings or encoder outputs) 570 | words = self.questionCntxWords if config.controlContextual else self.questionWords 571 | 572 | # optionally add parametric "null" word in the to all questions 573 | if config.addNullWord: 574 | words, questionLengths = self.addNullWord(words, questionLengths) 575 | 576 | # project words 577 | self.inWords = self.outWords = words 578 | if config.controlInWordsProj or config.controlOutWordsProj: 579 | pWords = ops.linear(words, config.ctrlDim, config.ctrlDim, name = "wordsProj") 580 | self.inWords = pWords if config.controlInWordsProj else words 581 | self.outWords = pWords if config.controlOutWordsProj else words 582 | 583 | # if config.controlCoverage: 584 | # self.coverage = tf.zeros((batchSize, tf.shape(words)[1]), dtype = tf.float32) 585 | # self.coverageBias = tf.get_variable("coverageBias", shape = (), 586 | # initializer = config.controlCoverageBias) 587 | 588 | ## initialize memory variational dropout mask 589 | if config.memoryVariationalDropout: 590 | self.memDpMask = ops.generateVarDpMask((batchSize, config.memDim), self.dropouts["memory"]) 591 | 592 | return MACCellTuple(initialControl, initialMemory) 593 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import warnings 3 | warnings.filterwarnings("ignore", category=FutureWarning) 4 | warnings.filterwarnings("ignore", message="size changed") 5 | import sys 6 | import os 7 | import time 8 | import math 9 | import random 10 | try: 11 | import Queue as queue 12 | except ImportError: 13 | import queue 14 | import threading 15 | import h5py 16 | import json 17 | import numpy as np 18 | import tensorflow as tf 19 | from termcolor import colored, cprint 20 | 21 | from config import config, loadDatasetConfig, parseArgs 22 | from preprocess import Preprocesser, bold, bcolored, writeline, writelist 23 | from model import MACnet 24 | from collections import defaultdict 25 | 26 | ############################################# loggers ############################################# 27 | 28 | # Writes log header to file 29 | def logInit(): 30 | with open(config.logFile(), "a+") as outFile: 31 | writeline(outFile, config.expName) 32 | headers = ["epoch", "trainAcc", "valAcc", "trainLoss", "valLoss"] 33 | if config.evalTrain: 34 | headers += ["evalTrainAcc", "evalTrainLoss"] 35 | if config.extra: 36 | if config.evalTrain: 37 | headers += ["thAcc", "thLoss"] 38 | headers += ["vhAcc", "vhLoss"] 39 | headers += ["time", "lr"] 40 | 41 | writelist(outFile, headers) 42 | # lr assumed to be last 43 | 44 | # Writes log record to file 45 | def logRecord(epoch, epochTime, lr, trainRes, evalRes, extraEvalRes): 46 | with open(config.logFile(), "a+") as outFile: 47 | record = [epoch, trainRes["acc"], evalRes["val"]["acc"], trainRes["loss"], evalRes["val"]["loss"]] 48 | if config.evalTrain: 49 | record += [evalRes["evalTrain"]["acc"], evalRes["evalTrain"]["loss"]] 50 | if config.extra: 51 | if config.evalTrain: 52 | record += [extraEvalRes["evalTrain"]["acc"], extraEvalRes["evalTrain"]["loss"]] 53 | record += [extraEvalRes["val"]["acc"], extraEvalRes["val"]["loss"]] 54 | record += [epochTime, lr] 55 | 56 | writelist(outFile, record) 57 | 58 | # Gets last logged epoch and learning rate 59 | def lastLoggedEpoch(): 60 | with open(config.logFile(), "r") as inFile: 61 | lastLine = list(inFile)[-1].split(",") 62 | epoch = int(lastLine[0]) 63 | lr = float(lastLine[-1]) 64 | return epoch, lr 65 | 66 | ################################## printing, output and analysis ################################## 67 | 68 | # Analysis by type 69 | analysisQuestionLims = [(0,18),(19,float("inf"))] 70 | analysisProgramLims = [(0,12),(13,float("inf"))] 71 | 72 | toArity = lambda instance: instance["programSeq"][-1].split("_", 1)[0] 73 | toType = lambda instance: instance["programSeq"][-1].split("_", 1)[1] 74 | 75 | def fieldLenIsInRange(field): 76 | return lambda instance, group: \ 77 | (len(instance[field]) >= group[0] and 78 | len(instance[field]) <= group[1]) 79 | 80 | # Groups instances based on a key 81 | def grouperKey(toKey): 82 | def grouper(instances): 83 | res = defaultdict(list) 84 | for instance in instances: 85 | res[toKey(instance)].append(instance) 86 | return res 87 | return grouper 88 | 89 | # Groups instances according to their match to condition 90 | def grouperCond(groups, isIn): 91 | def grouper(instances): 92 | res = {} 93 | for group in groups: 94 | res[group] = (instance for instance in instances if isIn(instance, group)) 95 | return res 96 | return grouper 97 | 98 | groupers = { 99 | "questionLength": grouperCond(analysisQuestionLims, fieldLenIsInRange("questionSeq")), 100 | "programLength": grouperCond(analysisProgramLims, fieldLenIsInRange("programSeq")), 101 | "arity": grouperKey(toArity), 102 | "type": grouperKey(toType) 103 | } 104 | 105 | # Computes average 106 | def avg(instances, field): 107 | if len(instances) == 0: 108 | return 0.0 109 | return sum(instances[field]) / len(instances) 110 | 111 | # Prints analysis of questions loss and accuracy by their group 112 | def printAnalysis(res): 113 | if config.analysisType != "": 114 | print("Analysis by {type}".format(type = config.analysisType)) 115 | groups = groupers[config.analysisType](res["preds"]) 116 | for key in groups: 117 | instances = groups[key] 118 | avgLoss = avg(instances, "loss") 119 | avgAcc = avg(instances, "acc") 120 | num = len(instances) 121 | print("Group {key}: Loss: {loss}, Acc: {acc}, Num: {num}".format(key, avgLoss, avgAcc, num)) 122 | 123 | # Print results for a tier 124 | def printTierResults(tierName, res, color): 125 | if res is None: 126 | return 127 | 128 | print("{tierName} Loss: {loss}, {tierName} accuracy: {acc}".format(tierName = tierName, 129 | loss = bcolored(res["loss"], color), 130 | acc = bcolored(res["acc"], color))) 131 | 132 | printAnalysis(res) 133 | 134 | # Prints dataset results (for several tiers) 135 | def printDatasetResults(trainRes, evalRes, extraEvalRes): 136 | printTierResults("Training", trainRes, "magenta") 137 | printTierResults("Training EMA", evalRes["evalTrain"], "red") 138 | printTierResults("Validation", evalRes["val"], "cyan") 139 | printTierResults("Extra Training EMA", extraEvalRes["evalTrain"], "red") 140 | printTierResults("Extra Validation", extraEvalRes["val"], "cyan") 141 | 142 | # Writes predictions for several tiers 143 | def writePreds(preprocessor, evalRes, extraEvalRes): 144 | preprocessor.writePreds(evalRes["evalTrain"], "evalTrain") 145 | preprocessor.writePreds(evalRes["val"], "val") 146 | preprocessor.writePreds(evalRes["test"], "test") 147 | preprocessor.writePreds(extraEvalRes["evalTrain"], "evalTrain", "H") 148 | preprocessor.writePreds(extraEvalRes["val"], "val", "H") 149 | preprocessor.writePreds(extraEvalRes["test"], "test", "H") 150 | 151 | ############################################# session ############################################# 152 | # Initializes TF session. Sets GPU memory configuration. 153 | def setSession(): 154 | sessionConfig = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False) 155 | if config.allowGrowth: 156 | sessionConfig.gpu_options.allow_growth = True 157 | if config.maxMemory < 1.0: 158 | sessionConfig.gpu_options.per_process_gpu_memory_fraction = config.maxMemory 159 | return sessionConfig 160 | 161 | ############################################## savers ############################################# 162 | # Initializes savers (standard, optional exponential-moving-average and optional for subset of variables) 163 | def setSavers(model): 164 | saver = tf.train.Saver(max_to_keep = config.weightsToKeep) 165 | 166 | subsetSaver = None 167 | if config.saveSubset: 168 | isRelevant = lambda var: any(s in var.name for s in config.varSubset) 169 | relevantVars = [var for var in tf.global_variables() if isRelevant(var)] 170 | subsetSaver = tf.train.Saver(relevantVars, max_to_keep = config.weightsToKeep, allow_empty = True) 171 | 172 | emaSaver = None 173 | if config.useEMA: 174 | emaSaver = tf.train.Saver(model.emaDict, max_to_keep = config.weightsToKeep) 175 | 176 | return { 177 | "saver": saver, 178 | "subsetSaver": subsetSaver, 179 | "emaSaver": emaSaver 180 | } 181 | 182 | ################################### restore / initialize weights ################################## 183 | # Restores weights of specified / last epoch if on restore mod. 184 | # Otherwise, initializes weights. 185 | def loadWeights(sess, saver, init): 186 | if config.restoreEpoch > 0 or config.restore: 187 | # restore last epoch only if restoreEpoch isn't set 188 | if config.restoreEpoch == 0: 189 | # restore last logged epoch 190 | config.restoreEpoch, config.lr = lastLoggedEpoch() 191 | print(bcolored("Restoring epoch {} and lr {}".format(config.restoreEpoch, config.lr),"cyan")) 192 | print(bcolored("Restoring weights", "blue")) 193 | saver.restore(sess, config.weightsFile(config.restoreEpoch)) 194 | epoch = config.restoreEpoch 195 | else: 196 | print(bcolored("Initializing weights", "blue")) 197 | sess.run(init) 198 | logInit() 199 | epoch = 0 200 | 201 | return epoch 202 | 203 | ###################################### training / evaluation ###################################### 204 | # Chooses data to train on (main / extra) data. 205 | def chooseTrainingData(data): 206 | trainingData = data["main"]["train"] 207 | alterData = None 208 | 209 | if config.extra: 210 | if config.trainExtra: 211 | if config.extraVal: 212 | trainingData = data["extra"]["val"] 213 | else: 214 | trainingData = data["extra"]["train"] 215 | if config.alterExtra: 216 | alterData = data["extra"]["train"] 217 | 218 | return trainingData, alterData 219 | 220 | #### evaluation 221 | # Runs evaluation on train / val / test datasets. 222 | def runEvaluation(sess, model, data, epoch, evalTrain = True, evalTest = False, getAtt = None): 223 | if getAtt is None: 224 | getAtt = config.getAtt 225 | res = {"evalTrain": None, "val": None, "test": None} 226 | 227 | if data is not None: 228 | if evalTrain and config.evalTrain: 229 | res["evalTrain"] = runEpoch(sess, model, data["evalTrain"], train = False, epoch = epoch, getAtt = getAtt) 230 | 231 | res["val"] = runEpoch(sess, model, data["val"], train = False, epoch = epoch, getAtt = getAtt) 232 | 233 | if evalTest or config.test: 234 | res["test"] = runEpoch(sess, model, data["test"], train = False, epoch = epoch, getAtt = getAtt) 235 | 236 | return res 237 | 238 | ## training conditions (comparing current epoch result to prior ones) 239 | def improveEnough(curr, prior, lr): 240 | prevRes = prior["prev"]["res"] 241 | currRes = curr["res"] 242 | 243 | if prevRes is None: 244 | return True 245 | 246 | prevTrainLoss = prevRes["train"]["loss"] 247 | currTrainLoss = currRes["train"]["loss"] 248 | lossDiff = prevTrainLoss - currTrainLoss 249 | 250 | notImprove = ((lossDiff < 0.015 and prevTrainLoss < 0.5 and lr > 0.00002) or \ 251 | (lossDiff < 0.008 and prevTrainLoss < 0.15 and lr > 0.00001) or \ 252 | (lossDiff < 0.003 and prevTrainLoss < 0.10 and lr > 0.000005)) 253 | #(prevTrainLoss < 0.2 and config.lr > 0.000015) 254 | 255 | return not notImprove 256 | 257 | def better(currRes, bestRes): 258 | return currRes["val"]["acc"] > bestRes["val"]["acc"] 259 | 260 | ############################################## data ############################################### 261 | #### instances and batching 262 | # Trims sequences based on their max length. 263 | def trim2DVectors(vectors, vectorsLengths): 264 | maxLength = np.max(vectorsLengths) 265 | return vectors[:,:maxLength] 266 | 267 | # Trims batch based on question length. 268 | def trimData(data): 269 | data["questions"] = trim2DVectors(data["questions"], data["questionLengths"]) 270 | return data 271 | 272 | # Gets batch / bucket size. 273 | def getLength(data): 274 | return len(data["instances"]) 275 | 276 | # Selects the data entries that match the indices. 277 | def selectIndices(data, indices): 278 | def select(field, indices): 279 | if type(field) is np.ndarray: 280 | return field[indices] 281 | if type(field) is list: 282 | return [field[i] for i in indices] 283 | else: 284 | return field 285 | selected = {k : select(d, indices) for k,d in data.items()} 286 | return selected 287 | 288 | # Batches data into a a list of batches of batchSize. 289 | # Shuffles the data by default. 290 | def getBatches(data, batchSize = None, shuffle = True): 291 | batches = [] 292 | 293 | dataLen = getLength(data) 294 | if batchSize is None or batchSize > dataLen: 295 | batchSize = dataLen 296 | 297 | indices = np.arange(dataLen) 298 | if shuffle: 299 | np.random.shuffle(indices) 300 | 301 | for batchStart in range(0, dataLen, batchSize): 302 | batchIndices = indices[batchStart : batchStart + batchSize] 303 | # if len(batchIndices) == batchSize? 304 | if len(batchIndices) >= config.gpusNum: 305 | batch = selectIndices(data, batchIndices) 306 | batches.append(batch) 307 | # batchesIndices.append((data, batchIndices)) 308 | 309 | return batches 310 | 311 | #### image batches 312 | # Opens image files. 313 | def openImageFiles(images): 314 | images["imagesFile"] = h5py.File(images["imagesFilename"], "r") 315 | images["imagesIds"] = None 316 | if config.dataset == "NLVR": 317 | with open(images["imageIdsFilename"], "r") as imageIdsFile: 318 | images["imagesIds"] = json.load(imageIdsFile) 319 | 320 | # Closes image files. 321 | def closeImageFiles(images): 322 | images["imagesFile"].close() 323 | 324 | # Loads an images from file for a given data batch. 325 | def loadImageBatch(images, batch): 326 | imagesFile = images["imagesFile"] 327 | id2idx = images["imagesIds"] 328 | 329 | toIndex = lambda imageId: imageId 330 | if id2idx is not None: 331 | toIndex = lambda imageId: id2idx[imageId] 332 | imageBatch = np.stack([imagesFile["features"][toIndex(imageId)] for imageId in batch["imageIds"]], axis = 0) 333 | 334 | return {"images": imageBatch, "imageIds": batch["imageIds"]} 335 | 336 | # Loads images for several num batches in the batches list from start index. 337 | def loadImageBatches(images, batches, start, num): 338 | batches = batches[start: start + num] 339 | return [loadImageBatch(images, batch) for batch in batches] 340 | 341 | #### data alternation 342 | # Alternates main training batches with extra data. 343 | def alternateData(batches, alterData, dataLen): 344 | alterData = alterData["data"][0] # data isn't bucketed for altered data 345 | 346 | # computes number of repetitions 347 | needed = math.ceil(len(batches) / config.alterNum) 348 | print(bold("Extra batches needed: %d") % needed) 349 | perData = math.ceil(getLength(alterData) / config.batchSize) 350 | print(bold("Batches per extra data: %d") % perData) 351 | repetitions = math.ceil(needed / perData) 352 | print(bold("reps: %d") % repetitions) 353 | 354 | # make alternate batches 355 | alterBatches = [] 356 | for _ in range(repetitions): 357 | repBatches = getBatches(alterData, batchSize = config.batchSize) 358 | random.shuffle(repBatches) 359 | alterBatches += repBatches 360 | print(bold("Batches num: %d") + len(alterBatches)) 361 | 362 | # alternate data with extra data 363 | curr = len(batches) - 1 364 | for alterBatch in alterBatches: 365 | if curr < 0: 366 | # print(colored("too many" + str(curr) + " " + str(len(batches)),"red")) 367 | break 368 | batches.insert(curr, alterBatch) 369 | dataLen += getLength(alterBatch) 370 | curr -= config.alterNum 371 | 372 | return batches, dataLen 373 | 374 | ############################################ threading ############################################ 375 | 376 | imagesQueue = queue.Queue(maxsize = 20) # config.tasksNum 377 | inQueue = queue.Queue(maxsize = 1) 378 | outQueue = queue.Queue(maxsize = 1) 379 | 380 | # Runs a worker thread(s) to load images while training . 381 | class StoppableThread(threading.Thread): 382 | # Thread class with a stop() method. The thread itself has to check 383 | # regularly for the stopped() condition. 384 | 385 | def __init__(self, images, batches): # i 386 | super(StoppableThread, self).__init__() 387 | # self.i = i 388 | self.images = images 389 | self.batches = batches 390 | self._stop_event = threading.Event() 391 | 392 | # def __init__(self, args): 393 | # super(StoppableThread, self).__init__(args = args) 394 | # self._stop_event = threading.Event() 395 | 396 | # def __init__(self, target, args): 397 | # super(StoppableThread, self).__init__(target = target, args = args) 398 | # self._stop_event = threading.Event() 399 | 400 | def stop(self): 401 | self._stop_event.set() 402 | 403 | def stopped(self): 404 | return self._stop_event.is_set() 405 | 406 | def run(self): 407 | while not self.stopped(): 408 | try: 409 | batchNum = inQueue.get(timeout = 60) 410 | nextItem = loadImageBatches(self.images, self.batches, batchNum, int(config.taskSize / 2)) 411 | outQueue.put(nextItem) 412 | # inQueue.task_done() 413 | except: 414 | pass 415 | # print("worker %d done", self.i) 416 | 417 | def loaderRun(images, batches): 418 | batchNum = 0 419 | 420 | # if config.workers == 2: 421 | # worker = StoppableThread(images, batches) # i, 422 | # worker.daemon = True 423 | # worker.start() 424 | 425 | # while batchNum < len(batches): 426 | # inQueue.put(batchNum + int(config.taskSize / 2)) 427 | # nextItem1 = loadImageBatches(images, batches, batchNum, int(config.taskSize / 2)) 428 | # nextItem2 = outQueue.get() 429 | 430 | # nextItem = nextItem1 + nextItem2 431 | # assert len(nextItem) == min(config.taskSize, len(batches) - batchNum) 432 | # batchNum += config.taskSize 433 | 434 | # imagesQueue.put(nextItem) 435 | 436 | # worker.stop() 437 | # else: 438 | while batchNum < len(batches): 439 | nextItem = loadImageBatches(images, batches, batchNum, config.taskSize) 440 | assert len(nextItem) == min(config.taskSize, len(batches) - batchNum) 441 | batchNum += config.taskSize 442 | imagesQueue.put(nextItem) 443 | 444 | # print("manager loader done") 445 | 446 | ########################################## stats tracking ######################################### 447 | # Computes exponential moving average. 448 | def emaAvg(avg, value): 449 | if avg is None: 450 | return value 451 | emaRate = 0.98 452 | return avg * emaRate + value * (1 - emaRate) 453 | 454 | # Initializes training statistics. 455 | def initStats(): 456 | return { 457 | "totalBatches": 0, 458 | "totalData": 0, 459 | "totalLoss": 0.0, 460 | "totalCorrect": 0, 461 | "loss": 0.0, 462 | "acc": 0.0, 463 | "emaLoss": None, 464 | "emaAcc": None, 465 | } 466 | 467 | # Updates statistics with training results of a batch 468 | def updateStats(stats, res, batch): 469 | stats["totalBatches"] += 1 470 | stats["totalData"] += getLength(batch) 471 | 472 | stats["totalLoss"] += res["loss"] 473 | stats["totalCorrect"] += res["correctNum"] 474 | 475 | stats["loss"] = stats["totalLoss"] / stats["totalBatches"] 476 | stats["acc"] = stats["totalCorrect"] / stats["totalData"] 477 | 478 | stats["emaLoss"] = emaAvg(stats["emaLoss"], res["loss"]) 479 | stats["emaAcc"] = emaAvg(stats["emaAcc"], res["acc"]) 480 | 481 | return stats 482 | 483 | # auto-encoder ae = {:2.4f} autoEncLoss, 484 | # Translates training statistics into a string to print 485 | def statsToStr(stats, res, epoch, batchNum, dataLen, startTime): 486 | formatStr = "\reb {epoch},{batchNum} ({dataProcessed} / {dataLen:5d}), " + \ 487 | "t = {time} ({loadTime:2.2f}+{trainTime:2.2f}), " + \ 488 | "lr {lr}, l = {loss}, a = {acc}, avL = {avgLoss}, " + \ 489 | "avA = {avgAcc}, g = {gradNorm:2.4f}, " + \ 490 | "emL = {emaLoss:2.4f}, emA = {emaAcc:2.4f}; " + \ 491 | "{expname}" # {machine}/{gpu}" 492 | 493 | s_epoch = bcolored("{:2d}".format(epoch),"green") 494 | s_batchNum = "{:3d}".format(batchNum) 495 | s_dataProcessed = bcolored("{:5d}".format(stats["totalData"]),"green") 496 | s_dataLen = dataLen 497 | s_time = bcolored("{:2.2f}".format(time.time() - startTime),"green") 498 | s_loadTime = res["readTime"] 499 | s_trainTime = res["trainTime"] 500 | s_lr = bold(config.lr) 501 | s_loss = bcolored("{:2.4f}".format(res["loss"]), "blue") 502 | s_acc = bcolored("{:2.4f}".format(res["acc"]),"blue") 503 | s_avgLoss = bcolored("{:2.4f}".format(stats["loss"]), "blue") 504 | s_avgAcc = bcolored("{:2.4f}".format(stats["acc"]),"red") 505 | s_gradNorm = res["gradNorm"] 506 | s_emaLoss = stats["emaLoss"] 507 | s_emaAcc = stats["emaAcc"] 508 | s_expname = config.expName 509 | # s_machine = bcolored(config.dataPath[9:11],"green") 510 | # s_gpu = bcolored(config.gpus,"green") 511 | 512 | return formatStr.format(epoch = s_epoch, batchNum = s_batchNum, dataProcessed = s_dataProcessed, 513 | dataLen = s_dataLen, time = s_time, loadTime = s_loadTime, 514 | trainTime = s_trainTime, lr = s_lr, loss = s_loss, acc = s_acc, 515 | avgLoss = s_avgLoss, avgAcc = s_avgAcc, gradNorm = s_gradNorm, 516 | emaLoss = s_emaLoss, emaAcc = s_emaAcc, expname = s_expname) 517 | # machine = s_machine, gpu = s_gpu) 518 | 519 | # collectRuntimeStats, writer = None, 520 | ''' 521 | Runs an epoch with model and session over the data. 522 | 1. Batches the data and optionally mix it with the extra alterData. 523 | 2. Start worker threads to load images in parallel to training. 524 | 3. Runs model for each batch, and gets results (e.g. loss, accuracy). 525 | 4. Updates and prints statistics based on batch results. 526 | 5. Once in a while (every config.saveEvery), save weights. 527 | 528 | Args: 529 | sess: TF session to run with. 530 | 531 | model: model to process data. Has runBatch method that process a given batch. 532 | (See model.py for further details). 533 | 534 | data: data to use for training/evaluation. 535 | 536 | epoch: epoch number. 537 | 538 | saver: TF saver to save weights 539 | 540 | calle: a method to call every number of iterations (config.calleEvery) 541 | 542 | alterData: extra data to mix with main data while training. 543 | 544 | getAtt: True to return model attentions. 545 | ''' 546 | def runEpoch(sess, model, data, train, epoch, saver = None, calle = None, 547 | alterData = None, getAtt = False): 548 | # train = data["train"] better than outside argument 549 | 550 | # initialization 551 | startTime0 = time.time() 552 | 553 | stats = initStats() 554 | preds = [] 555 | 556 | # open image files 557 | openImageFiles(data["images"]) 558 | 559 | ## prepare batches 560 | buckets = data["data"] 561 | dataLen = sum(getLength(bucket) for bucket in buckets) 562 | 563 | # make batches and randomize 564 | batches = [] 565 | for bucket in buckets: 566 | batches += getBatches(bucket, batchSize = config.batchSize) 567 | random.shuffle(batches) 568 | 569 | # alternate with extra data 570 | if train and alterData is not None: 571 | batches, dataLen = alternateData(batches, alterData, dataLen) 572 | 573 | # start image loaders 574 | if config.parallel: 575 | loader = threading.Thread(target = loaderRun, args = (data["images"], batches)) 576 | loader.daemon = True 577 | loader.start() 578 | 579 | for batchNum, batch in enumerate(batches): 580 | startTime = time.time() 581 | 582 | # prepare batch 583 | batch = trimData(batch) 584 | 585 | # load images batch 586 | if config.parallel: 587 | if batchNum % config.taskSize == 0: 588 | imagesBatches = imagesQueue.get() 589 | imagesBatch = imagesBatches[batchNum % config.taskSize] # len(imagesBatches) 590 | else: 591 | imagesBatch = loadImageBatch(data["images"], batch) 592 | for i, imageId in enumerate(batch["imageIds"]): 593 | assert imageId == imagesBatch["imageIds"][i] 594 | 595 | # run batch 596 | res = model.runBatch(sess, batch, imagesBatch, train, getAtt) 597 | 598 | # update stats 599 | stats = updateStats(stats, res, batch) 600 | preds += res["preds"] 601 | 602 | # if config.summerize and writer is not None: 603 | # writer.add_summary(res["summary"], epoch) 604 | 605 | sys.stdout.write(statsToStr(stats, res, epoch, batchNum, dataLen, startTime)) 606 | sys.stdout.flush() 607 | 608 | # save weights 609 | if saver is not None: 610 | if batchNum > 0 and batchNum % config.saveEvery == 0: 611 | print("") 612 | print(bold("saving weights")) 613 | saver.save(sess, config.weightsFile(epoch)) 614 | 615 | # calle 616 | if calle is not None: 617 | if batchNum > 0 and batchNum % config.calleEvery == 0: 618 | calle() 619 | 620 | sys.stdout.write("\r") 621 | sys.stdout.flush() 622 | 623 | print("") 624 | 625 | closeImageFiles(data["images"]) 626 | 627 | if config.parallel: 628 | loader.join() # should work 629 | 630 | return {"loss": stats["loss"], 631 | "acc": stats["acc"], 632 | "preds": preds 633 | } 634 | 635 | ''' 636 | Trains/evaluates the model: 637 | 1. Set GPU configurations. 638 | 2. Preprocess data: reads from datasets, and convert into numpy arrays. 639 | 3. Builds the TF computational graph for the MAC model. 640 | 4. Starts a session and initialize / restores weights. 641 | 5. If config.train is True, trains the model for number of epochs: 642 | a. Trains the model on training data 643 | b. Evaluates the model on training / validation data, optionally with 644 | exponential-moving-average weights. 645 | c. Prints and logs statistics, and optionally saves model predictions. 646 | d. Optionally reduces learning rate if losses / accuracies don't improve, 647 | and applies early stopping. 648 | 6. If config.test is True, runs a final evaluation on the dataset and print 649 | final results! 650 | ''' 651 | def main(): 652 | with open(config.configFile(), "a+") as outFile: 653 | json.dump(vars(config), outFile) 654 | 655 | # set gpus 656 | if config.gpus != "": 657 | config.gpusNum = len(config.gpus.split(",")) 658 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus 659 | 660 | tf.logging.set_verbosity(tf.logging.ERROR) 661 | 662 | # process data 663 | print(bold("Preprocess data...")) 664 | start = time.time() 665 | preprocessor = Preprocesser() 666 | data, embeddings, answerDict = preprocessor.preprocessData() 667 | print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue"))) 668 | 669 | # build model 670 | print(bold("Building model...")) 671 | start = time.time() 672 | model = MACnet(embeddings, answerDict) 673 | print("took {} seconds".format(bcolored("{:.2f}".format(time.time() - start), "blue"))) 674 | 675 | # initializer 676 | init = tf.global_variables_initializer() 677 | 678 | # savers 679 | savers = setSavers(model) 680 | saver, emaSaver = savers["saver"], savers["emaSaver"] 681 | 682 | # sessionConfig 683 | sessionConfig = setSession() 684 | 685 | with tf.Session(config = sessionConfig) as sess: 686 | 687 | # ensure no more ops are added after model is built 688 | sess.graph.finalize() 689 | 690 | # restore / initialize weights, initialize epoch variable 691 | epoch = loadWeights(sess, saver, init) 692 | 693 | if config.train: 694 | start0 = time.time() 695 | 696 | bestEpoch = epoch 697 | bestRes = None 698 | prevRes = None 699 | 700 | # epoch in [restored + 1, epochs] 701 | for epoch in range(config.restoreEpoch + 1, config.epochs + 1): 702 | print(bcolored("Training epoch {}...".format(epoch), "green")) 703 | start = time.time() 704 | 705 | # train 706 | # calle = lambda: model.runEpoch(), collectRuntimeStats, writer 707 | trainingData, alterData = chooseTrainingData(data) 708 | trainRes = runEpoch(sess, model, trainingData, train = True, epoch = epoch, 709 | saver = saver, alterData = alterData) 710 | 711 | # save weights 712 | saver.save(sess, config.weightsFile(epoch)) 713 | if config.saveSubset: 714 | subsetSaver.save(sess, config.subsetWeightsFile(epoch)) 715 | 716 | # load EMA weights 717 | if config.useEMA: 718 | print(bold("Restoring EMA weights")) 719 | emaSaver.restore(sess, config.weightsFile(epoch)) 720 | 721 | # evaluation 722 | evalRes = runEvaluation(sess, model, data["main"], epoch) 723 | extraEvalRes = runEvaluation(sess, model, data["extra"], epoch, 724 | evalTrain = not config.extraVal) 725 | 726 | # restore standard weights 727 | if config.useEMA: 728 | print(bold("Restoring standard weights")) 729 | saver.restore(sess, config.weightsFile(epoch)) 730 | 731 | print("") 732 | 733 | epochTime = time.time() - start 734 | print("took {:.2f} seconds".format(epochTime)) 735 | 736 | # print results 737 | printDatasetResults(trainRes, evalRes, extraEvalRes) 738 | 739 | # stores predictions and optionally attention maps 740 | if config.getPreds: 741 | print(bcolored("Writing predictions...", "white")) 742 | writePreds(preprocessor, evalRes, extraEvalRes) 743 | 744 | logRecord(epoch, epochTime, config.lr, trainRes, evalRes, extraEvalRes) 745 | 746 | # update best result 747 | # compute curr and prior 748 | currRes = {"train": trainRes, "val": evalRes["val"]} 749 | curr = {"res": currRes, "epoch": epoch} 750 | 751 | if bestRes is None or better(currRes, bestRes): 752 | bestRes = currRes 753 | bestEpoch = epoch 754 | 755 | prior = {"best": {"res": bestRes, "epoch": bestEpoch}, 756 | "prev": {"res": prevRes, "epoch": epoch - 1}} 757 | 758 | # lr reducing 759 | if config.lrReduce: 760 | if not improveEnough(curr, prior, config.lr): 761 | config.lr *= config.lrDecayRate 762 | print(colored("Reducing LR to {}".format(config.lr), "red")) 763 | 764 | # early stopping 765 | if config.earlyStopping > 0: 766 | if epoch - bestEpoch > config.earlyStopping: 767 | break 768 | 769 | # update previous result 770 | prevRes = currRes 771 | 772 | # reduce epoch back to the last one we trained on 773 | epoch -= 1 774 | print("Training took {:.2f} seconds ({:} epochs)".format(time.time() - start0, 775 | epoch - config.restoreEpoch)) 776 | 777 | if config.finalTest: 778 | print("Testing on epoch {}...".format(epoch)) 779 | 780 | start = time.time() 781 | if epoch > 0: 782 | if config.useEMA: 783 | emaSaver.restore(sess, config.weightsFile(epoch)) 784 | else: 785 | saver.restore(sess, config.weightsFile(epoch)) 786 | 787 | evalRes = runEvaluation(sess, model, data["main"], epoch, evalTest = True) 788 | extraEvalRes = runEvaluation(sess, model, data["extra"], epoch, 789 | evalTrain = not config.extraVal, evalTest = True) 790 | 791 | print("took {:.2f} seconds".format(time.time() - start)) 792 | printDatasetResults(None, evalRes, extraEvalRes) 793 | 794 | print("Writing predictions...") 795 | writePreds(preprocessor, evalRes, extraEvalRes) 796 | 797 | print(bcolored("Done!","white")) 798 | 799 | if __name__ == '__main__': 800 | parseArgs() 801 | loadDatasetConfig[config.dataset]() 802 | main() 803 | -------------------------------------------------------------------------------- /mi_gru_cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class MiGRUCell(tf.nn.rnn_cell.RNNCell): 5 | def __init__(self, num_units, input_size = None, activation = tf.tanh, reuse = None): 6 | self.numUnits = num_units 7 | self.activation = activation 8 | self.reuse = reuse 9 | 10 | @property 11 | def state_size(self): 12 | return self.numUnits 13 | 14 | @property 15 | def output_size(self): 16 | return self.numUnits 17 | 18 | def mulWeights(self, inp, inDim, outDim, name = ""): 19 | with tf.variable_scope("weights" + name): 20 | W = tf.get_variable("weights", shape = (inDim, outDim), 21 | initializer = tf.contrib.layers.xavier_initializer()) 22 | 23 | output = tf.matmul(inp, W) 24 | return output 25 | 26 | def addBiases(self, inp1, inp2, dim, bInitial = 0, name = ""): 27 | with tf.variable_scope("additiveBiases" + name): 28 | b = tf.get_variable("biases", shape = (dim,), 29 | initializer = tf.zeros_initializer()) + bInitial 30 | with tf.variable_scope("multiplicativeBias" + name): 31 | beta = tf.get_variable("biases", shape = (3 * dim,), 32 | initializer = tf.ones_initializer()) 33 | 34 | Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1), 35 | num_or_size_splits = 3, axis = 1) 36 | output = Wx + Uh + inter + b 37 | return output 38 | 39 | def __call__(self, inputs, state, scope = None): 40 | scope = scope or type(self).__name__ 41 | with tf.variable_scope(scope, reuse = self.reuse): 42 | inputSize = int(inputs.shape[1]) 43 | 44 | Wxr = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxr") 45 | Uhr = self.mulWeights(state, self.numUnits, self.numUnits, name = "Uhr") 46 | 47 | r = tf.nn.sigmoid(self.addBiases(Wxr, Uhr, self.numUnits, bInitial = 1, name = "r")) 48 | 49 | Wxu = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxu") 50 | Uhu = self.mulWeights(state, self.numUnits, self.numUnits, name = "Uhu") 51 | 52 | u = tf.nn.sigmoid(self.addBiases(Wxu, Uhu, self.numUnits, bInitial = 1, name = "u")) 53 | # r, u = tf.split(gates, num_or_size_splits = 2, axis = 1) 54 | 55 | Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxl") 56 | Urh = self.mulWeights(r * state, self.numUnits, self.numUnits, name = "Uhl") 57 | c = self.activation(self.addBiases(Wx, Urh, self.numUnits, name = "2")) 58 | 59 | newH = u * state + (1 - u) * c # switch u and 1-u? 60 | return newH, newH 61 | 62 | def zero_state(self, batchSize, dtype = tf.float32): 63 | return tf.zeros((batchSize, self.numUnits), dtype = dtype) 64 | -------------------------------------------------------------------------------- /mi_lstm_cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class MiLSTMCell(tf.nn.rnn_cell.RNNCell): 5 | def __init__(self, num_units, forget_bias = 1.0, input_size = None, 6 | state_is_tuple = True, activation = tf.tanh, reuse = None): 7 | self.numUnits = num_units 8 | self.forgetBias = forget_bias 9 | self.activation = activation 10 | self.reuse = reuse 11 | 12 | @property 13 | def state_size(self): 14 | return tf.nn.rnn_cell.LSTMStateTuple(self.numUnits, self.numUnits) 15 | 16 | @property 17 | def output_size(self): 18 | return self.numUnits 19 | 20 | def mulWeights(self, inp, inDim, outDim, name = ""): 21 | with tf.variable_scope("weights" + name): 22 | W = tf.get_variable("weights", shape = (inDim, outDim), 23 | initializer = tf.contrib.layers.xavier_initializer()) 24 | output = tf.matmul(inp, W) 25 | return output 26 | 27 | def addBiases(self, inp1, inp2, dim, name = ""): 28 | with tf.variable_scope("additiveBiases" + name): 29 | b = tf.get_variable("biases", shape = (dim,), 30 | initializer = tf.zeros_initializer()) 31 | with tf.variable_scope("multiplicativeBias" + name): 32 | beta = tf.get_variable("biases", shape = (3 * dim,), 33 | initializer = tf.ones_initializer()) 34 | 35 | Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1), 36 | num_or_size_splits = 3, axis = 1) 37 | output = Wx + Uh + inter + b 38 | return output 39 | 40 | def __call__(self, inputs, state, scope = None): 41 | scope = scope or type(self).__name__ 42 | with tf.variable_scope(scope, reuse = self.reuse): 43 | c, h = state 44 | inputSize = int(inputs.shape[1]) 45 | 46 | Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxi") 47 | Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhi") 48 | 49 | i = self.addBiases(Wx, Uh, self.numUnits, name = "i") 50 | 51 | Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxj") 52 | Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhj") 53 | 54 | j = self.addBiases(Wx, Uh, self.numUnits, name = "l") 55 | 56 | Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxf") 57 | Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhf") 58 | 59 | f = self.addBiases(Wx, Uh, self.numUnits, name = "f") 60 | 61 | Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxo") 62 | Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uho") 63 | 64 | o = self.addBiases(Wx, Uh, self.numUnits, name = "o") 65 | # i, j, f, o = tf.split(value = concat, num_or_size_splits = 4, axis = 1) 66 | 67 | newC = (c * tf.nn.sigmoid(f + self.forgetBias) + tf.nn.sigmoid(i) * 68 | self.activation(j)) 69 | newH = self.activation(newC) * tf.nn.sigmoid(o) 70 | 71 | newState = tf.nn.rnn_cell.LSTMStateTuple(newC, newH) 72 | return newH, newState 73 | 74 | def zero_state(self, batchSize, dtype = tf.float32): 75 | return tf.nn.rnn_cell.LSTMStateTuple(tf.zeros((batchSize, self.numUnits), dtype = dtype), 76 | tf.zeros((batchSize, self.numUnits), dtype = dtype)) 77 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import ops 7 | from config import config 8 | from mac_cell import MACCell 9 | ''' 10 | The MAC network model. It performs reasoning processes to answer a question over 11 | knowledge base (the image) by decomposing it into attention-based computational steps, 12 | each perform by a recurrent MAC cell. 13 | 14 | The network has three main components. 15 | Input unit: processes the network inputs: raw question strings and image into 16 | distributional representations. 17 | 18 | The MAC network: calls the MACcells (mac_cell.py) config.netLength number of times, 19 | to perform the reasoning process over the question and image. 20 | 21 | The output unit: a classifier that receives the question and final state of the MAC 22 | network and uses them to compute log-likelihood over the possible one-word answers. 23 | ''' 24 | class MACnet(object): 25 | 26 | '''Initialize the class. 27 | 28 | Args: 29 | embeddingsInit: initialization for word embeddings (random / glove). 30 | answerDict: answers dictionary (mapping between integer id and symbol). 31 | ''' 32 | def __init__(self, embeddingsInit, answerDict): 33 | self.embeddingsInit = embeddingsInit 34 | self.answerDict = answerDict 35 | self.build() 36 | 37 | ''' 38 | Initializes placeholders. 39 | questionsIndicesAll: integer ids of question words. 40 | [batchSize, questionLength] 41 | 42 | questionLengthsAll: length of each question. 43 | [batchSize] 44 | 45 | imagesPlaceholder: image features. 46 | [batchSize, channels, height, width] 47 | (converted internally to [batchSize, height, width, channels]) 48 | 49 | answersIndicesAll: integer ids of answer words. 50 | [batchSize] 51 | 52 | lr: learning rate (tensor scalar) 53 | train: train / evaluation (tensor boolean) 54 | 55 | dropout values dictionary (tensor scalars) 56 | ''' 57 | # change to H x W x C? 58 | def addPlaceholders(self): 59 | with tf.variable_scope("Placeholders"): 60 | ## data 61 | # questions 62 | self.questionsIndicesAll = tf.placeholder(tf.int32, shape = (None, None)) 63 | self.questionLengthsAll = tf.placeholder(tf.int32, shape = (None, )) 64 | 65 | # images 66 | # put image known dimension as last dim? 67 | self.imagesPlaceholder = tf.placeholder(tf.float32, shape = (None, None, None, None)) 68 | self.imagesAll = tf.transpose(self.imagesPlaceholder, (0, 2, 3, 1)) 69 | # self.imageH = tf.shape(self.imagesAll)[1] 70 | # self.imageW = tf.shape(self.imagesAll)[2] 71 | 72 | # answers 73 | self.answersIndicesAll = tf.placeholder(tf.int32, shape = (None, )) 74 | 75 | ## optimization 76 | self.lr = tf.placeholder(tf.float32, shape = ()) 77 | self.train = tf.placeholder(tf.bool, shape = ()) 78 | self.batchSizeAll = tf.shape(self.questionsIndicesAll)[0] 79 | 80 | ## dropouts 81 | # TODO: change dropouts to be 1 - current 82 | self.dropouts = { 83 | "encInput": tf.placeholder(tf.float32, shape = ()), 84 | "encState": tf.placeholder(tf.float32, shape = ()), 85 | "stem": tf.placeholder(tf.float32, shape = ()), 86 | "question": tf.placeholder(tf.float32, shape = ()), 87 | # self.dropouts["question"]Out = tf.placeholder(tf.float32, shape = ()) 88 | # self.dropouts["question"]MAC = tf.placeholder(tf.float32, shape = ()) 89 | "read": tf.placeholder(tf.float32, shape = ()), 90 | "write": tf.placeholder(tf.float32, shape = ()), 91 | "memory": tf.placeholder(tf.float32, shape = ()), 92 | "output": tf.placeholder(tf.float32, shape = ()) 93 | } 94 | 95 | # batch norm params 96 | self.batchNorm = {"decay": config.bnDecay, "train": self.train} 97 | 98 | # if config.parametricDropout: 99 | # self.dropouts["question"] = parametricDropout("qDropout", self.train) 100 | # self.dropouts["read"] = parametricDropout("readDropout", self.train) 101 | # else: 102 | # self.dropouts["question"] = self.dropouts["_q"] 103 | # self.dropouts["read"] = self.dropouts["_read"] 104 | 105 | # if config.tempDynamic: 106 | # self.tempAnnealRate = tf.placeholder(tf.float32, shape = ()) 107 | 108 | self.H, self.W, self.imageInDim = config.imageDims 109 | 110 | # Feeds data into placeholders. See addPlaceholders method for further details. 111 | def createFeedDict(self, data, images, train): 112 | feedDict = { 113 | self.questionsIndicesAll: data["questions"], 114 | self.questionLengthsAll: data["questionLengths"], 115 | self.imagesPlaceholder: images["images"], 116 | self.answersIndicesAll: data["answers"], 117 | 118 | self.dropouts["encInput"]: config.encInputDropout if train else 1.0, 119 | self.dropouts["encState"]: config.encStateDropout if train else 1.0, 120 | self.dropouts["stem"]: config.stemDropout if train else 1.0, 121 | self.dropouts["question"]: config.qDropout if train else 1.0, #_ 122 | self.dropouts["memory"]: config.memoryDropout if train else 1.0, 123 | self.dropouts["read"]: config.readDropout if train else 1.0, #_ 124 | self.dropouts["write"]: config.writeDropout if train else 1.0, 125 | self.dropouts["output"]: config.outputDropout if train else 1.0, 126 | # self.dropouts["question"]Out: config.qDropoutOut if train else 1.0, 127 | # self.dropouts["question"]MAC: config.qDropoutMAC if train else 1.0, 128 | 129 | self.lr: config.lr, 130 | self.train: train 131 | } 132 | 133 | # if config.tempDynamic: 134 | # feedDict[self.tempAnnealRate] = tempAnnealRate 135 | 136 | return feedDict 137 | 138 | # Splits data to a specific GPU (tower) for parallelization 139 | def initTowerBatch(self, towerI, towersNum, dataSize): 140 | towerBatchSize = tf.floordiv(dataSize, towersNum) 141 | start = towerI * towerBatchSize 142 | end = (towerI + 1) * towerBatchSize if towerI < towersNum - 1 else dataSize 143 | 144 | self.questionsIndices = self.questionsIndicesAll[start:end] 145 | self.questionLengths = self.questionLengthsAll[start:end] 146 | self.images = self.imagesAll[start:end] 147 | self.answersIndices = self.answersIndicesAll[start:end] 148 | 149 | self.batchSize = end - start 150 | 151 | ''' 152 | The Image Input Unit (stem). Passes the image features through a CNN-network 153 | Optionally adds position encoding (doesn't in the default behavior). 154 | Flatten the image into Height * Width "Knowledge base" array. 155 | 156 | Args: 157 | images: image input. [batchSize, height, width, inDim] 158 | inDim: input image dimension 159 | outDim: image out dimension 160 | addLoc: if not None, adds positional encoding to the image 161 | 162 | Returns preprocessed images. 163 | [batchSize, height * width, outDim] 164 | ''' 165 | def stem(self, images, inDim, outDim, addLoc = None): 166 | 167 | with tf.variable_scope("stem"): 168 | if addLoc is None: 169 | addLoc = config.locationAware 170 | 171 | if config.stemLinear: 172 | features = ops.linear(images, inDim, outDim) 173 | else: 174 | dims = [inDim] + ([config.stemDim] * (config.stemNumLayers - 1)) + [outDim] 175 | 176 | if addLoc: 177 | images, inDim = ops.addLocation(images, inDim, config.locationDim, 178 | h = self.H, w = self.W, locType = config.locationType) 179 | dims[0] = inDim 180 | 181 | # if config.locationType == "PE": 182 | # dims[-1] /= 4 183 | # dims[-1] *= 3 184 | # else: 185 | # dims[-1] -= 2 186 | features = ops.CNNLayer(images, dims, 187 | batchNorm = self.batchNorm if config.stemBN else None, 188 | dropout = self.dropouts["stem"], 189 | kernelSizes = config.stemKernelSizes, 190 | strides = config.stemStrideSizes) 191 | 192 | # if addLoc: 193 | # lDim = outDim / 4 194 | # lDim /= 4 195 | # features, _ = addLocation(features, dims[-1], lDim, h = H, w = W, 196 | # locType = config.locationType) 197 | 198 | if config.stemGridRnn: 199 | features = ops.multigridRNNLayer(features, H, W, outDim) 200 | 201 | # flatten the 2d images into a 1d KB 202 | features = tf.reshape(features, (self.batchSize, -1, outDim)) 203 | 204 | return features 205 | 206 | # Embed question using parametrized word embeddings. 207 | # The embedding are initialized to the values supported to the class initialization 208 | def qEmbeddingsOp(self, qIndices, embInit): 209 | with tf.variable_scope("qEmbeddings"): 210 | # if config.useCPU: 211 | # with tf.device('/cpu:0'): 212 | # embeddingsVar = tf.Variable(self.embeddingsInit, name = "embeddings", dtype = tf.float32) 213 | # else: 214 | # embeddingsVar = tf.Variable(self.embeddingsInit, name = "embeddings", dtype = tf.float32) 215 | embeddingsVar = tf.get_variable("emb", initializer = tf.to_float(embInit), 216 | dtype = tf.float32, trainable = (not config.wrdEmbFixed)) 217 | embeddings = tf.concat([tf.zeros((1, config.wrdEmbDim)), embeddingsVar], axis = 0) 218 | questions = tf.nn.embedding_lookup(embeddings, qIndices) 219 | 220 | return questions, embeddings 221 | 222 | # Embed answer words 223 | def aEmbeddingsOp(self, embInit): 224 | with tf.variable_scope("aEmbeddings"): 225 | if embInit is None: 226 | return None 227 | answerEmbeddings = tf.get_variable("emb", initializer = tf.to_float(embInit), 228 | dtype = tf.float32) 229 | return answerEmbeddings 230 | 231 | # Embed question and answer words with tied embeddings 232 | def qaEmbeddingsOp(self, qIndices, embInit): 233 | questions, qaEmbeddings = self.qEmbeddingsOp(qIndices, embInit["qa"]) 234 | aEmbeddings = tf.nn.embedding_lookup(qaEmbeddings, embInit["ansMap"]) 235 | 236 | return questions, qaEmbeddings, aEmbeddings 237 | 238 | ''' 239 | Embed question (and optionally answer) using parametrized word embeddings. 240 | The embedding are initialized to the values supported to the class initialization 241 | ''' 242 | def embeddingsOp(self, qIndices, embInit): 243 | if config.ansEmbMod == "SHARED": 244 | questions, qEmb, aEmb = self.qaEmbeddingsOp(qIndices, embInit) 245 | else: 246 | questions, qEmb = self.qEmbeddingsOp(qIndices, embInit["q"]) 247 | aEmb = self.aEmbeddingsOp(embInit["a"]) 248 | 249 | return questions, qEmb, aEmb 250 | 251 | ''' 252 | The Question Input Unit embeds the questions to randomly-initialized word vectors, 253 | and runs a recurrent bidirectional encoder (RNN/LSTM etc.) that gives back 254 | vector representations for each question (the RNN final hidden state), and 255 | representations for each of the question words (the RNN outputs for each word). 256 | 257 | The method uses bidirectional LSTM, by default. 258 | Optionally projects the outputs of the LSTM (with linear projection / 259 | optionally with some activation). 260 | 261 | Args: 262 | questions: question word embeddings 263 | [batchSize, questionLength, wordEmbDim] 264 | 265 | questionLengths: the question lengths. 266 | [batchSize] 267 | 268 | projWords: True to apply projection on RNN outputs. 269 | projQuestion: True to apply projection on final RNN state. 270 | projDim: projection dimension in case projection is applied. 271 | 272 | Returns: 273 | Contextual Words: RNN outputs for the words. 274 | [batchSize, questionLength, ctrlDim] 275 | 276 | Vectorized Question: Final hidden state representing the whole question. 277 | [batchSize, ctrlDim] 278 | ''' 279 | def encoder(self, questions, questionLengths, projWords = False, 280 | projQuestion = False, projDim = None): 281 | 282 | with tf.variable_scope("encoder"): 283 | # variational dropout option 284 | varDp = None 285 | if config.encVariationalDropout: 286 | varDp = {"stateDp": self.dropouts["stateInput"], 287 | "inputDp": self.dropouts["encInput"], 288 | "inputSize": config.wrdEmbDim} 289 | 290 | # rnns 291 | for i in range(config.encNumLayers): 292 | questionCntxWords, vecQuestions = ops.RNNLayer(questions, questionLengths, 293 | config.encDim, bi = config.encBi, cellType = config.encType, 294 | dropout = self.dropouts["encInput"], varDp = varDp, name = "rnn%d" % i) 295 | 296 | # dropout for the question vector 297 | vecQuestions = tf.nn.dropout(vecQuestions, self.dropouts["question"]) 298 | 299 | # projection of encoder outputs 300 | if projWords: 301 | questionCntxWords = ops.linear(questionCntxWords, config.encDim, projDim, 302 | name = "projCW") 303 | if projQuestion: 304 | vecQuestions = ops.linear(vecQuestions, config.encDim, projDim, 305 | act = config.encProjQAct, name = "projQ") 306 | 307 | return questionCntxWords, vecQuestions 308 | 309 | ''' 310 | Stacked Attention Layer for baseline. Computes interaction between images 311 | and the previous memory, and casts it back to compute attention over the 312 | image, which in turn is summed up with the previous memory to result in the 313 | new one. 314 | 315 | Args: 316 | images: input image. 317 | [batchSize, H * W, inDim] 318 | 319 | memory: previous memory value 320 | [batchSize, inDim] 321 | 322 | inDim: inputs dimension 323 | hDim: hidden dimension to compute interactions between image and memory 324 | 325 | Returns the new memory value. 326 | ''' 327 | def baselineAttLayer(self, images, memory, inDim, hDim, name = "", reuse = None): 328 | with tf.variable_scope("attLayer" + name, reuse = reuse): 329 | # projImages = ops.linear(images, inDim, hDim, name = "projImage") 330 | # projMemory = tf.expand_dims(ops.linear(memory, inDim, hDim, name = "projMemory"), axis = -2) 331 | # if config.saMultiplicative: 332 | # interactions = projImages * projMemory 333 | # else: 334 | # interactions = tf.tanh(projImages + projMemory) 335 | interactions, _ = ops.mul(images, memory, inDim, proj = {"dim": hDim, "shared": False}, 336 | interMod = config.baselineAttType) 337 | 338 | attention = ops.inter2att(interactions, hDim) 339 | summary = ops.att2Smry(attention, images) 340 | newMemory = memory + summary 341 | 342 | return newMemory 343 | 344 | ''' 345 | Baseline approach: 346 | If baselineAtt is True, applies several layers (baselineAttNumLayers) 347 | of stacked attention to image and memory, when memory is initialized 348 | to the vector questions. See baselineAttLayer for further details. 349 | 350 | Otherwise, computes result output features based on image representation 351 | (baselineCNN), or question (baselineLSTM) or both. 352 | 353 | Args: 354 | vecQuestions: question vector representation 355 | [batchSize, questionDim] 356 | 357 | questionDim: dimension of question vectors 358 | 359 | images: (flattened) image representation 360 | [batchSize, imageDim] 361 | 362 | imageDim: dimension of image representations. 363 | 364 | hDim: hidden dimension to compute interactions between image and memory 365 | (for attention-based baseline). 366 | 367 | Returns final features to use in later classifier. 368 | [batchSize, outDim] (out dimension depends on baseline method) 369 | ''' 370 | def baseline(self, vecQuestions, questionDim, images, imageDim, hDim): 371 | with tf.variable_scope("baseline"): 372 | if config.baselineAtt: 373 | memory = self.linear(vecQuestions, questionDim, hDim, name = "qProj") 374 | images = self.linear(images, imageDim, hDim, name = "iProj") 375 | 376 | for i in range(config.baselineAttNumLayers): 377 | memory = self.baselineAttLayer(images, memory, hDim, hDim, 378 | name = "baseline%d" % i) 379 | memDim = hDim 380 | else: 381 | images, imagesDim = ops.linearizeFeatures(images, self.H, self.W, 382 | imageDim, projDim = config.baselineProjDim) 383 | if config.baselineLSTM and config.baselineCNN: 384 | memory = tf.concat([vecQuestions, images], axis = -1) 385 | memDim = questionDim + imageDim 386 | elif config.baselineLSTM: 387 | memory = vecQuestions 388 | memDim = questionDim 389 | else: # config.baselineCNN 390 | memory = images 391 | memDim = imageDim 392 | 393 | return memory, memDim 394 | 395 | ''' 396 | Runs the MAC recurrent network to perform the reasoning process. 397 | Initializes a MAC cell and runs netLength iterations. 398 | 399 | Currently it passes the question and knowledge base to the cell during 400 | its creating, such that it doesn't need to interact with it through 401 | inputs / outputs while running. The recurrent computation happens 402 | by working iteratively over the hidden (control, memory) states. 403 | 404 | Args: 405 | images: flattened image features. Used as the "Knowledge Base". 406 | (Received by default model behavior from the Image Input Units). 407 | [batchSize, H * W, memDim] 408 | 409 | vecQuestions: vector questions representations. 410 | (Received by default model behavior from the Question Input Units 411 | as the final RNN state). 412 | [batchSize, ctrlDim] 413 | 414 | questionWords: question word embeddings. 415 | [batchSize, questionLength, ctrlDim] 416 | 417 | questionCntxWords: question contextual words. 418 | (Received by default model behavior from the Question Input Units 419 | as the series of RNN output states). 420 | [batchSize, questionLength, ctrlDim] 421 | 422 | questionLengths: question lengths. 423 | [batchSize] 424 | 425 | Returns the final control state and memory state resulted from the network. 426 | ([batchSize, ctrlDim], [bathSize, memDim]) 427 | ''' 428 | def MACnetwork(self, images, vecQuestions, questionWords, questionCntxWords, 429 | questionLengths, name = "", reuse = None): 430 | 431 | with tf.variable_scope("MACnetwork" + name, reuse = reuse): 432 | 433 | self.macCell = MACCell( 434 | vecQuestions = vecQuestions, 435 | questionWords = questionWords, 436 | questionCntxWords = questionCntxWords, 437 | questionLengths = questionLengths, 438 | knowledgeBase = images, 439 | memoryDropout = self.dropouts["memory"], 440 | readDropout = self.dropouts["read"], 441 | writeDropout = self.dropouts["write"], 442 | # qDropoutMAC = self.qDropoutMAC, 443 | batchSize = self.batchSize, 444 | train = self.train, 445 | reuse = reuse) 446 | 447 | state = self.macCell.zero_state(self.batchSize, tf.float32) 448 | 449 | # inSeq = tf.unstack(inSeq, axis = 1) 450 | none = tf.zeros((self.batchSize, 1), dtype = tf.float32) 451 | 452 | # for i, inp in enumerate(inSeq): 453 | for i in range(config.netLength): 454 | self.macCell.iteration = i 455 | # if config.unsharedCells: 456 | # with tf.variable_scope("iteration%d" % i): 457 | # macCell.myNameScope = "iteration%d" % i 458 | _, state = self.macCell(none, state) 459 | # else: 460 | # _, state = macCell(none, state) 461 | # macCell.reuse = True 462 | 463 | # self.autoEncMMLoss = macCell.autoEncMMLossI 464 | # inputSeqL = None 465 | # _, lastOutputs = tf.nn.dynamic_rnn(macCell, inputSeq, # / static 466 | # sequence_length = inputSeqL, 467 | # initial_state = initialState, 468 | # swap_memory = True) 469 | 470 | # self.postModules = None 471 | # if (config.controlPostRNN or config.selfAttentionMod == "POST"): # may not work well with dlogits 472 | # self.postModules, _ = self.RNNLayer(cLogits, None, config.encDim, bi = False, 473 | # name = "decPostRNN", cellType = config.controlPostRNNmod) 474 | # if config.controlPostRNN: 475 | # logits = self.postModules 476 | # self.postModules = tf.unstack(self.postModules, axis = 1) 477 | 478 | # self.autoEncCtrlLoss = tf.constant(0.0) 479 | # if config.autoEncCtrl: 480 | # autoEncCtrlCellType = ("GRU" if config.autoEncCtrlGRU else "RNN") 481 | # autoEncCtrlinp = logits 482 | # _, autoEncHid = self.RNNLayer(autoEncCtrlinp, None, config.encDim, 483 | # bi = True, name = "autoEncCtrl", cellType = autoEncCtrlCellType) 484 | # self.autoEncCtrlLoss = (tf.nn.l2_loss(vecQuestions - autoEncHid)) / tf.to_float(self.batchSize) 485 | 486 | finalControl = state.control 487 | finalMemory = state.memory 488 | 489 | return finalControl, finalMemory 490 | 491 | ''' 492 | Output Unit (step 1): chooses the inputs to the output classifier. 493 | 494 | By default the classifier input will be the the final memory state of the MAC network. 495 | If outQuestion is True, concatenate the question representation to that. 496 | If outImage is True, concatenate the image flattened representation. 497 | 498 | Args: 499 | memory: (final) memory state of the MAC network. 500 | [batchSize, memDim] 501 | 502 | vecQuestions: question vector representation. 503 | [batchSize, ctrlDim] 504 | 505 | images: image features. 506 | [batchSize, H, W, imageInDim] 507 | 508 | imageInDim: images dimension. 509 | 510 | Returns the resulted features and their dimension. 511 | ''' 512 | def outputOp(self, memory, vecQuestions, images, imageInDim): 513 | with tf.variable_scope("outputUnit"): 514 | features = memory 515 | dim = config.memDim 516 | 517 | if config.outQuestion: 518 | eVecQuestions = ops.linear(vecQuestions, config.ctrlDim, config.memDim, name = "outQuestion") 519 | features, dim = ops.concat(features, eVecQuestions, config.memDim, mul = config.outQuestionMul) 520 | 521 | if config.outImage: 522 | images, imagesDim = ops.linearizeFeatures(images, self.H, self.W, self.imageInDim, 523 | outputDim = config.outImageDim) 524 | images = ops.linear(images, config.memDim, config.outImageDim, name = "outImage") 525 | features = tf.concat([features, images], axis = -1) 526 | dim += config.outImageDim 527 | 528 | return features, dim 529 | 530 | ''' 531 | Output Unit (step 2): Computes the logits for the answers. Passes the features 532 | through fully-connected network to get the logits over the possible answers. 533 | Optionally uses answer word embeddings in computing the logits (by default, it doesn't). 534 | 535 | Args: 536 | features: features used to compute logits 537 | [batchSize, inDim] 538 | 539 | inDim: features dimension 540 | 541 | aEmbedding: supported word embeddings for answer words in case answerMod is not NON. 542 | Optionally computes logits by computing dot-product with answer embeddings. 543 | 544 | Returns: the computed logits. 545 | [batchSize, answerWordsNum] 546 | ''' 547 | def classifier(self, features, inDim, aEmbeddings = None): 548 | with tf.variable_scope("classifier"): 549 | outDim = config.answerWordsNum 550 | dims = [inDim] + config.outClassifierDims + [outDim] 551 | if config.answerMod != "NON": 552 | dims[-1] = config.wrdEmbDim 553 | 554 | 555 | logits = ops.FCLayer(features, dims, 556 | batchNorm = self.batchNorm if config.outputBN else None, 557 | dropout = self.dropouts["output"]) 558 | 559 | if config.answerMod != "NON": 560 | logits = tf.nn.dropout(logits, self.dropouts["output"]) 561 | interactions = ops.mul(aEmbeddings, logits, dims[-1], interMod = config.answerMod) 562 | logits = ops.inter2logits(interactions, dims[-1], sumMod = "SUM") 563 | logits += ops.getBias((outputDim, ), "ans") 564 | 565 | # answersWeights = tf.transpose(aEmbeddings) 566 | 567 | # if config.answerMod == "BL": 568 | # Wans = ops.getWeight((dims[-1], config.wrdEmbDim), "ans") 569 | # logits = tf.matmul(logits, Wans) 570 | # elif config.answerMod == "DIAG": 571 | # Wans = ops.getWeight((config.wrdEmbDim, ), "ans") 572 | # logits = logits * Wans 573 | 574 | # logits = tf.matmul(logits, answersWeights) 575 | 576 | return logits 577 | 578 | # def getTemp(): 579 | # with tf.variable_scope("temperature"): 580 | # if config.tempParametric: 581 | # self.temperatureVar = tf.get_variable("temperature", shape = (), 582 | # initializer = tf.constant_initializer(5), dtype = tf.float32) 583 | # temperature = tf.sigmoid(self.temperatureVar) 584 | # else: 585 | # temperature = config.temperature 586 | 587 | # if config.tempDynamic: 588 | # temperature *= self.tempAnnealRate 589 | 590 | # return temperature 591 | 592 | # Computes mean cross entropy loss between logits and answers. 593 | def addAnswerLossOp(self, logits, answers): 594 | with tf.variable_scope("answerLoss"): 595 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = answers, logits = logits) 596 | loss = tf.reduce_mean(losses) 597 | self.answerLossList.append(loss) 598 | 599 | return loss, losses 600 | 601 | # Computes predictions (by finding maximal logit value, corresponding to highest probability) 602 | # and mean accuracy between predictions and answers. 603 | def addPredOp(self, logits, answers): 604 | with tf.variable_scope("pred"): 605 | preds = tf.to_int32(tf.argmax(logits, axis = -1)) # tf.nn.softmax( 606 | corrects = tf.equal(preds, answers) 607 | correctNum = tf.reduce_sum(tf.to_int32(corrects)) 608 | acc = tf.reduce_mean(tf.to_float(corrects)) 609 | self.correctNumList.append(correctNum) 610 | self.answerAccList.append(acc) 611 | 612 | return preds, corrects, correctNum 613 | 614 | # Creates optimizer (adam) 615 | def addOptimizerOp(self): 616 | with tf.variable_scope("trainAddOptimizer"): 617 | self.globalStep = tf.Variable(0, dtype = tf.int32, trainable = False, name = "globalStep") # init to 0 every run? 618 | optimizer = tf.train.AdamOptimizer(learning_rate = self.lr) 619 | 620 | return optimizer 621 | 622 | ''' 623 | Computes gradients for all variables or subset of them, based on provided loss, 624 | using optimizer. 625 | ''' 626 | def computeGradients(self, optimizer, loss, trainableVars = None): # tf.trainable_variables() 627 | with tf.variable_scope("computeGradients"): 628 | if config.trainSubset: 629 | trainableVars = [] 630 | allVars = tf.trainable_variables() 631 | for var in allVars: 632 | if any((s in var.name) for s in config.varSubset): 633 | trainableVars.append(var) 634 | 635 | gradients_vars = optimizer.compute_gradients(loss, trainableVars) 636 | return gradients_vars 637 | 638 | ''' 639 | Apply gradients. Optionally clip them, and update exponential moving averages 640 | for parameters. 641 | ''' 642 | def addTrainingOp(self, optimizer, gradients_vars): 643 | with tf.variable_scope("train"): 644 | gradients, variables = zip(*gradients_vars) 645 | norm = tf.global_norm(gradients) 646 | 647 | # gradient clipping 648 | if config.clipGradients: 649 | clippedGradients, _ = tf.clip_by_global_norm(gradients, config.gradMaxNorm, use_norm = norm) 650 | gradients_vars = zip(clippedGradients, variables) 651 | 652 | # updates ops (for batch norm) and train op 653 | updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 654 | with tf.control_dependencies(updateOps): 655 | train = optimizer.apply_gradients(gradients_vars, global_step = self.globalStep) 656 | 657 | # exponential moving average 658 | if config.useEMA: 659 | ema = tf.train.ExponentialMovingAverage(decay = config.emaDecayRate) 660 | maintainAveragesOp = ema.apply(tf.trainable_variables()) 661 | 662 | with tf.control_dependencies([train]): 663 | trainAndUpdateOp = tf.group(maintainAveragesOp) 664 | 665 | train = trainAndUpdateOp 666 | 667 | self.emaDict = ema.variables_to_restore() 668 | 669 | return train, norm 670 | 671 | # TODO (add back support for multi-gpu..) 672 | def averageAcrossTowers(self, gpusNum): 673 | self.lossAll = self.lossList[0] 674 | 675 | self.answerLossAll = self.answerLossList[0] 676 | self.correctNumAll = self.correctNumList[0] 677 | self.answerAccAll = self.answerAccList[0] 678 | self.predsAll = self.predsList[0] 679 | self.gradientVarsAll = self.gradientVarsList[0] 680 | 681 | def trim2DVectors(self, vectors, vectorsLengths): 682 | maxLength = np.max(vectorsLengths) 683 | return vectors[:,:maxLength] 684 | 685 | def trimData(self, data): 686 | data["questions"] = self.trim2DVectors(data["questions"], data["questionLengths"]) 687 | return data 688 | 689 | ''' 690 | Builds predictions JSON, by adding the model's predictions and attention maps 691 | back to the original data JSON. 692 | ''' 693 | def buildPredsList(self, data, predictions, attentionMaps): 694 | predsList = [] 695 | 696 | for i, instance in enumerate(data["instances"]): 697 | 698 | if predictions is not None: 699 | pred = self.answerDict.decodeId(predictions[i]) 700 | instance["prediction"] = pred 701 | 702 | # aggregate np attentions of instance i in the batch into 2d list 703 | attMapToList = lambda attMap: [step[i].tolist() for step in attMap] 704 | if attentionMaps is not None: 705 | attentions = {k: attMapToList(attentionMaps[k]) for k in attentionMaps} 706 | instance["attentions"] = attentions 707 | 708 | predsList.append(instance) 709 | 710 | return predsList 711 | 712 | ''' 713 | Processes a batch of data with the model. 714 | 715 | Args: 716 | sess: TF session 717 | 718 | data: Data batch. Dictionary that contains numpy array for: 719 | questions, questionLengths, answers. 720 | See preprocess.py for further information of the batch structure. 721 | 722 | images: batch of image features, as numpy array. images["images"] contains 723 | [batchSize, channels, h, w] 724 | 725 | train: True to run batch for training. 726 | 727 | getAtt: True to return attention maps for question and image (and optionally 728 | self-attention and gate values). 729 | 730 | Returns results: e.g. loss, accuracy, running time. 731 | ''' 732 | def runBatch(self, sess, data, images, train, getAtt = False): 733 | data = self.trimData(data) 734 | 735 | trainOp = self.trainOp if train else self.noOp 736 | gradNormOp = self.gradNorm if train else self.noOp 737 | 738 | predsOp = (self.predsAll, self.correctNumAll, self.answerAccAll) 739 | 740 | attOp = self.macCell.attentions 741 | 742 | time0 = time.time() 743 | feed = self.createFeedDict(data, images, train) 744 | 745 | time1 = time.time() 746 | _, loss, predsInfo, gradNorm, attentionMaps = sess.run( 747 | [trainOp, self.lossAll, predsOp, gradNormOp, attOp], 748 | feed_dict = feed) 749 | 750 | time2 = time.time() 751 | 752 | predsList = self.buildPredsList(data, predsInfo[0], attentionMaps if getAtt else None) 753 | 754 | return {"loss": loss, 755 | "correctNum": predsInfo[1], 756 | "acc": predsInfo[2], 757 | "preds": predsList, 758 | "gradNorm": gradNorm if train else -1, 759 | "readTime": time1 - time0, 760 | "trainTime": time2 - time1} 761 | 762 | def build(self): 763 | self.addPlaceholders() 764 | self.optimizer = self.addOptimizerOp() 765 | 766 | self.gradientVarsList = [] 767 | self.lossList = [] 768 | 769 | self.answerLossList = [] 770 | self.correctNumList = [] 771 | self.answerAccList = [] 772 | self.predsList = [] 773 | 774 | with tf.variable_scope("macModel"): 775 | for i in range(config.gpusNum): 776 | with tf.device("/gpu:{}".format(i)): 777 | with tf.name_scope("tower{}".format(i)) as scope: 778 | self.initTowerBatch(i, config.gpusNum, self.batchSizeAll) 779 | 780 | self.loss = tf.constant(0.0) 781 | 782 | # embed questions words (and optionally answer words) 783 | questionWords, qEmbeddings, aEmbeddings = \ 784 | self.embeddingsOp(self.questionsIndices, self.embeddingsInit) 785 | 786 | projWords = projQuestion = ((config.encDim != config.ctrlDim) or config.encProj) 787 | questionCntxWords, vecQuestions = self.encoder(questionWords, 788 | self.questionLengths, projWords, projQuestion, config.ctrlDim) 789 | 790 | # Image Input Unit (stem) 791 | imageFeatures = self.stem(self.images, self.imageInDim, config.memDim) 792 | 793 | # baseline model 794 | if config.useBaseline: 795 | output, dim = self.baseline(vecQuestions, config.ctrlDim, 796 | self.images, self.imageInDim, config.attDim) 797 | # MAC model 798 | else: 799 | # self.temperature = self.getTemp() 800 | 801 | finalControl, finalMemory = self.MACnetwork(imageFeatures, vecQuestions, 802 | questionWords, questionCntxWords, self.questionLengths) 803 | 804 | # Output Unit - step 1 (preparing classifier inputs) 805 | output, dim = self.outputOp(finalMemory, vecQuestions, 806 | self.images, self.imageInDim) 807 | 808 | # Output Unit - step 2 (classifier) 809 | logits = self.classifier(output, dim, aEmbeddings) 810 | 811 | # compute loss, predictions, accuracy 812 | answerLoss, self.losses = self.addAnswerLossOp(logits, self.answersIndices) 813 | self.preds, self.corrects, self.correctNum = self.addPredOp(logits, self.answersIndices) 814 | self.loss += answerLoss 815 | self.predsList.append(self.preds) 816 | 817 | self.lossList.append(self.loss) 818 | 819 | # compute gradients 820 | gradient_vars = self.computeGradients(self.optimizer, self.loss, trainableVars = None) 821 | self.gradientVarsList.append(gradient_vars) 822 | 823 | # reuse variables in next towers 824 | tf.get_variable_scope().reuse_variables() 825 | 826 | self.averageAcrossTowers(config.gpusNum) 827 | 828 | self.trainOp, self.gradNorm = self.addTrainingOp(self.optimizer, self.gradientVarsAll) 829 | self.noOp = tf.no_op() 830 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import random 4 | import json 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from termcolor import colored 9 | from program_translator import ProgramTranslator # 10 | from config import config 11 | 12 | # Print bold tex 13 | def bold(txt): 14 | return colored(str(txt),attrs = ["bold"]) 15 | 16 | # Print bold and colored text 17 | def bcolored(txt, color): 18 | return colored(str(txt), color, attrs = ["bold"]) 19 | 20 | # Write a line to file 21 | def writeline(f, line): 22 | f.write(str(line) + "\n") 23 | 24 | # Write a list to file 25 | def writelist(f, l): 26 | writeline(f, ",".join(map(str, l))) 27 | 28 | # 2d list to numpy 29 | def vectorize2DList(items, minX = 0, minY = 0, dtype = np.int): 30 | maxX = max(len(items), minX) 31 | maxY = max([len(item) for item in items] + [minY]) 32 | t = np.zeros((maxX, maxY), dtype = dtype) 33 | tLengths = np.zeros((maxX, ), dtype = np.int) 34 | for i, item in enumerate(items): 35 | t[i, 0:len(item)] = np.array(item, dtype = dtype) 36 | tLengths[i] = len(item) 37 | return t, tLengths 38 | 39 | # 3d list to numpy 40 | def vectorize3DList(items, minX = 0, minY = 0, minZ = 0, dtype = np.int): 41 | maxX = max(len(items), minX) 42 | maxY = max([len(item) for item in items] + [minY]) 43 | maxZ = max([len(subitem) for item in items for subitem in item] + [minZ]) 44 | t = np.zeros((maxX, maxY, maxZ), dtype = dtype) 45 | tLengths = np.zeros((maxX, maxY), dtype = np.int) 46 | for i, item in enumerate(items): 47 | for j, subitem in enumerate(item): 48 | t[i, j, 0:len(subitem)] = np.array(subitem, dtype = dtype) 49 | tLengths[i, j] = len(subitem) 50 | return t, tLengths 51 | 52 | ''' 53 | Encodes text into integers. Keeps dictionary between string words (symbols) 54 | and their matching integers. Supports encoding and decoding. 55 | ''' 56 | class SymbolDict(object): 57 | def __init__(self, empty = False): 58 | self.padding = "" 59 | self.unknown = "" 60 | self.start = "" 61 | self.end = "" 62 | 63 | self.invalidSymbols = [self.padding, self.unknown, self.start, self.end] 64 | 65 | if empty: 66 | self.sym2id = {} 67 | self.id2sym = [] 68 | else: 69 | self.sym2id = {self.padding: 0, self.unknown: 1, self.start: 2, self.end: 3} 70 | self.id2sym = [self.padding, self.unknown, self.start, self.end] 71 | self.allSeqs = [] 72 | 73 | def getNumSymbols(self): 74 | return len(self.sym2id) 75 | 76 | def isPadding(self, enc): 77 | return enc == 0 78 | 79 | def isUnknown(self, enc): 80 | return enc == 1 81 | 82 | def isStart(self, enc): 83 | return enc == 2 84 | 85 | def isEnd(self, enc): 86 | return enc == 3 87 | 88 | def isValid(self, enc): 89 | return enc < self.getNumSymbols() and enc >= len(self.invalidSymbols) 90 | 91 | def resetSeqs(self): 92 | self.allSeqs = [] 93 | 94 | def addSeq(self, seq): 95 | self.allSeqs += seq 96 | 97 | # Call to create the words-to-integers vocabulary after (reading word sequences with addSeq). 98 | def createVocab(self, minCount = 0): 99 | counter = {} 100 | for symbol in self.allSeqs: 101 | counter[symbol] = counter.get(symbol, 0) + 1 102 | for symbol in counter: 103 | if counter[symbol] > minCount and (symbol not in self.sym2id): 104 | self.sym2id[symbol] = self.getNumSymbols() 105 | self.id2sym.append(symbol) 106 | 107 | # Encodes a symbol. Returns the matching integer. 108 | def encodeSym(self, symbol): 109 | if symbol not in self.sym2id: 110 | symbol = self.unknown 111 | return self.sym2id[symbol] 112 | 113 | ''' 114 | Encodes a sequence of symbols. 115 | Optionally add start, or end symbols. 116 | Optionally reverse sequence 117 | ''' 118 | def encodeSequence(self, decoded, addStart = False, addEnd = False, reverse = False): 119 | if reverse: 120 | decoded.reverse() 121 | if addStart: 122 | decoded = [self.start] + decoded 123 | if addEnd: 124 | decoded = decoded + [self.end] 125 | encoded = [self.encodeSym(symbol) for symbol in decoded] 126 | return encoded 127 | 128 | # Decodes an integer into its symbol 129 | def decodeId(self, enc): 130 | return self.id2sym[enc] if enc < self.getNumSymbols() else self.unknown 131 | 132 | ''' 133 | Decodes a sequence of integers into their symbols. 134 | If delim is given, joins the symbols using delim, 135 | Optionally reverse the resulted sequence 136 | ''' 137 | def decodeSequence(self, encoded, delim = None, reverse = False, stopAtInvalid = True): 138 | length = 0 139 | for i in range(len(encoded)): 140 | if not self.isValid(encoded[i]) and stopAtInvalid: 141 | break 142 | length += 1 143 | encoded = encoded[:length] 144 | 145 | decoded = [self.decodeId(enc) for enc in encoded] 146 | if reverse: 147 | decoded.reverse() 148 | 149 | if delim is not None: 150 | return delim.join(decoded) 151 | 152 | return decoded 153 | 154 | ''' 155 | Preprocesses a given dataset into numpy arrays. 156 | By calling preprocess, the class: 157 | 1. Reads the input data files into dictionary. 158 | 2. Saves the results jsons in files and loads them instead of parsing input if files exist/ 159 | 3. Initializes word embeddings to random / GloVe. 160 | 4. Optionally filters data according to given filters. 161 | 5. Encodes and vectorize the data into numpy arrays. 162 | 6. Buckets the data according to the instances length. 163 | ''' 164 | class Preprocesser(object): 165 | def __init__(self): 166 | self.questionDict = SymbolDict() 167 | self.answerDict = SymbolDict(empty = True) 168 | self.qaDict = SymbolDict() 169 | 170 | self.specificDatasetDicts = None 171 | 172 | self.programDict = SymbolDict() 173 | self.programTranslator = ProgramTranslator(self.programDict, 2) 174 | ''' 175 | Tokenizes string into list of symbols. 176 | 177 | Args: 178 | text: raw string to tokenize. 179 | ignorePuncts: punctuation to ignore 180 | keptPunct: punctuation to keep (as symbol) 181 | endPunct: punctuation to remove if appears at the end 182 | delim: delimiter between symbols 183 | clean: True to replace text in string 184 | replacelistPre: dictionary of replacement to perform on the text before tokanization 185 | replacelistPost: dictionary of replacement to perform on the text after tokanization 186 | ''' 187 | # sentence tokenizer 188 | allPunct = ["?", "!", "\\", "/", ")", "(", ".", ",", ";", ":"] 189 | def tokenize(self, text, ignoredPuncts = ["?", "!", "\\", "/", ")", "("], 190 | keptPuncts = [".", ",", ";", ":"], endPunct = [">", "<", ":"], delim = " ", 191 | clean = False, replacelistPre = dict(), replacelistPost = dict()): 192 | 193 | if clean: 194 | for word in replacelistPre: 195 | origText = text 196 | text = text.replace(word, replacelistPre[word]) 197 | if (origText != text): 198 | print(origText) 199 | print(text) 200 | print("") 201 | 202 | for punct in endPunct: 203 | if text[-1] == punct: 204 | print(text) 205 | text = text[:-1] 206 | print(text) 207 | print("") 208 | 209 | for punct in keptPuncts: 210 | text = text.replace(punct, delim + punct + delim) 211 | 212 | for punct in ignoredPuncts: 213 | text = text.replace(punct, "") 214 | 215 | ret = text.lower().split(delim) 216 | 217 | if clean: 218 | origRet = ret 219 | ret = [replacelistPost.get(word, word) for word in ret] 220 | if origRet != ret: 221 | print(origRet) 222 | print(ret) 223 | 224 | ret = [t for t in ret if t != ""] 225 | return ret 226 | 227 | 228 | # Read class' generated files. 229 | # files interface 230 | def readFiles(self, instancesFilename): 231 | with open(instancesFilename, "r") as inFile: 232 | instances = json.load(inFile) 233 | 234 | with open(config.questionDictFile(), "rb") as inFile: 235 | self.questionDict = pickle.load(inFile) 236 | 237 | with open(config.answerDictFile(), "rb") as inFile: 238 | self.answerDict = pickle.load(inFile) 239 | 240 | with open(config.qaDictFile(), "rb") as inFile: 241 | self.qaDict = pickle.load(inFile) 242 | 243 | return instances 244 | 245 | ''' 246 | Generate class' files. Save json representation of instances and 247 | symbols-to-integers dictionaries. 248 | ''' 249 | def writeFiles(self, instances, instancesFilename): 250 | with open(instancesFilename, "w") as outFile: 251 | json.dump(instances, outFile) 252 | 253 | with open(config.questionDictFile(), "wb") as outFile: 254 | pickle.dump(self.questionDict, outFile) 255 | 256 | with open(config.answerDictFile(), "wb") as outFile: 257 | pickle.dump(self.answerDict, outFile) 258 | 259 | with open(config.qaDictFile(), "wb") as outFile: 260 | pickle.dump(self.qaDict, outFile) 261 | 262 | # Write prediction json to file and optionally a one-answer-per-line output file 263 | def writePreds(self, res, tier, suffix = ""): 264 | if res is None: 265 | return 266 | preds = res["preds"] 267 | sortedPreds = sorted(preds, key = lambda instance: instance["index"]) 268 | with open(config.predsFile(tier + suffix), "w") as outFile: 269 | outFile.write(json.dumps(sortedPreds)) 270 | with open(config.answersFile(tier + suffix), "w") as outFile: 271 | for instance in sortedPreds: 272 | writeline(outFile, instance["prediction"]) 273 | 274 | # Reads NLVR data entries and create a json dictionary. 275 | def readNLVR(self, datasetFilename, instancesFilename, train): 276 | instances = [] 277 | i = 0 278 | 279 | if os.path.exists(instancesFilename): 280 | instances = self.readFiles(instancesFilename) 281 | else: 282 | with open(datasetFilename, "r") as datasetFile: 283 | for line in datasetFile: 284 | instance = json.loads(line) 285 | question = instance["sentence"] 286 | questionSeq = self.tokenize(question, 287 | ignoredPuncts = Preprocesser.allPunct, keptPuncts = []) 288 | 289 | if train or (not config.wrdEmbUnknown): 290 | self.questionDict.addSeq(question) 291 | self.qaDict.addSeq(question) 292 | 293 | answer = instance["label"] 294 | self.answerDict.addSeq([answer]) 295 | self.qaDict.addSeq([answer]) 296 | 297 | for k in range(6): 298 | instances.append({ 299 | "question": question, 300 | "questionSeq": questionSeq, 301 | "answer": answer, 302 | "imageId": instance["identifier"] + "-" + str(k), 303 | "index": i 304 | }) 305 | i += 1 306 | 307 | random.shuffle(instances) 308 | 309 | self.questionDict.createVocab() 310 | self.answerDict.createVocab() 311 | self.qaDict.createVocab() 312 | 313 | self.writeFiles(instances, instancesFilename) 314 | 315 | return instances 316 | 317 | # Reads CLEVR data entries and create a json dictionary. 318 | def readCLEVR(self, datasetFilename, instancesFilename, train): 319 | instances = [] 320 | 321 | if os.path.exists(instancesFilename): 322 | instances = self.readFiles(instancesFilename) 323 | else: 324 | with open(datasetFilename, "r") as datasetFile: 325 | data = json.load(datasetFile)["questions"] 326 | for i in tqdm(range(len(data)), desc = "Preprocessing"): 327 | instance = data[i] 328 | 329 | question = instance["question"] 330 | questionSeq = self.tokenize(question) 331 | 332 | if train or (not config.wrdEmbUnknown): 333 | self.questionDict.addSeq(questionSeq) 334 | self.qaDict.addSeq(questionSeq) 335 | 336 | answer = instance.get("answer", "yes") # DUMMY_ANSWER 337 | self.answerDict.addSeq([answer]) 338 | self.qaDict.addSeq([answer]) 339 | 340 | dummyProgram = [{"function": "FUNC", "value_inputs": [], "inputs": []}] 341 | program = instance.get("program", dummyProgram) 342 | postfixProgram = self.programTranslator.programToPostfixProgram(program) 343 | programSeq = self.programTranslator.programToSeq(postfixProgram) 344 | programInputs = self.programTranslator.programToInputs(postfixProgram, 345 | offset = 2) 346 | 347 | # pass other fields to instance? 348 | instances.append({ 349 | "question": question, 350 | "questionSeq": questionSeq, 351 | "answer": answer, 352 | "imageId": instance["image_index"], 353 | "program": program, 354 | "programSeq": programSeq, 355 | "programInputs": programInputs, 356 | "index": i 357 | }) 358 | 359 | random.shuffle(instances) 360 | 361 | self.questionDict.createVocab() 362 | self.answerDict.createVocab() 363 | self.qaDict.createVocab() 364 | 365 | self.writeFiles(instances, instancesFilename) 366 | 367 | return instances 368 | 369 | ''' 370 | Reads data in datasetFilename, and creates json dictionary. 371 | If instancesFilename exists, restore dictionary from this file. 372 | Otherwise, save created dictionary to instancesFilename. 373 | ''' 374 | def readData(self, datasetFilename, instancesFilename, train): 375 | # data extraction 376 | datasetReader = { 377 | "CLEVR": self.readCLEVR, 378 | "NLVR": self.readNLVR 379 | } 380 | 381 | return datasetReader[config.dataset](datasetFilename, instancesFilename, train) 382 | 383 | # Reads dataset tier (train, val, test) and returns the loaded instances 384 | # and image relevant filenames 385 | def readTier(self, tier, train): 386 | imagesFilename = config.imagesFile(tier) 387 | datasetFilename = config.datasetFile(tier) 388 | instancesFilename = config.instancesFile(tier) 389 | 390 | instances = self.readData(datasetFilename, instancesFilename, train) 391 | 392 | images = {"imagesFilename": imagesFilename} 393 | if config.dataset == "NLVR": 394 | images["imageIdsFilename"] = config.imagesIdsFile(tier) 395 | 396 | return {"instances": instances, "images": images, "train": train} 397 | 398 | ''' 399 | Reads all tiers of a dataset (train if exists, val, test). 400 | Creates also evalTrain tier which will optionally be used for evaluation. 401 | ''' 402 | def readDataset(self, suffix = "", hasTrain = True): 403 | dataset = {"train": None, "evalTrain": None, "val": None, "test": None} 404 | if hasTrain: 405 | dataset["train"] = self.readTier("train" + suffix, train = True) 406 | dataset["val"] = self.readTier("val" + suffix, train = False) 407 | dataset["test"] = self.readTier("test" + suffix, train = False) 408 | 409 | if hasTrain: 410 | dataset["evalTrain"] = {} 411 | for k in dataset["train"]: 412 | dataset["evalTrain"][k] = dataset["train"][k] 413 | dataset["evalTrain"]["train"] = False 414 | 415 | return dataset 416 | 417 | # Transform symbols to corresponding integers and vectorize into numpy array 418 | def vectorizeData(self, data): 419 | # if "SHARED" tie symbol representations in questions and answers 420 | if config.ansEmbMod == "SHARED": 421 | qDict = self.qaDict 422 | else: 423 | qDict = self.questionDict 424 | 425 | encodedQuestions = [qDict.encodeSequence(d["questionSeq"]) for d in data] 426 | questions, questionsL = vectorize2DList(encodedQuestions) 427 | 428 | answers = np.array([self.answerDict.encodeSym(d["answer"]) for d in data]) 429 | 430 | # pass the whole instances? if heavy then not good 431 | imageIds = [d["imageId"] for d in data] 432 | indices = [d["index"] for d in data] 433 | instances = data 434 | 435 | return { "questions": questions, 436 | "questionLengths": questionsL, 437 | "answers": answers, 438 | "imageIds": imageIds, 439 | "indices": indices, 440 | "instances": instances 441 | } 442 | 443 | # Separates data based on a field length 444 | def lseparator(self, key, lims): 445 | maxI = len(lims) 446 | def separatorFn(x): 447 | v = x[key] 448 | for i, lim in enumerate(lims): 449 | if len(v) < lim: 450 | return i 451 | return maxI 452 | return {"separate": separatorFn, "groupsNum": maxI + 1} 453 | 454 | # # separate data based on a field type 455 | # def tseparator(self, key, types): 456 | # typesNum = len(types) + 1 457 | # def separatorFn(x): 458 | # v = str(x[key][-1]) 459 | # return types.get(v, len(types)) 460 | # return {"separate": separatorFn, "groupsNum": typesNum} 461 | 462 | # # separate data based on field arity 463 | # def bseparator(self, key): 464 | # def separatorFn(x): 465 | # cond = (len(x[key][-1]) == 2) 466 | # return (1 if cond else 0) 467 | # return {"separate": separatorFn, "groupsNum": 2} 468 | 469 | # Buckets data to groups using a separator 470 | def bucket(self, instances, separator): 471 | buckets = [[] for i in range(separator["groupsNum"])] 472 | for instance in instances: 473 | bucketI = separator["separate"](instance) 474 | buckets[bucketI].append(instance) 475 | return [bucket for bucket in buckets if len(bucket) > 0] 476 | 477 | # Re-buckets bucket list given a seperator 478 | def rebucket(self, buckets, separator): 479 | res = [] 480 | for bucket in buckets: 481 | res += self.bucket(bucket, separator) 482 | return res 483 | 484 | # Buckets data based on question / program length 485 | def bucketData(self, data, noBucket = False): 486 | if noBucket: 487 | buckets = [data] 488 | else: 489 | if config.noBucket: 490 | buckets = [data] 491 | elif config.noRebucket: 492 | questionSep = self.lseparator("questionSeq", config.questionLims) 493 | buckets = self.bucket(data, questionSep) 494 | else: 495 | programSep = self.lseparator("programSeq", config.programLims) 496 | questionSep = self.lseparator("questionSeq", config.questionLims) 497 | buckets = self.bucket(data, programSep) 498 | buckets = self.rebucket(buckets, questionSep) 499 | return buckets 500 | 501 | ''' 502 | Prepares data: 503 | 1. Filters data according to above arguments. 504 | 2. Takes only a subset of the data based on config.trainedNum / config.testedNum 505 | 3. Buckets data according to question / program length 506 | 4. Vectorizes data into numpy arrays 507 | ''' 508 | def prepareData(self, data, train, filterKey = None, noBucket = False): 509 | filterDefault = {"maxQLength": 0, "maxPLength": 0, "onlyChain": False, "filterOp": 0} 510 | 511 | filterTrain = {"maxQLength": config.tMaxQ, "maxPLength": config.tMaxP, 512 | "onlyChain": config.tOnlyChain, "filterOp": config.tFilterOp} 513 | 514 | filterVal = {"maxQLength": config.vMaxQ, "maxPLength": config.vMaxP, 515 | "onlyChain": config.vOnlyChain, "filterOp": config.vFilterOp} 516 | 517 | filters = {"train": filterTrain, "evalTrain": filterTrain, 518 | "val": filterVal, "test": filterDefault} 519 | 520 | if filterKey is None: 521 | fltr = filterDefault 522 | else: 523 | fltr = filters[filterKey] 524 | 525 | # split data when finetuning on validation set 526 | if config.trainExtra and config.extraVal and (config.finetuneNum > 0): 527 | if train: 528 | data = data[:config.finetuneNum] 529 | else: 530 | data = data[config.finetuneNum:] 531 | 532 | typeFilter = config.typeFilters[fltr["filterOp"]] 533 | # filter specific settings 534 | if fltr["onlyChain"]: 535 | data = [d for d in data if all((len(inputNum) < 2) for inputNum in d["programInputs"])] 536 | if fltr["maxQLength"] > 0: 537 | data = [d for d in data if len(d["questionSeq"]) <= fltr["maxQLength"]] 538 | if fltr["maxPLength"] > 0: 539 | data = [d for d in data if len(d["programSeq"]) <= fltr["maxPLength"]] 540 | if len(typeFilter) > 0: 541 | data = [d for d in data if d["programSeq"][-1] not in typeFilter] 542 | 543 | # run on subset of the data. If 0 then use all data 544 | num = config.trainedNum if train else config.testedNum 545 | # retainVal = True to retain same sample of validation across runs 546 | if (not train) and (not config.retainVal): 547 | random.shuffle(data) 548 | if num > 0: 549 | data = data[:num] 550 | # set number to match dataset size 551 | if train: 552 | config.trainedNum = len(data) 553 | else: 554 | config.testedNum = len(data) 555 | 556 | # bucket 557 | buckets = self.bucketData(data, noBucket = noBucket) 558 | 559 | # vectorize 560 | return [self.vectorizeData(bucket) for bucket in buckets] 561 | 562 | # Prepares all the tiers of a dataset. See prepareData method for further details. 563 | def prepareDataset(self, dataset, noBucket = False): 564 | if dataset is None: 565 | return None 566 | 567 | for tier in dataset: 568 | if dataset[tier] is not None: 569 | dataset[tier]["data"] = self.prepareData(dataset[tier]["instances"], 570 | train = dataset[tier]["train"], filterKey = tier, noBucket = noBucket) 571 | 572 | for tier in dataset: 573 | if dataset[tier] is not None: 574 | del dataset[tier]["instances"] 575 | 576 | return dataset 577 | 578 | # Initializes word embeddings to random uniform / random normal / GloVe. 579 | def initializeWordEmbeddings(self, wordsDict = None, noPadding = False): 580 | # default dictionary to use for embeddings 581 | if wordsDict is None: 582 | wordsDict = self.questionDict 583 | 584 | # uniform initialization 585 | if config.wrdEmbUniform: 586 | lowInit = -1.0 * config.wrdEmbScale 587 | highInit = 1.0 * config.wrdEmbScale 588 | embeddings = np.random.uniform(low = lowInit, high = highInit, 589 | size = (wordsDict.getNumSymbols(), config.wrdEmbDim)) 590 | # normal initialization 591 | else: 592 | embeddings = config.wrdEmbScale * np.random.randn(wordsDict.getNumSymbols(), 593 | config.wrdEmbDim) 594 | 595 | # if wrdEmbRandom = False, use GloVE 596 | counter = 0 597 | if (not config.wrdEmbRandom): 598 | with open(config.wordVectorsFile, 'r') as inFile: 599 | for line in inFile: 600 | line = line.strip().split() 601 | word = line[0].lower() 602 | vector = [float(x) for x in line[1:]] 603 | index = wordsDict.sym2id.get(word) 604 | if index is not None: 605 | embeddings[index] = vector 606 | counter += 1 607 | 608 | print(counter) 609 | print(self.questionDict.sym2id) 610 | print(len(self.questionDict.sym2id)) 611 | print(self.answerDict.sym2id) 612 | print(len(self.answerDict.sym2id)) 613 | print(self.qaDict.sym2id) 614 | print(len(self.qaDict.sym2id)) 615 | 616 | if noPadding: 617 | return embeddings # no embedding for padding symbol 618 | else: 619 | return embeddings[1:] 620 | 621 | ''' 622 | Initializes words embeddings for question words and optionally for answer words 623 | (when config.ansEmbMod == "BOTH"). If config.ansEmbMod == "SHARED", tie embeddings for 624 | question and answer same symbols. 625 | ''' 626 | def initializeQAEmbeddings(self): 627 | # use same embeddings for questions and answers 628 | if config.ansEmbMod == "SHARED": 629 | qaEmbeddings = self.initializeWordEmbeddings(self.qaDict) 630 | ansMap = np.array([self.qaDict.sym2id[sym] for sym in self.answerDict.id2sym]) 631 | embeddings = {"qa": qaEmbeddings, "ansMap": ansMap} 632 | # use different embeddings for questions and answers 633 | else: 634 | qEmbeddings = self.initializeWordEmbeddings(self.questionDict) 635 | aEmbeddings = None 636 | if config.ansEmbMod == "BOTH": 637 | aEmbeddings = self.initializeWordEmbeddings(self.answerDict, noPadding = True) 638 | embeddings = {"q": qEmbeddings, "a": aEmbeddings} 639 | return embeddings 640 | 641 | ''' 642 | Preprocesses a given dataset into numpy arrays: 643 | 1. Reads the input data files into dictionary. 644 | 2. Saves the results jsons in files and loads them instead of parsing input if files exist/ 645 | 3. Initializes word embeddings to random / GloVe. 646 | 4. Optionally filters data according to given filters. 647 | 5. Encodes and vectorize the data into numpy arrays. 648 | 5. Buckets the data according to the instances length. 649 | ''' 650 | def preprocessData(self, debug = False): 651 | # Read data into json and symbols' dictionaries 652 | print(bold("Loading data...")) 653 | start = time.time() 654 | mainDataset = self.readDataset(hasTrain = True) 655 | 656 | extraDataset = None 657 | if config.extra: 658 | # compositionalClevr doesn't have training dataset 659 | extraDataset = self.readDataset(suffix = "H", hasTrain = (not config.extraVal)) 660 | # extra dataset uses the same images 661 | if not config.extraVal: 662 | for tier in extraDataset: 663 | extraDataset[tier]["images"] = mainDataset[tier]["images"] 664 | 665 | print("took {:.2f} seconds".format(time.time() - start)) 666 | 667 | # Initialize word embeddings (random / glove) 668 | print(bold("Loading word vectors...")) 669 | start = time.time() 670 | embeddings = self.initializeQAEmbeddings() 671 | print("took {:.2f} seconds".format(time.time() - start)) 672 | 673 | # Prepare data: filter, bucket, and vectorize into numpy arrays 674 | print(bold("Vectorizing data...")) 675 | start = time.time() 676 | 677 | mainDataset = self.prepareDataset(mainDataset) 678 | # don't bucket for alternated data and also for humans data (small dataset) 679 | extraDataset = self.prepareDataset(extraDataset, 680 | noBucket = (not config.extraVal) or (not config.alterExtra)) 681 | 682 | data = {"main": mainDataset, "extra": extraDataset} 683 | print("took {:.2f} seconds".format(time.time() - start)) 684 | 685 | config.questionWordsNum = self.questionDict.getNumSymbols() 686 | config.answerWordsNum = self.answerDict.getNumSymbols() 687 | 688 | return data, embeddings, self.answerDict 689 | -------------------------------------------------------------------------------- /program_translator.py: -------------------------------------------------------------------------------- 1 | 2 | class ProgramTranslator(object): 3 | def __init__(self, programDict, maxArity): 4 | self.programDict = programDict 5 | self.maxArity = maxArity 6 | 7 | self.maxStack = 0 8 | 9 | def functionToKey(self, function, withValInputs = True): 10 | valInputs = "" 11 | if withValInputs: 12 | valInputs = "_" + ",".join(function["value_inputs"]) 13 | functionKey = function["function"] if "_" in function["function"] else \ 14 | "_".join([function["function"], function["function"]]) 15 | return str(len(function["inputs"])) + "_" + functionKey + valInputs 16 | 17 | def keyToFunction(self, key): 18 | assert key not in self.programDict.invalidSymbols 19 | function = {} 20 | parts = key.split("_") 21 | arity = int(parts[0]) 22 | function["function"] = "_".join([parts[1], parts[2]]) 23 | function["value_inputs"] = [] 24 | if len(parts) == 4: 25 | function["value_inputs"] = parts[3].split(",") 26 | function["inputs"] = [] 27 | return function, arity 28 | 29 | def keyToArity(self, key): 30 | if key in self.programDict.invalidSymbols: 31 | return 0 32 | return int(key.split("_")[0]) 33 | 34 | def keyToType(self, key): 35 | if key in self.programDict.invalidSymbols: 36 | return ["0", "0", "0"] 37 | return ["0:" + key.split("_")[0], "1:" + key.split("_")[1], "2:" + key.split("_")[2]] 38 | 39 | def programToPostfixProgram(self, program): 40 | newProgram = [] 41 | 42 | def programToPostfixAux(currIndex = -1): 43 | childrenIndices = program[currIndex]["inputs"] 44 | #[int(child) for child in program[currIndex]["inputs"]] 45 | childrenNewIndices = [] 46 | for child in childrenIndices: 47 | programToPostfixAux(child) 48 | childrenNewIndices.append(len(newProgram) - 1) 49 | program[currIndex]["inputs"] = childrenNewIndices 50 | newProgram.append(program[currIndex]) 51 | 52 | programToPostfixAux() 53 | return newProgram 54 | 55 | def programToSeq(self, program): 56 | return [self.functionToKey(function) for function in program] 57 | 58 | def programToInputs(self, program, offset = 0): 59 | inputs = [function["inputs"] for function in program] 60 | offsetedInputs = [[FuncInput + offset for FuncInput in FuncInputs] for FuncInputs in inputs] 61 | return offsetedInputs 62 | 63 | # def seqToProgram(self, seq, enforceValidPrograms = True): 64 | # program = [] 65 | 66 | # def seqToProgramAux(currIndex = len(seq) - 1): 67 | # if currIndex < 0: 68 | # program = None 69 | # return 70 | # currFunc, arity = self.keyToFunction(seq[currIndex]) 71 | # nextIndex = currIndex - 1 72 | # program.append(currFunc) 73 | # for _ in arity: 74 | # currFunc["inputs"].append(nextIndex) 75 | # nextIndex = seqToProgramAux(nextIndex) 76 | # currFunc["inputs"].reverse() 77 | # return nextIndex 78 | 79 | # if enforceValidPrograms: 80 | # seqToProgramAux() 81 | # if program is not None: 82 | # program.reverse() 83 | # else: 84 | # stack = [0] * self.maxArity 85 | # for i in range(len(seq)): 86 | # func, arity = self.keyToFunction(seq[i]) 87 | # func["inputs"] = stack[len(stack) - arity:] 88 | # newLength = max(len(stack) - arity, self.maxArity) 89 | # stack = stack[:newLength] + [i + self.maxArity] 90 | # self.maxStack = max(len(stack), self.maxStack) 91 | # program.append(func) 92 | 93 | # return program 94 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Compostional Attention Networks for Real-World Reasoning 2 |

3 | Drew A. Hudson & Christopher D. Manning 4 |

5 | 6 | ***Please note: We have updated the [GQA challenge](https://visualreasoning.net/challenge.html) deadline to be May 15. Best of Luck! :)*** 7 | 8 | This is the implementation of [Compositional Attention Networks for Machine Reasoning](https://arxiv.org/pdf/1803.03067.pdf) (ICLR 2018) on two visual reasoning datasets: [CLEVR dataset](http://cs.stanford.edu/people/jcjohns/clevr/) and the ***New*** [***GQA dataset***](https://visualreasoning.net) ([CVPR 2019](https://visualreasoning.net/gqaPaper.pdf)). We propose a fully differentiable model that learns to perform multi-step reasoning. 9 | See our [website](https://cs.stanford.edu/people/dorarad/mac/) and [blogpost](https://cs.stanford.edu/people/dorarad/mac/blog.html) for more information about the model! 10 | 11 | In particular, the implementation includes the MAC cell at [`mac_cell.py`](mac_cell.py). The code supports the standard cell as presented in the paper as well as additional extensions and variants. Run `python main.py -h` or see [`config.py`](config.py) for the complete list of options. 12 | 13 | The adaptation of MAC as well as several baselines for the GQA dataset are located at the **GQA** branch. 14 | 15 |
16 | 17 | 18 | 19 |
20 | 21 | ## Bibtex 22 | For MAC: 23 | ```bibtex 24 | @inproceedings{hudson2018compositional, 25 | title={Compositional Attention Networks for Machine Reasoning}, 26 | author={Hudson, Drew A and Manning, Christopher D}, 27 | journal={International Conference on Learning Representations (ICLR)}, 28 | year={2018} 29 | } 30 | ``` 31 | 32 | For the GQA dataset: 33 | ```bibtex 34 | @article{hudson2018gqa, 35 | title={GQA: A New Dataset for Real-World Visual Reasoning and Compositional Question Answering}, 36 | author={Hudson, Drew A and Manning, Christopher D}, 37 | journal={Conference on Computer Vision and Pattern Recognition (CVPR)}, 38 | year={2019} 39 | } 40 | ``` 41 | 42 | ## Requirements 43 | - Tensorflow (originally has been developed with 1.3 but should work for later versions as well). 44 | - We have performed experiments on Maxwell Titan X GPU. We assume 12GB of GPU memory. 45 | - See [`requirements.txt`](requirements.txt) for the required python packages and run `pip install -r requirements.txt` to install them. 46 | 47 | ## Pre-processing 48 | Before training the model, we first have to download the CLEVR dataset and extract features for the images: 49 | 50 | ### Dataset 51 | To download and unpack the data, run the following commands: 52 | ```bash 53 | wget https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip 54 | unzip CLEVR_v1.0.zip 55 | mv CLEVR_v1.0 CLEVR_v1 56 | mkdir CLEVR_v1/data 57 | mv CLEVR_v1/questions/* CLEVR_v1/data/ 58 | ``` 59 | The final command moves the dataset questions into the `data` directory, where we will put all the data files we use during training. 60 | 61 | ### Feature extraction 62 | Extract ResNet-101 features for the CLEVR train, val, and test images with the following commands: 63 | 64 | ```bash 65 | python extract_features.py --input_image_dir CLEVR_v1/images/train --output_h5_file CLEVR_v1/data/train.h5 --batch_size 32 66 | python extract_features.py --input_image_dir CLEVR_v1/images/val --output_h5_file CLEVR_v1/data/val.h5 --batch_size 32 67 | python extract_features.py --input_image_dir CLEVR_v1/images/test --output_h5_file CLEVR_v1/data/test.h5 --batch_size 32 68 | ``` 69 | 70 | ## Training 71 | To train the model, run the following command: 72 | ```bash 73 | python main.py --expName "clevrExperiment" --train --testedNum 10000 --epochs 25 --netLength 4 @configs/args.txt 74 | ``` 75 | 76 | First, the program preprocesses the CLEVR questions. It tokenizes them and maps them to integers to prepare them for the network. It then stores a JSON with that information about them as well as word-to-integer dictionaries in the `./CLEVR_v1/data` directory. 77 | 78 | Then, the program trains the model. Weights are saved by default to `./weights/{expName}` and statistics about the training are collected in `./results/{expName}`, where `expName` is the name we choose to give to the current experiment. 79 | 80 | ### Notes 81 | - The number of examples used for training and evaluation can be set by `--trainedNum` and `--testedNum` respectively. 82 | - You can use the `-r` flag to restore and continue training a previously pre-trained model. 83 | - We recommend you to try out varying the number of MAC cells used in the network through the `--netLength` option to explore different lengths of reasoning processes. 84 | - Good lengths for CLEVR are in the range of 4-16 (using more cells tends to converge faster and achieves a bit higher accuracy, while lower number of cells usually results in more easily interpretable attention maps). 85 | 86 | ### Model variants 87 | We have explored several variants of our model. We provide a few examples in `configs/args2-4.txt`. For instance, you can run the first by: 88 | ```bash 89 | python main.py --expName "experiment1" --train --testedNum 10000 --epochs 40 --netLength 6 @configs/args2.txt 90 | ``` 91 | - [`args2`](configs/args2.txt) uses a non-recurrent variant of the control unit that converges faster. 92 | - [`args3`](configs/args3.txt) incorporates self-attention into the write unit. 93 | - [`args4`](configs/args4.txt) adds control-based gating over the memory. 94 | 95 | See [`config.py`](config.py) for further available options (Note that some of them are still in an experimental stage). 96 | 97 | ## Evalutation 98 | To evaluate the trained model, and get predictions and attention maps, run the following: 99 | ```bash 100 | python main.py --expName "clevrExperiment" --finalTest --testedNum 10000 --netLength 16 -r --getPreds --getAtt @configs/args.txt 101 | ``` 102 | The command will restore the model we have trained, and evaluate it on the validation set. JSON files with predictions and the attention distributions resulted by running the model are saved by default to `./preds/{expName}`. 103 | 104 | - In case you are interested in getting attention maps (`--getAtt`), and to avoid having large prediction files, we advise you to limit the number of examples evaluated to 5,000-20,000. 105 | 106 | ## Visualization 107 | After we evaluate the model with the command above, we can visualize the attention maps generated by running: 108 | ```bash 109 | python visualization.py --expName "clevrExperiment" --tier val 110 | ``` 111 | (Tier can be set to `train` or `test` as well). The script supports filtering of the visualized questions by various ways. See [`visualization.py`](visualization.py) for further details. 112 | 113 | To get more interpretable visualizations, it is highly recommended to reduce the number of cells to 4-8 (`--netLength`). Using more cells allows the network to learn more effective ways to approach the task but these tend to be less interpretable compared to a shorter networks (with less cells). 114 | 115 | Optionally, to make the image attention maps look a little bit nicer, you can do the following (using [imagemagick](https://www.imagemagick.org)): 116 | ``` 117 | for x in preds/clevrExperiment/*Img*.png; do magick convert $x -brightness-contrast 20x35 $x; done; 118 | ``` 119 | 120 | Thank you for your interest in our model! Please contact me at dorarad@cs.stanford.edu for any questions, comments, or suggestions! :-) 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | Pillow 4 | scipy 5 | torchvision 6 | h5py 7 | tensorflow 8 | tqdm 9 | termcolor 10 | matplotlib 11 | seaborn 12 | pandas 13 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns; sns.set() 7 | from scipy.misc import imread, imresize 8 | from matplotlib.colors import Normalize, LinearSegmentedColormap 9 | 10 | flatten = lambda ll: [e for l in ll for e in l] 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | # experiment settings 15 | parser.add_argument("--tier", default = "val", choices = ["train", "val", "test"], type = str) 16 | parser.add_argument("--expName", default = "experiment", type = str) 17 | 18 | # plotting 19 | parser.add_argument("--cmap", default = "custom", type = str) # "gnuplot2", "GreysT" 20 | 21 | parser.add_argument("--trans", help = "transpose question attention", action = "store_true") 22 | parser.add_argument("--sa", action = "store_true") 23 | parser.add_argument("--gate", action = "store_true") 24 | 25 | # filtering 26 | parser.add_argument("--instances", nargs = "*", type = int) 27 | parser.add_argument("--maxNum", default = 0, type = int) 28 | 29 | parser.add_argument("--filter", default = [], nargs = "*", choices = ["mod", "length", "field"]) 30 | parser.add_argument("--filterMod", action = "store_true") 31 | parser.add_argument("--filterLength", type = int) # 19 32 | parser.add_argument("--filterField", type = str) 33 | parser.add_argument("--filterIn", action = "store_true") 34 | parser.add_argument("--filterList", nargs = "*") # ["how many", "more"], numbers 35 | 36 | args = parser.parse_args() 37 | 38 | isRight = lambda instance: instance["answer"] == instance["prediction"] 39 | isRightStr = lambda instance: "RIGHT" if isRight(instance) else "WRONG" 40 | 41 | # files 42 | # jsonFilename = "valHPredictions.json" if args.humans else "valPredictions.json" 43 | imagesDir = "./CLEVR_v1/images/{tier}".format( 44 | tier = args.tier) 45 | 46 | dataFile = "./preds/{expName}/{tier}Predictions-{expName}.json".format( 47 | tier = args.tier, 48 | expName = args.expName) 49 | 50 | inImgName = lambda index: "{dir}/CLEVR_{tier}_{index}.png".format( 51 | dir = imagesDir, 52 | index = ("000000%d" % index)[-6:], 53 | tier = args.tier) 54 | 55 | outImgAttName = lambda instance, j: "./preds/{expName}/{tier}{id}Img_{step}.png".format( 56 | expName = args.expName, 57 | tier = args.tier, 58 | id = instance["index"], 59 | step = j + 1) 60 | 61 | outTableAttName = lambda instance, name: "./preds/{expName}/{tier}{id}{tableName}_{right}{orientation}.png".format( 62 | expName = args.expName, 63 | tier = args.tier, 64 | id = instance["index"], 65 | tableName = name, 66 | right = isRightStr(instance), 67 | orientation = "_t" if args.trans else "") 68 | 69 | # plotting 70 | imageDims = (14,14) 71 | figureImageDims = (2,3) 72 | figureTableDims = (5,4) 73 | fontScale = 1 74 | 75 | # set transparent mask for low attention areas 76 | # cdict = plt.get_cmap("gnuplot2")._segmentdata 77 | cdict = {"red": ((0.0, 0.0, 0.0), (0.6, 0.8, 0.8), (1.0, 1, 1)), 78 | "green": ((0.0, 0.0, 0.0), (0.6, 0.8, 0.8), (1.0, 1, 1)), 79 | "blue": ((0.0, 0.0, 0.0), (0.6, 0.8, 0.8), (1.0, 1, 1))} 80 | cdict["alpha"] = ((0.0, 0.35, 0.35), 81 | (1.0,0.65, 0.65)) 82 | plt.register_cmap(name = "custom", data = cdict) 83 | 84 | def savePlot(fig, fileName): 85 | plt.savefig(fileName, dpi = 720) 86 | plt.close(fig) 87 | del fig 88 | 89 | def filter(instance): 90 | if "length" in args.filter: 91 | if len(instance["question"].split(" ")) > args.filterLength: 92 | return True 93 | 94 | if "field" in args.filter: 95 | if args.filterIn: 96 | if not (instance[args.filterField] in args.filterList): 97 | return True 98 | else: 99 | if not any((l in instance[args.filterField]) for l in args.filterList): 100 | return True 101 | 102 | if "mod" in args.filter: 103 | if (not isRight(instance)) and args.filterMod: 104 | return True 105 | 106 | if isRight(instance) and (not args.filterMod): 107 | return True 108 | 109 | return False 110 | 111 | def showImgAtt(img, instance, step, ax): 112 | dx, dy = 0.05, 0.05 113 | x = np.arange(-1.5, 1.5, dx) 114 | y = np.arange(-1.0, 1.0, dy) 115 | X, Y = np.meshgrid(x, y) 116 | extent = np.min(x), np.max(x), np.min(y), np.max(y) 117 | 118 | ax.cla() 119 | 120 | img1 = ax.imshow(img, interpolation = "nearest", extent = extent) 121 | ax.imshow(np.array(instance["attentions"]["kb"][step]).reshape(imageDims), cmap = plt.get_cmap(args.cmap), 122 | interpolation = "bicubic", extent = extent) 123 | 124 | ax.set_axis_off() 125 | plt.axis("off") 126 | 127 | ax.set_aspect("auto") 128 | 129 | 130 | def showImgAtts(instance): 131 | img = imread(inImgName(instance["imageId"])) 132 | 133 | length = len(instance["attentions"]["kb"]) 134 | 135 | # show images 136 | for j in range(length): 137 | fig, ax = plt.subplots() 138 | fig.set_figheight(figureImageDims[0]) 139 | fig.set_figwidth(figureImageDims[1]) 140 | 141 | showImgAtt(img, instance, j, ax) 142 | 143 | plt.subplots_adjust(bottom = 0, top = 1, left = 0, right = 1) 144 | savePlot(fig, outImgAttName(instance, j)) 145 | 146 | def showTableAtt(instance, table, x, y, name): 147 | # if args.trans: 148 | # figureTableDims = (len(y) / 2 + 4, len(x) + 2) 149 | # else: 150 | # figureTableDims = (len(y) / 2, len(x) / 2) 151 | # xx = np.arange(0, len(x), 1) 152 | # yy = np.arange(0, len(y), 1) 153 | # extent2 = np.min(xx), np.max(xx), np.min(yy), np.max(yy) 154 | 155 | fig2, bx = plt.subplots(1, 1) # figsize = figureTableDims 156 | bx.cla() 157 | 158 | sns.set(font_scale = fontScale) 159 | 160 | if args.trans: 161 | table = np.transpose(table) 162 | x, y = y, x 163 | 164 | tableMap = pandas.DataFrame(data = table, index = x, columns = y) 165 | 166 | bx = sns.heatmap(tableMap, cmap = "Purples", cbar = False, linewidths = .5, linecolor = "gray", square = True) 167 | 168 | # x ticks 169 | if args.trans: 170 | bx.xaxis.tick_top() 171 | locs, labels = plt.xticks() 172 | if args.trans: 173 | plt.setp(labels, rotation = 0) 174 | else: 175 | plt.setp(labels, rotation = 60) 176 | 177 | # y ticks 178 | locs, labels = plt.yticks() 179 | plt.setp(labels, rotation = 0) 180 | 181 | plt.savefig(outTableAttName(instance, name), dpi = 720) 182 | 183 | def main(): 184 | with open(dataFile) as inFile: 185 | results = json.load(inFile) 186 | 187 | # print(args.exp) 188 | 189 | count = 0 190 | if args.instances is None: 191 | args.instances = range(len(results)) 192 | 193 | for i in args.instances: 194 | if filter(results[i]): 195 | continue 196 | 197 | if count > args.maxNum and args.maxNum > 0: 198 | break 199 | count += 1 200 | 201 | length = len(results[i]["attentions"]["kb"]) 202 | showImgAtts(results[i]) 203 | 204 | iterations = range(1, length + 1) 205 | questionList = results[i]["question"].split(" ") 206 | table = np.array(results[i]["attentions"]["question"])[:,:(len(questionList) + 1)] 207 | showTableAtt(results[i], table, iterations, questionList, "text") 208 | 209 | if args.sa: 210 | iterations = range(length) 211 | sa = np.zeros((length, length)) 212 | for i in range(length): 213 | for j in range(i+1): 214 | sa[i][j] = results[i]["attentions"]["self"][i][j] 215 | 216 | showTableAtt(results[i], sa[i][j], iterations, iterations, "sa") 217 | 218 | print(i) 219 | print("id:", results[i]["index"]) 220 | print("img:", results[i]["imageId"]) 221 | print("Q:", results[i]["question"]) 222 | print("G:", results[i]["answer"]) 223 | print("P:", results[i]["prediction"]) 224 | print(isRightStr(results[i])) 225 | 226 | if args.gate: 227 | print(results[i]["attentions"]["gate"]) 228 | 229 | print("________________________________________________________________________") 230 | 231 | if __name__ == "__main__": 232 | main() 233 | --------------------------------------------------------------------------------