├── .gitignore ├── README.md ├── birds ├── .gitignore ├── LICENSE.txt ├── README.md ├── custom_filelists │ └── CUB │ │ ├── base.json │ │ ├── novel.json │ │ └── val.json ├── exp │ └── README.md ├── fewshot │ ├── backbone.py │ ├── constants.py │ ├── data │ │ ├── __init__.py │ │ ├── additional_transforms.py │ │ ├── datamgr.py │ │ ├── dataset.py │ │ └── lang_utils.py │ ├── io_utils.py │ ├── models │ │ ├── __init__.py │ │ ├── language.py │ │ └── protonet.py │ ├── run_cl.py │ ├── test.py │ └── train.py ├── filelists │ └── CUB │ │ ├── download_CUB.sh │ │ ├── save_np.py │ │ └── write_CUB_filelist.py ├── run_l3.sh ├── run_lang_ablation.sh ├── run_lang_amount.sh ├── run_lsl.sh └── run_meta.sh └── shapeworld ├── .gitignore ├── LICENSE ├── README.md ├── analysis ├── analysis.Rproj └── metrics.Rmd ├── exp ├── README.md ├── l3 │ ├── args.json │ └── metrics.json ├── lsl │ ├── args.json │ └── metrics.json ├── lsl_color │ ├── args.json │ └── metrics.json ├── lsl_nocolor │ ├── args.json │ └── metrics.json ├── lsl_shuffle_captions │ ├── args.json │ └── metrics.json ├── lsl_shuffle_words │ ├── args.json │ └── metrics.json └── meta │ ├── args.json │ └── metrics.json ├── lsl ├── datasets.py ├── models.py ├── train.py ├── tre.py ├── utils.py └── vision.py ├── run_l3.sh ├── run_lang_ablation.sh ├── run_lsl.sh ├── run_lsl_img.sh └── run_meta.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | 6 | sync_results.sh 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shaping Visual Representations with Language for Few-shot Classification 2 | 3 | Code and data for 4 | 5 | > Jesse Mu, Percy Liang, and Noah Goodman. Shaping Visual Representations with Language for Few-shot Classification. ACL 2020. https://arxiv.org/abs/1911.02683 6 | 7 | In addition, a CodaLab executable paper (docker containers with code, data, and experiment runs) is available [here](https://bit.ly/lsl_acl20). There are some minor fixes for CodaLab compatibility on the codalab branch. 8 | 9 | The codebase is split into two repositories, `shapeworld` and `birds`, for the 10 | different tasks explored in this paper. Each have their own READMEs, 11 | instructions, and licenses, since they were extended from different existing 12 | codebases. 13 | 14 | If you found this code useful, please cite 15 | 16 | ``` 17 | @inproceedings{mu2020shaping, 18 | author = {Jesse Mu, Percy Liang, and Noah Goodman}, 19 | title = {Shaping Visual Representations with Language for Few-Shot Classification}, 20 | booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics}, 21 | year = {2020} 22 | } 23 | ``` 24 | -------------------------------------------------------------------------------- /birds/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | 4 | saves/* 5 | !saves/README.md 6 | 7 | filelists/CUB/* 8 | !filelists/CUB/download_cub.sh 9 | !filelists/CUB/*.py 10 | 11 | filelists/scenes/* 12 | !filelists/scenes/download_scenes.sh 13 | !filelists/scenes/*.py 14 | 15 | # Ignore codalab 16 | /checkpoints/ 17 | /features/ 18 | /args.json 19 | /results.json 20 | /reed-birds 21 | .Rproj.user 22 | 23 | .Rhistory 24 | 25 | *.out 26 | 27 | exp/* 28 | exp/*/*/checkpoints/*.tar 29 | exp/*/*/features/*.hdf5 30 | !exp/README.md 31 | 32 | test/* 33 | 34 | *.RData 35 | 36 | analysis/*.html 37 | -------------------------------------------------------------------------------- /birds/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /birds/README.md: -------------------------------------------------------------------------------- 1 | # LSL - Birds 2 | 3 | This codebase is built off of [wyharveychen/CloserLookFewShot](https://github.com/wyharveychen/CloserLookFewShot) ([paper](https://openreview.net/pdf?id=HkxLXnAcFQ)) - thanks to them! 4 | 5 | ## Dependencies 6 | 7 | Tested with Python 3.7.4, torch 1.4.0, torchvision 0.4.1, numpy 1.16.2, PIL 8 | 5.4.1, torchfile 0.1.0, sklearn 0.20.3, pandas 0.25.2 9 | 10 | Glove initialization depends on spacy 2.2.2 and the spacy `en_vectors_web_lg` 11 | model: 12 | 13 | ``` 14 | python -m spacy download en_vectors_web_lg 15 | ``` 16 | 17 | ## Data 18 | 19 | To download data, cd to `filelists/CUB` and run `source download_CUB.sh`. This 20 | downloads the CUB 200-2011 dataset and also runs `python write_CUB_filelist.py`. 21 | 22 | `python write_CUB_filelist.py` saves a filelist (train/val/test) split 23 | to `./custom_filelists/CUB/{base,val,novel}.json`. 24 | 25 | Then run `python save_np.py` which takes the images and serializes them as NP arrays 26 | (for speed). 27 | 28 | The language data is available from 29 | [reedscot/cvpr2016](https://github.com/reedscot/cvpr2016) ([GDrive link](https://drive.google.com/open?id=0B0ywwgffWnLLZW9uVHNjb2JmNlE)). Download it and unzip to `reed-birds` directory in the main directory (e.g. the path to the vocab file should be `./reed-birds/vocab_c10.t7`). 30 | 31 | ## Running 32 | 33 | To train and evaluate a model, you will run `fewshot/train.py` and `fewshot/test.py`, 34 | respectively. Alternatively, for CodaLab, the `fewshot/run_cl.py` script does 35 | both training and testing, with slightly more friendly argument names 36 | (`fewshot/run_cl.py --help`) for more. 37 | 38 | The shell scripts contain commands for running the various models: 39 | 40 | - `run_meta.sh`: Non-linguistic protonet baseline 41 | - `run_l3.sh`: learning with latent language (Andreas et al., 2018) 42 | - `run_lsl.sh`: Ours 43 | - `run_lang_ablation.sh`: Language ablation studies 44 | - `run_lang_amount.sh`: Language amount 45 | 46 | ## References 47 | 48 | (from the original CloserLookFewShot repo) 49 | 50 | Our testbed builds upon several existing publicly available code. Specifically, we have modified and integrated the following code into this project: 51 | 52 | * Framework, Backbone, Method: Matching Network 53 | https://github.com/facebookresearch/low-shot-shrink-hallucinate 54 | * Omniglot dataset, Method: Prototypical Network 55 | https://github.com/jakesnell/prototypical-networks 56 | * Method: Relational Network 57 | https://github.com/floodsung/LearningToCompare_FSL 58 | * Method: MAML 59 | https://github.com/cbfinn/maml 60 | https://github.com/dragen1860/MAML-Pytorch 61 | https://github.com/katerakelly/pytorch-maml 62 | -------------------------------------------------------------------------------- /birds/exp/README.md: -------------------------------------------------------------------------------- 1 | # Exp 2 | 3 | Placeholder for model experiments. 4 | 5 | Use `fewshot/run_cl.py` with a `--log_dir` pointing to a directory within this 6 | folder. 7 | -------------------------------------------------------------------------------- /birds/fewshot/backbone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backbone vision models. 3 | 4 | This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torchvision.models as models 13 | 14 | CONV4_HIDDEN_SIZES = [112896, 28224, 6400, 1600] 15 | 16 | 17 | class Identity(nn.Module): 18 | def __init__(self): 19 | super(Identity, self).__init__() 20 | 21 | def forward(self, x): 22 | return x 23 | 24 | 25 | # Basic ResNet model 26 | def init_layer(L): 27 | # Initialization using fan-in 28 | if isinstance(L, nn.Conv2d): 29 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels 30 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n))) 31 | elif isinstance(L, nn.BatchNorm2d): 32 | L.weight.data.fill_(1) 33 | L.bias.data.fill_(0) 34 | 35 | 36 | class Flatten(nn.Module): 37 | def __init__(self): 38 | super(Flatten, self).__init__() 39 | 40 | def forward(self, x): 41 | return x.view(x.size(0), -1) 42 | 43 | 44 | class Linear_fw(nn.Linear): # used in MAML to forward input with fast weight 45 | def __init__(self, in_features, out_features): 46 | super(Linear_fw, self).__init__(in_features, out_features) 47 | self.weight.fast = None # Lazy hack to add fast weight link 48 | self.bias.fast = None 49 | 50 | def forward(self, x): 51 | if self.weight.fast is not None and self.bias.fast is not None: 52 | out = F.linear(x, self.weight.fast, self.bias.fast) 53 | else: 54 | out = super(Linear_fw, self).forward(x) 55 | return out 56 | 57 | 58 | class Conv2d_fw(nn.Conv2d): # used in MAML to forward input with fast weight 59 | def __init__( 60 | self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True 61 | ): 62 | super(Conv2d_fw, self).__init__( 63 | in_channels, 64 | out_channels, 65 | kernel_size, 66 | stride=stride, 67 | padding=padding, 68 | bias=bias, 69 | ) 70 | self.weight.fast = None 71 | if self.bias is not None: 72 | self.bias.fast = None 73 | 74 | def forward(self, x): 75 | if self.bias is None: 76 | if self.weight.fast is not None: 77 | out = F.conv2d( 78 | x, self.weight.fast, None, stride=self.stride, padding=self.padding 79 | ) 80 | else: 81 | out = super(Conv2d_fw, self).forward(x) 82 | else: 83 | if self.weight.fast is not None and self.bias.fast is not None: 84 | out = F.conv2d( 85 | x, 86 | self.weight.fast, 87 | self.bias.fast, 88 | stride=self.stride, 89 | padding=self.padding, 90 | ) 91 | else: 92 | out = super(Conv2d_fw, self).forward(x) 93 | 94 | return out 95 | 96 | 97 | class BatchNorm2d_fw(nn.BatchNorm2d): # used in MAML to forward input with fast weight 98 | def __init__(self, num_features): 99 | super(BatchNorm2d_fw, self).__init__(num_features) 100 | self.weight.fast = None 101 | self.bias.fast = None 102 | 103 | def forward(self, x): 104 | running_mean = torch.zeros(x.data.size()[1]).cuda() 105 | running_var = torch.ones(x.data.size()[1]).cuda() 106 | if self.weight.fast is not None and self.bias.fast is not None: 107 | out = F.batch_norm( 108 | x, 109 | running_mean, 110 | running_var, 111 | self.weight.fast, 112 | self.bias.fast, 113 | training=True, 114 | momentum=1, 115 | ) 116 | # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py 117 | else: 118 | out = F.batch_norm( 119 | x, 120 | running_mean, 121 | running_var, 122 | self.weight, 123 | self.bias, 124 | training=True, 125 | momentum=1, 126 | ) 127 | return out 128 | 129 | 130 | # Simple Conv Block 131 | class ConvBlock(nn.Module): 132 | maml = False # Default 133 | 134 | def __init__(self, indim, outdim, pool=True, padding=1): 135 | super(ConvBlock, self).__init__() 136 | self.indim = indim 137 | self.outdim = outdim 138 | if self.maml: 139 | self.C = Conv2d_fw(indim, outdim, 3, padding=padding) 140 | self.BN = BatchNorm2d_fw(outdim) 141 | else: 142 | self.C = nn.Conv2d(indim, outdim, 3, padding=padding) 143 | self.BN = nn.BatchNorm2d(outdim) 144 | self.relu = nn.ReLU(inplace=True) 145 | 146 | self.parametrized_layers = [self.C, self.BN, self.relu] 147 | if pool: 148 | self.pool = nn.MaxPool2d(2) 149 | self.parametrized_layers.append(self.pool) 150 | 151 | for layer in self.parametrized_layers: 152 | init_layer(layer) 153 | 154 | self.trunk = nn.Sequential(*self.parametrized_layers) 155 | 156 | def forward(self, x): 157 | out = self.trunk(x) 158 | return out 159 | 160 | 161 | # Simple ResNet Block 162 | class SimpleBlock(nn.Module): 163 | maml = False # Default 164 | 165 | def __init__(self, indim, outdim, half_res): 166 | super(SimpleBlock, self).__init__() 167 | self.indim = indim 168 | self.outdim = outdim 169 | if self.maml: 170 | self.C1 = Conv2d_fw( 171 | indim, 172 | outdim, 173 | kernel_size=3, 174 | stride=2 if half_res else 1, 175 | padding=1, 176 | bias=False, 177 | ) 178 | self.BN1 = BatchNorm2d_fw(outdim) 179 | self.C2 = Conv2d_fw(outdim, outdim, kernel_size=3, padding=1, bias=False) 180 | self.BN2 = BatchNorm2d_fw(outdim) 181 | else: 182 | self.C1 = nn.Conv2d( 183 | indim, 184 | outdim, 185 | kernel_size=3, 186 | stride=2 if half_res else 1, 187 | padding=1, 188 | bias=False, 189 | ) 190 | self.BN1 = nn.BatchNorm2d(outdim) 191 | self.C2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1, bias=False) 192 | self.BN2 = nn.BatchNorm2d(outdim) 193 | self.relu1 = nn.ReLU(inplace=True) 194 | self.relu2 = nn.ReLU(inplace=True) 195 | 196 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 197 | 198 | self.half_res = half_res 199 | 200 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 201 | if indim != outdim: 202 | if self.maml: 203 | self.shortcut = Conv2d_fw( 204 | indim, outdim, 1, 2 if half_res else 1, bias=False 205 | ) 206 | self.BNshortcut = BatchNorm2d_fw(outdim) 207 | else: 208 | self.shortcut = nn.Conv2d( 209 | indim, outdim, 1, 2 if half_res else 1, bias=False 210 | ) 211 | self.BNshortcut = nn.BatchNorm2d(outdim) 212 | 213 | self.parametrized_layers.append(self.shortcut) 214 | self.parametrized_layers.append(self.BNshortcut) 215 | self.shortcut_type = "1x1" 216 | else: 217 | self.shortcut_type = "identity" 218 | 219 | for layer in self.parametrized_layers: 220 | init_layer(layer) 221 | 222 | def forward(self, x): 223 | out = self.C1(x) 224 | out = self.BN1(out) 225 | out = self.relu1(out) 226 | out = self.C2(out) 227 | out = self.BN2(out) 228 | short_out = ( 229 | x if self.shortcut_type == "identity" else self.BNshortcut(self.shortcut(x)) 230 | ) 231 | out = out + short_out 232 | out = self.relu2(out) 233 | return out 234 | 235 | 236 | # Bottleneck block 237 | class BottleneckBlock(nn.Module): 238 | maml = False # Default 239 | 240 | def __init__(self, indim, outdim, half_res): 241 | super(BottleneckBlock, self).__init__() 242 | bottleneckdim = int(outdim / 4) 243 | self.indim = indim 244 | self.outdim = outdim 245 | if self.maml: 246 | self.C1 = Conv2d_fw(indim, bottleneckdim, kernel_size=1, bias=False) 247 | self.BN1 = BatchNorm2d_fw(bottleneckdim) 248 | self.C2 = Conv2d_fw( 249 | bottleneckdim, 250 | bottleneckdim, 251 | kernel_size=3, 252 | stride=2 if half_res else 1, 253 | padding=1, 254 | ) 255 | self.BN2 = BatchNorm2d_fw(bottleneckdim) 256 | self.C3 = Conv2d_fw(bottleneckdim, outdim, kernel_size=1, bias=False) 257 | self.BN3 = BatchNorm2d_fw(outdim) 258 | else: 259 | self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False) 260 | self.BN1 = nn.BatchNorm2d(bottleneckdim) 261 | self.C2 = nn.Conv2d( 262 | bottleneckdim, 263 | bottleneckdim, 264 | kernel_size=3, 265 | stride=2 if half_res else 1, 266 | padding=1, 267 | ) 268 | self.BN2 = nn.BatchNorm2d(bottleneckdim) 269 | self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False) 270 | self.BN3 = nn.BatchNorm2d(outdim) 271 | 272 | self.relu = nn.ReLU() 273 | self.parametrized_layers = [ 274 | self.C1, 275 | self.BN1, 276 | self.C2, 277 | self.BN2, 278 | self.C3, 279 | self.BN3, 280 | ] 281 | self.half_res = half_res 282 | 283 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 284 | if indim != outdim: 285 | if self.maml: 286 | self.shortcut = Conv2d_fw( 287 | indim, outdim, 1, stride=2 if half_res else 1, bias=False 288 | ) 289 | else: 290 | self.shortcut = nn.Conv2d( 291 | indim, outdim, 1, stride=2 if half_res else 1, bias=False 292 | ) 293 | 294 | self.parametrized_layers.append(self.shortcut) 295 | self.shortcut_type = "1x1" 296 | else: 297 | self.shortcut_type = "identity" 298 | 299 | for layer in self.parametrized_layers: 300 | init_layer(layer) 301 | 302 | def forward(self, x): 303 | 304 | short_out = x if self.shortcut_type == "identity" else self.shortcut(x) 305 | out = self.C1(x) 306 | out = self.BN1(out) 307 | out = self.relu(out) 308 | out = self.C2(out) 309 | out = self.BN2(out) 310 | out = self.relu(out) 311 | out = self.C3(out) 312 | out = self.BN3(out) 313 | out = out + short_out 314 | 315 | out = self.relu(out) 316 | return out 317 | 318 | 319 | class ConvNet(nn.Module): 320 | def __init__(self, depth, flatten=True): 321 | super(ConvNet, self).__init__() 322 | trunk = [] 323 | for i in range(depth): 324 | indim = 3 if i == 0 else 64 325 | outdim = 64 326 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers 327 | trunk.append(B) 328 | 329 | self.flatten = flatten 330 | if self.flatten: 331 | trunk.append(Flatten()) 332 | 333 | self.trunk = nn.Sequential(*trunk) 334 | self.final_feat_dim = 1600 335 | 336 | def forward(self, x): 337 | out = self.trunk(x) 338 | return out 339 | 340 | def forward_seq(self, x): 341 | hiddens = [] 342 | if self.flatten: 343 | seq = self.trunk[:-1] 344 | else: 345 | seq = self.trunk[:-1] 346 | for layer in seq: 347 | x = layer(x) 348 | hiddens.append(x) 349 | return hiddens 350 | 351 | 352 | class ConvNetNopool( 353 | nn.Module 354 | ): # Relation net use a 4 layer conv with pooling in only first two layers, else no pooling 355 | def __init__(self, depth): 356 | super(ConvNetNopool, self).__init__() 357 | trunk = [] 358 | for i in range(depth): 359 | indim = 3 if i == 0 else 64 360 | outdim = 64 361 | B = ConvBlock( 362 | indim, outdim, pool=(i in [0, 1]), padding=0 if i in [0, 1] else 1 363 | ) # only first two layer has pooling and no padding 364 | trunk.append(B) 365 | 366 | self.trunk = nn.Sequential(*trunk) 367 | self.final_feat_dim = [64, 19, 19] 368 | 369 | def forward(self, x): 370 | out = self.trunk(x) 371 | return out 372 | 373 | 374 | class ConvNetS(nn.Module): # For omniglot, only 1 input channel, output dim is 64 375 | def __init__(self, depth, flatten=True): 376 | super(ConvNetS, self).__init__() 377 | trunk = [] 378 | for i in range(depth): 379 | indim = 1 if i == 0 else 64 380 | outdim = 64 381 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers 382 | trunk.append(B) 383 | 384 | if flatten: 385 | trunk.append(Flatten()) 386 | 387 | self.trunk = nn.Sequential(*trunk) 388 | self.final_feat_dim = 64 389 | 390 | def forward(self, x): 391 | out = x[:, 0:1, :, :] # only use the first dimension 392 | out = self.trunk(out) 393 | return out 394 | 395 | 396 | class ConvNetSNopool(nn.Module): 397 | def __init__(self, depth): 398 | super(ConvNetSNopool, self).__init__() 399 | trunk = [] 400 | for i in range(depth): 401 | indim = 1 if i == 0 else 64 402 | outdim = 64 403 | B = ConvBlock( 404 | indim, outdim, pool=(i in [0, 1]), padding=0 if i in [0, 1] else 1 405 | ) # only first two layer has pooling and no padding 406 | trunk.append(B) 407 | 408 | self.trunk = nn.Sequential(*trunk) 409 | self.final_feat_dim = [64, 5, 5] 410 | 411 | def forward(self, x): 412 | out = x[:, 0:1, :, :] # only use the first dimension 413 | out = self.trunk(out) 414 | return out 415 | 416 | 417 | class ResNet(nn.Module): 418 | maml = False # Default 419 | 420 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=True): 421 | # list_of_num_layers specifies number of layers in each stage 422 | # list_of_out_dims specifies number of output channel for each stage 423 | super(ResNet, self).__init__() 424 | assert len(list_of_num_layers) == 4, "Can have only four stages" 425 | if self.maml: 426 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 427 | bn1 = BatchNorm2d_fw(64) 428 | else: 429 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 430 | bn1 = nn.BatchNorm2d(64) 431 | 432 | relu = nn.ReLU() 433 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 434 | 435 | init_layer(conv1) 436 | init_layer(bn1) 437 | 438 | trunk = [conv1, bn1, relu, pool1] 439 | 440 | indim = 64 441 | for i in range(4): 442 | 443 | for j in range(list_of_num_layers[i]): 444 | half_res = (i >= 1) and (j == 0) 445 | B = block(indim, list_of_out_dims[i], half_res) 446 | trunk.append(B) 447 | indim = list_of_out_dims[i] 448 | 449 | if flatten: 450 | avgpool = nn.AvgPool2d(7) 451 | trunk.append(avgpool) 452 | trunk.append(Flatten()) 453 | self.final_feat_dim = indim 454 | else: 455 | self.final_feat_dim = [indim, 7, 7] 456 | 457 | self.trunk = nn.Sequential(*trunk) 458 | 459 | def forward(self, x): 460 | out = self.trunk(x) 461 | return out 462 | 463 | 464 | def Conv4(): 465 | return ConvNet(4) 466 | 467 | 468 | def Conv6(): 469 | return ConvNet(6) 470 | 471 | 472 | def Conv4NP(): 473 | return ConvNetNopool(4) 474 | 475 | 476 | def Conv6NP(): 477 | return ConvNetNopool(6) 478 | 479 | 480 | def Conv4S(): 481 | return ConvNetS(4) 482 | 483 | 484 | def Conv4SNP(): 485 | return ConvNetSNopool(4) 486 | 487 | 488 | def ResNet10(flatten=True): 489 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten) 490 | 491 | 492 | def ResNet18(flatten=True): 493 | return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], flatten) 494 | 495 | 496 | def PretrainedResNet18(): 497 | rn18 = models.resnet18(pretrained=True) 498 | rn18.final_feat_dim = 512 499 | rn18.fc = Identity() # We don't use final fc 500 | return rn18 501 | 502 | 503 | def ResNet34(flatten=True): 504 | return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], flatten) 505 | 506 | 507 | def ResNet50(flatten=True): 508 | return ResNet(BottleneckBlock, [3, 4, 6, 3], [256, 512, 1024, 2048], flatten) 509 | 510 | 511 | def ResNet101(flatten=True): 512 | return ResNet(BottleneckBlock, [3, 4, 23, 3], [256, 512, 1024, 2048], flatten) 513 | -------------------------------------------------------------------------------- /birds/fewshot/constants.py: -------------------------------------------------------------------------------- 1 | DATA_DIR = "./custom_filelists/CUB/" 2 | LANG_DIR = "./reed-birds/" 3 | -------------------------------------------------------------------------------- /birds/fewshot/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import additional_transforms, datamgr, dataset 2 | -------------------------------------------------------------------------------- /birds/fewshot/data/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from PIL import ImageEnhance 9 | 10 | transformtypedict = dict( 11 | Brightness=ImageEnhance.Brightness, 12 | Contrast=ImageEnhance.Contrast, 13 | Sharpness=ImageEnhance.Sharpness, 14 | Color=ImageEnhance.Color, 15 | ) 16 | 17 | 18 | class ImageJitter(object): 19 | def __init__(self, transformdict): 20 | self.transforms = [ 21 | (transformtypedict[k], transformdict[k]) for k in transformdict 22 | ] 23 | 24 | def __call__(self, img): 25 | out = img 26 | randtensor = torch.rand(len(self.transforms)) 27 | 28 | for i, (transformer, alpha) in enumerate(self.transforms): 29 | r = alpha * (randtensor[i] * 2.0 - 1.0) + 1 30 | out = transformer(out).enhance(r).convert("RGB") 31 | 32 | return out 33 | -------------------------------------------------------------------------------- /birds/fewshot/data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from 2 | # https://github.com/facebookresearch/low-shot-shrink-hallucinate 3 | 4 | from abc import abstractmethod 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | 9 | import data.additional_transforms as add_transforms 10 | from data.dataset import EpisodicBatchSampler, SetDataset, SimpleDataset 11 | 12 | 13 | class TransformLoader: 14 | def __init__( 15 | self, 16 | image_size, 17 | normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 18 | jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4), 19 | ): 20 | self.image_size = image_size 21 | self.normalize_param = normalize_param 22 | self.jitter_param = jitter_param 23 | 24 | def parse_transform(self, transform_type): 25 | if transform_type == "ImageJitter": 26 | method = add_transforms.ImageJitter(self.jitter_param) 27 | return method 28 | method = getattr(transforms, transform_type) 29 | if transform_type == "RandomResizedCrop": 30 | return method(self.image_size) 31 | elif transform_type == "CenterCrop": 32 | return method(self.image_size) 33 | elif transform_type == "Resize": 34 | return method([int(self.image_size * 1.15), int(self.image_size * 1.15)]) 35 | elif transform_type == "Normalize": 36 | return method(**self.normalize_param) 37 | else: 38 | return method() 39 | 40 | def get_composed_transform( 41 | self, 42 | aug=False, 43 | normalize=True, 44 | to_pil=True, 45 | confound_noise=0.0, 46 | confound_noise_class_weight=0.0, 47 | ): 48 | if aug: 49 | transform_list = [ 50 | "RandomResizedCrop", 51 | "ImageJitter", 52 | "RandomHorizontalFlip", 53 | "ToTensor", 54 | ] 55 | else: 56 | transform_list = ["Resize", "CenterCrop", "ToTensor"] 57 | 58 | if confound_noise != 0.0: 59 | transform_list.append( 60 | ("Noise", confound_noise, confound_noise_class_weight) 61 | ) 62 | 63 | if normalize: 64 | transform_list.append("Normalize") 65 | 66 | if to_pil: 67 | transform_list = ["ToPILImage"] + transform_list 68 | 69 | transform_funcs = [self.parse_transform(x) for x in transform_list] 70 | transform = transforms.Compose(transform_funcs) 71 | return transform 72 | 73 | def get_normalize(self): 74 | return self.parse_transform("Normalize") 75 | 76 | 77 | class DataManager: 78 | @abstractmethod 79 | def get_data_loader(self, data_file, aug): 80 | pass 81 | 82 | 83 | class SimpleDataManager(DataManager): 84 | def __init__(self, image_size, batch_size, num_workers=12): 85 | super(SimpleDataManager, self).__init__() 86 | self.batch_size = batch_size 87 | self.trans_loader = TransformLoader(image_size) 88 | self.num_workers = num_workers 89 | 90 | def get_data_loader( 91 | self, data_file, aug, lang_dir=None, normalize=True, to_pil=False 92 | ): # parameters that would change on train/val set 93 | if lang_dir is not None: 94 | raise NotImplementedError 95 | transform = self.trans_loader.get_composed_transform( 96 | aug, normalize=normalize, to_pil=to_pil 97 | ) 98 | dataset = SimpleDataset(data_file, transform) 99 | data_loader_params = dict( 100 | batch_size=self.batch_size, 101 | shuffle=True, 102 | num_workers=self.num_workers, 103 | pin_memory=True, 104 | ) 105 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 106 | 107 | return data_loader 108 | 109 | 110 | class SetDataManager(DataManager): 111 | def __init__( 112 | self, name, image_size, n_way, n_support, n_query, n_episode=100, args=None 113 | ): 114 | super(SetDataManager, self).__init__() 115 | self.name = name 116 | self.image_size = image_size 117 | self.n_way = n_way 118 | self.batch_size = n_support + n_query 119 | self.n_episode = n_episode 120 | self.args = args 121 | 122 | self.trans_loader = TransformLoader(image_size) 123 | 124 | def get_data_loader( 125 | self, 126 | data_file, 127 | aug, 128 | lang_dir=None, 129 | normalize=True, 130 | vocab=None, 131 | max_class=None, 132 | max_img_per_class=None, 133 | max_lang_per_class=None, 134 | ): 135 | transform = self.trans_loader.get_composed_transform(aug, normalize=normalize) 136 | 137 | dataset = SetDataset( 138 | self.name, 139 | data_file, 140 | self.batch_size, 141 | transform, 142 | args=self.args, 143 | lang_dir=lang_dir, 144 | vocab=vocab, 145 | max_class=max_class, 146 | max_img_per_class=max_img_per_class, 147 | max_lang_per_class=max_lang_per_class, 148 | ) 149 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_episode) 150 | data_loader_params = dict( 151 | batch_sampler=sampler, num_workers=self.args.n_workers, pin_memory=True, 152 | ) 153 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 154 | return data_loader 155 | -------------------------------------------------------------------------------- /birds/fewshot/data/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from 2 | # https://github.com/facebookresearch/low-shot-shrink-hallucinate 3 | 4 | import glob 5 | import json 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | import torchvision.transforms as transforms 11 | from numpy import random 12 | from PIL import Image 13 | import torchfile 14 | 15 | from . import lang_utils 16 | 17 | 18 | CUB_IMAGES_PATH = "CUB_200_2011/images" 19 | 20 | 21 | def identity(x): 22 | return x 23 | 24 | 25 | def load_image(image_path): 26 | img = Image.open(image_path).convert("RGB") 27 | return img 28 | 29 | 30 | class SimpleDataset: 31 | def __init__(self, data_file, transform, target_transform=identity): 32 | with open(data_file, "r") as f: 33 | self.meta = json.load(f) 34 | self.transform = transform 35 | self.target_transform = target_transform 36 | 37 | def __getitem__(self, i): 38 | image_path = os.path.join(self.meta["image_names"][i]) 39 | img = load_image(image_path) 40 | img = self.transform(img) 41 | target = self.target_transform(self.meta["image_labels"][i]) 42 | return img, target 43 | 44 | def __len__(self): 45 | return len(self.meta["image_names"]) 46 | 47 | 48 | class SetDataset: 49 | def __init__( 50 | self, 51 | name, 52 | data_file, 53 | batch_size, 54 | transform, 55 | args=None, 56 | lang_dir=None, 57 | vocab=None, 58 | max_class=None, 59 | max_img_per_class=None, 60 | max_lang_per_class=None, 61 | ): 62 | self.name = name 63 | with open(data_file, "r") as f: 64 | self.meta = json.load(f) 65 | 66 | self.args = args 67 | self.max_class = max_class 68 | self.max_img_per_class = max_img_per_class 69 | self.max_lang_per_class = max_lang_per_class 70 | 71 | if not (1 <= args.n_caption <= 10): 72 | raise ValueError("Invalid # captions {}".format(args.n_caption)) 73 | 74 | self.cl_list = np.unique(self.meta["image_labels"]).tolist() 75 | 76 | if self.max_class is not None: 77 | if self.max_class > len(self.cl_list): 78 | raise ValueError( 79 | "max_class set to {} but only {} classes in {}".format( 80 | self.max_class, len(self.cl_list), data_file 81 | ) 82 | ) 83 | self.cl_list = self.cl_list[: self.max_class] 84 | 85 | if args.language_filter not in ["all", "color", "nocolor"]: 86 | raise NotImplementedError( 87 | "language_filter = {}".format(args.language_filter) 88 | ) 89 | 90 | self.sub_meta_lang = {} 91 | self.sub_meta_lang_length = {} 92 | self.sub_meta_lang_mask = {} 93 | self.sub_meta = {} 94 | 95 | for cl in self.cl_list: 96 | self.sub_meta[cl] = [] 97 | self.sub_meta_lang[cl] = [] 98 | self.sub_meta_lang_length[cl] = [] 99 | self.sub_meta_lang_mask[cl] = [] 100 | 101 | # Load language and mapping from image names -> lang idx 102 | self.lang = {} 103 | self.lang_lengths = {} 104 | self.lang_masks = {} 105 | self.image_name_idx = {} 106 | for cln, label_name in enumerate(self.meta["label_names"]): 107 | # Use the numeric class id instead of label name due to 108 | # inconsistencies 109 | digits = label_name.split(".")[0] 110 | matching_names = [ 111 | x 112 | for x in os.listdir(os.path.join(lang_dir, "word_c10")) 113 | if x.startswith(digits) 114 | ] 115 | assert len(matching_names) == 1, matching_names 116 | label_file = os.path.join(lang_dir, "word_c10", matching_names[0]) 117 | lang_tensor = torch.from_numpy(torchfile.load(label_file)).long() 118 | # Make words last dim 119 | lang_tensor = lang_tensor.transpose(2, 1) 120 | lang_tensor = lang_tensor - 1 # XXX: Decrement language by 1 upon load 121 | 122 | if ( 123 | self.args.language_filter == "color" 124 | or self.args.language_filter == "nocolor" 125 | ): 126 | lang_tensor = lang_utils.filter_language( 127 | lang_tensor, self.args.language_filter, vocab 128 | ) 129 | 130 | if self.args.shuffle_lang: 131 | lang_tensor = lang_utils.shuffle_language(lang_tensor) 132 | 133 | lang_lengths = lang_utils.get_lang_lengths(lang_tensor) 134 | 135 | # Add start and end of sentence tokens to language 136 | lang_tensor, lang_lengths = lang_utils.add_sos_eos( 137 | lang_tensor, lang_lengths, vocab 138 | ) 139 | lang_masks = lang_utils.get_lang_masks( 140 | lang_lengths, max_len=lang_tensor.shape[2] 141 | ) 142 | 143 | self.lang[label_name] = lang_tensor 144 | self.lang_lengths[label_name] = lang_lengths 145 | self.lang_masks[label_name] = lang_masks 146 | 147 | # Give images their numeric ids according to alphabetical order 148 | if self.name == "CUB": 149 | img_dir = os.path.join(lang_dir, "text_c10", label_name, "*.txt") 150 | sorted_imgs = sorted( 151 | [ 152 | os.path.splitext(os.path.basename(i))[0] 153 | for i in glob.glob(img_dir) 154 | ] 155 | ) 156 | for i, img_fname in enumerate(sorted_imgs): 157 | self.image_name_idx[img_fname] = i 158 | 159 | for x, y in zip(self.meta["image_names"], self.meta["image_labels"]): 160 | if y in self.sub_meta: 161 | self.sub_meta[y].append(x) 162 | label_name = self.meta["label_names"][y] 163 | 164 | image_basename = os.path.splitext(os.path.basename(x))[0] 165 | if self.name == "CUB": 166 | image_lang_idx = self.image_name_idx[image_basename] 167 | else: 168 | image_lang_idx = int(image_basename[-1]) 169 | 170 | captions = self.lang[label_name][image_lang_idx] 171 | lengths = self.lang_lengths[label_name][image_lang_idx] 172 | masks = self.lang_masks[label_name][image_lang_idx] 173 | 174 | self.sub_meta_lang[y].append(captions) 175 | self.sub_meta_lang_length[y].append(lengths) 176 | self.sub_meta_lang_mask[y].append(masks) 177 | else: 178 | assert self.max_class is not None 179 | 180 | if self.args.scramble_lang: 181 | # For each class, shuffle captions for each image 182 | ( 183 | self.sub_meta_lang, 184 | self.sub_meta_lang_length, 185 | self.sub_meta_lang_mask, 186 | ) = lang_utils.shuffle_lang_class( 187 | self.sub_meta_lang, self.sub_meta_lang_length, self.sub_meta_lang_mask 188 | ) 189 | 190 | if self.args.scramble_lang_class: 191 | raise NotImplementedError 192 | 193 | if self.args.scramble_all: 194 | # Shuffle captions completely randomly 195 | ( 196 | self.sub_meta_lang, 197 | self.sub_meta_lang_length, 198 | self.sub_meta_lang_mask, 199 | ) = lang_utils.shuffle_all_class( 200 | self.sub_meta_lang, self.sub_meta_lang_length, self.sub_meta_lang_mask 201 | ) 202 | 203 | if self.max_img_per_class is not None: 204 | # Trim number of images available per class 205 | for cl in self.sub_meta.keys(): 206 | self.sub_meta[cl] = self.sub_meta[cl][: self.max_img_per_class] 207 | self.sub_meta_lang[cl] = self.sub_meta_lang[cl][ 208 | : self.max_img_per_class 209 | ] 210 | self.sub_meta_lang_length[cl] = self.sub_meta_lang_length[cl][ 211 | : self.max_img_per_class 212 | ] 213 | self.sub_meta_lang_mask[cl] = self.sub_meta_lang_mask[cl][ 214 | : self.max_img_per_class 215 | ] 216 | 217 | if self.max_lang_per_class is not None: 218 | # Trim language available for each class; recycle language if not enough 219 | for cl in self.sub_meta.keys(): 220 | self.sub_meta_lang[cl] = lang_utils.recycle_lang( 221 | self.sub_meta_lang[cl], self.max_lang_per_class 222 | ) 223 | self.sub_meta_lang_length[cl] = lang_utils.recycle_lang( 224 | self.sub_meta_lang_length[cl], self.max_lang_per_class 225 | ) 226 | self.sub_meta_lang_mask[cl] = lang_utils.recycle_lang( 227 | self.sub_meta_lang_mask[cl], self.max_lang_per_class 228 | ) 229 | 230 | self.sub_dataloader = [] 231 | sub_data_loader_params = dict( 232 | batch_size=batch_size, 233 | shuffle=True, 234 | num_workers=0, # use main thread only or may receive multiple batches 235 | pin_memory=False, 236 | ) 237 | for i, cl in enumerate(self.cl_list): 238 | sub_dataset = SubDataset( 239 | self.name, 240 | self.sub_meta[cl], 241 | cl, 242 | sub_meta_lang=self.sub_meta_lang[cl], 243 | sub_meta_lang_length=self.sub_meta_lang_length[cl], 244 | sub_meta_lang_mask=self.sub_meta_lang_mask[cl], 245 | transform=transform, 246 | n_caption=self.args.n_caption, 247 | args=self.args, 248 | max_lang_per_class=self.max_lang_per_class, 249 | ) 250 | self.sub_dataloader.append( 251 | torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) 252 | ) 253 | 254 | def __getitem__(self, i): 255 | return next(iter(self.sub_dataloader[i])) 256 | 257 | def __len__(self): 258 | return len(self.sub_dataloader) 259 | 260 | 261 | class SubDataset: 262 | def __init__( 263 | self, 264 | name, 265 | sub_meta, 266 | cl, 267 | sub_meta_lang=None, 268 | sub_meta_lang_length=None, 269 | sub_meta_lang_mask=None, 270 | transform=transforms.ToTensor(), 271 | target_transform=identity, 272 | n_caption=10, 273 | args=None, 274 | max_lang_per_class=None, 275 | ): 276 | self.name = name 277 | self.sub_meta = sub_meta 278 | self.sub_meta_lang = sub_meta_lang 279 | self.sub_meta_lang_length = sub_meta_lang_length 280 | self.sub_meta_lang_mask = sub_meta_lang_mask 281 | self.cl = cl 282 | self.transform = transform 283 | self.target_transform = target_transform 284 | if not (1 <= n_caption <= 10): 285 | raise ValueError("Invalid # captions {}".format(n_caption)) 286 | self.n_caption = n_caption 287 | cl_path = os.path.split(self.sub_meta[0])[0] 288 | self.img = dict(np.load(os.path.join(cl_path, "img.npz"))) 289 | 290 | # Used if sampling from class 291 | self.args = args 292 | self.max_lang_per_class = max_lang_per_class 293 | 294 | def __getitem__(self, i): 295 | image_path = self.sub_meta[i] 296 | img = self.img[image_path] 297 | img = self.transform(img) 298 | target = self.target_transform(self.cl) 299 | 300 | if self.n_caption == 1: 301 | lang_idx = 0 302 | else: 303 | lang_idx = random.randint(min(self.n_caption, len(self.sub_meta_lang[i]))) 304 | 305 | if self.args.sample_class_lang: 306 | # Sample from all language, rather than the ith image 307 | if self.max_lang_per_class is None: 308 | max_i = len(self.sub_meta_lang) 309 | else: 310 | max_i = min(self.max_lang_per_class, len(self.sub_meta_lang)) 311 | which_img_lang_i = random.randint(0, max_i) 312 | else: 313 | which_img_lang_i = i 314 | 315 | lang = self.sub_meta_lang[which_img_lang_i][lang_idx] 316 | lang_length = self.sub_meta_lang_length[which_img_lang_i][lang_idx] 317 | lang_mask = self.sub_meta_lang_mask[which_img_lang_i][lang_idx] 318 | 319 | return img, target, (lang, lang_length, lang_mask) 320 | 321 | def __len__(self): 322 | return len(self.sub_meta) 323 | 324 | 325 | class EpisodicBatchSampler(object): 326 | def __init__(self, n_classes, n_way, n_episodes): 327 | self.n_classes = n_classes 328 | self.n_way = n_way 329 | self.n_episodes = n_episodes 330 | 331 | def __len__(self): 332 | return self.n_episodes 333 | 334 | def __iter__(self): 335 | for i in range(self.n_episodes): 336 | yield torch.randperm(self.n_classes)[: self.n_way] 337 | -------------------------------------------------------------------------------- /birds/fewshot/data/lang_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for processing language datasets 3 | """ 4 | 5 | import os 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import torch 10 | from numpy import random 11 | import torchfile 12 | 13 | SOS_TOKEN = "" 14 | EOS_TOKEN = "" 15 | PAD_TOKEN = "" 16 | 17 | COLOR_WORDS = set( 18 | [ 19 | "amaranth", 20 | "charcoal", 21 | "amber", 22 | "amethyst", 23 | "apricot", 24 | "aquamarine", 25 | "azure", 26 | "baby blue", 27 | "beige", 28 | "black", 29 | "blue", 30 | "blush", 31 | "bronze", 32 | "brown", 33 | "burgundy", 34 | "byzantium", 35 | "carmine", 36 | "cerise", 37 | "cerulean", 38 | "champagne", 39 | "chartreuse", 40 | "chocolate", 41 | "cobalt", 42 | "coffee", 43 | "copper", 44 | "coral", 45 | "crimson", 46 | "cyan", 47 | "desert", 48 | "electric", 49 | "emerald", 50 | "erin", 51 | "gold", 52 | "gray", 53 | "grey", 54 | "green", 55 | "harlequin", 56 | "indigo", 57 | "ivory", 58 | "jade", 59 | "jungle", 60 | "lavender", 61 | "lemon", 62 | "lilac", 63 | "lime", 64 | "magenta", 65 | "magenta", 66 | "maroon", 67 | "mauve", 68 | "navy", 69 | "ochre", 70 | "olive", 71 | "orange", 72 | "orange", 73 | "orchid", 74 | "peach", 75 | "pear", 76 | "periwinkle", 77 | "persian", 78 | "pink", 79 | "plum", 80 | "prussian", 81 | "puce", 82 | "purple", 83 | "raspberry", 84 | "red", 85 | "red", 86 | "rose", 87 | "ruby", 88 | "salmon", 89 | "sangria", 90 | "sapphire", 91 | "scarlet", 92 | "silver", 93 | "slate", 94 | "spring", 95 | "spring", 96 | "tan", 97 | "taupe", 98 | "teal", 99 | "turquoise", 100 | "ultramarine", 101 | "violet", 102 | "viridian", 103 | "white", 104 | "yellow", 105 | "reddish", 106 | "yellowish", 107 | "greenish", 108 | "orangeish", 109 | "orangish", 110 | "blackish", 111 | "pinkish", 112 | "dark", 113 | "light", 114 | "bright", 115 | "greyish", 116 | "grayish", 117 | "brownish", 118 | "beigish", 119 | "aqua", 120 | ] 121 | ) 122 | 123 | 124 | def filter_language(lang_tensor, language_filter, vocab): 125 | """ 126 | Filter language, keeping or discarding color words 127 | 128 | :param lang_tensor: torch.Tensor of shape (n_imgs, lang_per_img, 129 | max_lang_len); language to be filtered 130 | :param language_filter: either 'color' or 'nocolor'; what language to 131 | filter out 132 | :param vocab: the vocabulary (so we know what indexes to remove) 133 | 134 | :returns: torch.Tensor of same shape as `lang_tensor` with either color or 135 | non-color words removed 136 | """ 137 | assert language_filter in ["color", "nocolor"] 138 | 139 | cw = set(vocab[cw] for cw in COLOR_WORDS if cw in vocab) 140 | 141 | new_lang_tensor = torch.ones_like(lang_tensor) 142 | for bird_caps_i in range(lang_tensor.shape[0]): 143 | bird_caps = lang_tensor[bird_caps_i] 144 | new_bird_caps = torch.ones_like(bird_caps) 145 | for bird_cap_i in range(bird_caps.shape[0]): 146 | bird_cap = bird_caps[bird_cap_i] 147 | new_bird_cap = torch.ones_like(bird_cap) 148 | new_w_i = 0 149 | for w in bird_cap: 150 | is_cw = w.item() in cw 151 | if (language_filter == "color" and is_cw) or ( 152 | language_filter == "nocolor" and not is_cw 153 | ): 154 | new_bird_cap[new_w_i] = w 155 | new_w_i += 1 156 | if new_bird_cap[0].item() == 1: 157 | # FIXME: Here we're just choosing an arbitrary randomly 158 | # mispelled token; make a proper UNK token. 159 | new_bird_cap[0] = 5724 160 | new_bird_caps[bird_cap_i] = new_bird_cap 161 | new_lang_tensor[bird_caps_i] = new_bird_caps 162 | return new_lang_tensor 163 | 164 | 165 | def shuffle_language(lang_tensor): 166 | """ 167 | Scramble words in language 168 | 169 | :param lang_tensor: torch.Tensor of shape (n_img, lang_per_img, max_lang_len) 170 | 171 | :returns: torch.Tensor of same shape, but with words randomly scrambled 172 | """ 173 | new_lang_tensor = torch.ones_like(lang_tensor) 174 | for bird_caps_i in range(lang_tensor.shape[0]): 175 | bird_caps = lang_tensor[bird_caps_i] 176 | new_bird_caps = torch.ones_like(bird_caps) 177 | for bird_cap_i in range(bird_caps.shape[0]): 178 | bird_cap = bird_caps[bird_cap_i] 179 | new_bird_cap = torch.ones_like(bird_cap) 180 | bird_cap_list = [] 181 | for w in bird_cap.numpy(): 182 | if w != 1: 183 | bird_cap_list.append(w) 184 | else: 185 | break 186 | random.shuffle(bird_cap_list) 187 | bird_cap_shuf = torch.tensor( 188 | bird_cap_list, dtype=new_bird_cap.dtype, requires_grad=False 189 | ) 190 | new_bird_cap[: len(bird_cap_list)] = bird_cap_shuf 191 | new_bird_caps[bird_cap_i] = new_bird_cap 192 | new_lang_tensor[bird_caps_i] = new_bird_caps 193 | return new_lang_tensor 194 | 195 | 196 | def get_lang_lengths(lang_tensor): 197 | """ 198 | Get lengths of each caption 199 | 200 | :param lang_tensor: torch.Tensor of shape (n_img, lang_per_img, max_len) 201 | :returns: torch.Tensor of shape (n_img, lang_per_img) 202 | """ 203 | max_lang_len = lang_tensor.shape[2] 204 | n_pad = torch.sum(lang_tensor == 0, dim=2) 205 | lang_lengths = max_lang_len - n_pad 206 | return lang_lengths 207 | 208 | 209 | def get_lang_masks(lang_lengths, max_len=32): 210 | """ 211 | Given lang lengths, convert to masks 212 | 213 | :param lang_lengths: torch.tensor of shape (n_imgs, lang_per_img) 214 | 215 | returns: torch.BoolTensor of shape (n_imgs, lang_per_img, max_len), binary 216 | mask with 0s in token spots and 1s in padding spots 217 | """ 218 | mask = torch.ones(lang_lengths.shape + (max_len,), dtype=torch.bool) 219 | for i in range(lang_lengths.shape[0]): 220 | for j in range(lang_lengths.shape[1]): 221 | this_ll = lang_lengths[i, j] 222 | mask[i, j, :this_ll] = 0 223 | return mask 224 | 225 | 226 | def add_sos_eos(lang_tensor, lang_lengths, vocab): 227 | """ 228 | Pad language tensors 229 | 230 | :param lang: torch.Tensor of shape (n_imgs, n_lang_per_img, max_len) 231 | :param lang_lengths: torch.Tensor of shape (n_imgs, n_lang_per_img) 232 | :param vocab: dictionary from words -> idxs 233 | 234 | :returns: (lang, lang_lengths) where lang has SOS and EOS tokens added, and 235 | lang_lengths have all been increased by 2 (to account for SOS/EOS) 236 | """ 237 | sos_idx = vocab[SOS_TOKEN] 238 | eos_idx = vocab[EOS_TOKEN] 239 | lang_tensor_padded = torch.zeros( 240 | lang_tensor.shape[0], 241 | lang_tensor.shape[1], 242 | lang_tensor.shape[2] + 2, 243 | dtype=torch.int64, 244 | ) 245 | lang_tensor_padded[:, :, 0] = sos_idx 246 | lang_tensor_padded[:, :, 1:-1] = lang_tensor 247 | for i in range(lang_tensor_padded.shape[0]): 248 | for j in range(lang_tensor_padded.shape[1]): 249 | ll = lang_lengths[i, j] 250 | lang_tensor_padded[ 251 | i, j, ll + 1 252 | ] = eos_idx # + 1 accounts for sos token already there 253 | return lang_tensor_padded, lang_lengths + 2 254 | 255 | 256 | def shuffle_lang_class(lang, lang_length, lang_mask): 257 | """ 258 | For each class, shuffle captions across images 259 | 260 | :param lang: dict from class -> list of languages for that class 261 | :param lang_length: dict from class -> list of language lengths for that class 262 | :param lang_mask: list of language masks 263 | 264 | :returns: (new_lang, new_lang_length, new_lang_mask): tuple of new language 265 | dictionaries representing the modified language 266 | """ 267 | new_lang = {} 268 | new_lang_length = {} 269 | new_lang_mask = {} 270 | for y in lang: 271 | # FIXME: Make this seedable 272 | img_range = np.arange(len(lang[y])) 273 | random.shuffle(img_range) 274 | nlang = [] 275 | nlang_length = [] 276 | nlang_mask = [] 277 | for lang_i in img_range: 278 | nlang.append(lang[y][lang_i]) 279 | nlang_length.append(lang_length[y][lang_i]) 280 | nlang_mask.append(lang_mask[y][lang_i]) 281 | new_lang[y] = nlang 282 | new_lang_length[y] = nlang_length 283 | new_lang_mask[y] = nlang_mask 284 | return new_lang, new_lang_length, new_lang_mask 285 | 286 | 287 | def shuffle_all_class(lang, lang_length, lang_mask): 288 | """ 289 | Shuffle captions completely randomly across all images and classes 290 | 291 | :param lang: dict from class -> list of languages for that class 292 | :param lang_length: dict from class -> list of language lengths for that class 293 | :param lang_mask: list of language masks 294 | 295 | :returns: (new_lang, new_lang_length, new_lang_mask): tuple of new language 296 | dictionaries representing the modified language 297 | """ 298 | lens = [[(m, j) for j in range(len(lang[m]))] for m in lang.keys()] 299 | lens = [item for sublist in lens for item in sublist] 300 | shuffled_lens = lens[:] 301 | random.shuffle(shuffled_lens) 302 | new_lang = defaultdict(list) 303 | new_lang_length = defaultdict(list) 304 | new_lang_mask = defaultdict(list) 305 | for (m, _), (new_m, new_i) in zip(lens, shuffled_lens): 306 | new_lang[m].append(lang[new_m][new_i]) 307 | new_lang_length[m].append(lang_length[new_m][new_i]) 308 | new_lang_mask[m].append(lang_mask[new_m][new_i]) 309 | assert all(len(new_lang[m]) == len(lang[m]) for m in lang.keys()) 310 | return dict(new_lang), dict(new_lang_length), dict(new_lang_mask) 311 | 312 | 313 | def load_vocab(lang_dir): 314 | """ 315 | Load torch-serialized vocabulary from the lang dir 316 | 317 | :param: lang_dir: str, path to language directory 318 | :returns: dictionary from words -> idxs 319 | """ 320 | vocab = torchfile.load(os.path.join(lang_dir, "vocab_c10.t7")) 321 | vocab = {k: v - 1 for k, v in vocab.items()} # Decrement vocab 322 | vocab = {k.decode("utf-8"): v for k, v in vocab.items()} # Unicode 323 | # Add SOS/EOS tokens 324 | sos_idx = len(vocab) 325 | vocab[SOS_TOKEN] = sos_idx 326 | eos_idx = len(vocab) 327 | vocab[EOS_TOKEN] = eos_idx 328 | return vocab 329 | 330 | 331 | def glove_init(vocab, emb_size=300): 332 | """ 333 | Initialize vocab with glove vectors. Requires spacy and en_vectors_web_lg 334 | spacy model 335 | 336 | :param vocab: dict from words -> idxs 337 | :param emb_size: int, size of embeddings (should be 300 for spacy glove 338 | vectors) 339 | 340 | :returns: torch.FloatTensor of size (len(vocab), emb_size), with glove 341 | embedding if exists, else zeros 342 | """ 343 | import spacy 344 | 345 | try: 346 | nlp = spacy.load("en_vectors_web_lg", disable=["tagger", "parser", "ner"]) 347 | except OSError: 348 | # Try loading for current directory (codalab) 349 | nlp = spacy.load( 350 | "./en_vectors_web_lg/en_vectors_web_lg-2.1.0/", 351 | disable=["tagger", "parser", "ner"], 352 | ) 353 | 354 | vecs = np.zeros((len(vocab), emb_size), dtype=np.float32) 355 | vec_ids_sort = sorted(vocab.items(), key=lambda x: x[1]) 356 | sos_idx = vocab[SOS_TOKEN] 357 | eos_idx = vocab[EOS_TOKEN] 358 | pad_idx = vocab[PAD_TOKEN] 359 | for vec, vecid in vec_ids_sort: 360 | if vecid in (pad_idx, sos_idx, eos_idx): 361 | v = np.zeros(emb_size, dtype=np.float32) 362 | else: 363 | v = nlp(vec)[0].vector 364 | vecs[vecid] = v 365 | vecs = torch.as_tensor(vecs) 366 | return vecs 367 | 368 | 369 | def get_special_indices(vocab): 370 | """ 371 | Get indices of special items from vocab. 372 | :param vocab: dictionary from words -> idxs 373 | :returns: dictionary from {sos_index, eos_index, pad_index} -> tokens 374 | """ 375 | return { 376 | name: vocab[token] 377 | for name, token in [ 378 | ("sos_index", SOS_TOKEN), 379 | ("eos_index", EOS_TOKEN), 380 | ("pad_index", PAD_TOKEN), 381 | ] 382 | } 383 | 384 | 385 | def recycle_lang(langs, max_lang): 386 | """ 387 | Given a limited amount of language, reuse `max_lang` times 388 | :param langs: list of languages 389 | :param max_lang: how long the full language tensor should be 390 | 391 | :returns: new_langs, a list of length `max_lang` created by cycling through 392 | `langs` 393 | """ 394 | new_langs = [] 395 | for i in range(len(langs)): 396 | new_langs.append(langs[i % max_lang]) 397 | return new_langs 398 | -------------------------------------------------------------------------------- /birds/fewshot/io_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains argument parsers and utilities for saving and loading metrics and 3 | models. 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | import numpy as np 11 | 12 | import backbone 13 | 14 | 15 | model_dict = dict( 16 | Conv4=backbone.Conv4, 17 | Conv4NP=backbone.Conv4NP, 18 | Conv4S=backbone.Conv4S, 19 | Conv6=backbone.Conv6, 20 | ResNet10=backbone.ResNet10, 21 | ResNet18=backbone.ResNet18, 22 | PretrainedResNet18=backbone.PretrainedResNet18, 23 | ResNet34=backbone.ResNet34, 24 | ResNet50=backbone.ResNet50, 25 | ResNet101=backbone.ResNet101, 26 | ) 27 | 28 | 29 | def parse_args(script): 30 | parser = argparse.ArgumentParser( 31 | description="few-shot script %s" % (script), 32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 33 | ) 34 | parser.add_argument( 35 | "--checkpoint_dir", 36 | required=True, 37 | help="Specify checkpoint dir (if none, automatically generate)", 38 | ) 39 | parser.add_argument("--model", default="Conv4", help="Choice of backbone") 40 | parser.add_argument("--lsl", action="store_true") 41 | parser.add_argument( 42 | "--l3", action="store_true", help="Use l3 (do not need to --lsl)" 43 | ) 44 | parser.add_argument("--l3_n_infer", type=int, default=10, help="Number to sample") 45 | parser.add_argument( 46 | "--rnn_type", choices=["gru", "lstm"], default="gru", help="Language RNN type" 47 | ) 48 | parser.add_argument( 49 | "--rnn_num_layers", default=1, type=int, help="Language RNN num layers" 50 | ) 51 | parser.add_argument( 52 | "--rnn_dropout", default=0.0, type=float, help="Language RNN dropout" 53 | ) 54 | parser.add_argument( 55 | "--lang_supervision", 56 | default="class", 57 | choices=["instance", "class"], 58 | help="At what level to supervise with language?", 59 | ) 60 | parser.add_argument("--glove_init", action="store_true") 61 | parser.add_argument( 62 | "--freeze_emb", action="store_true", help="Freeze LM word embedding layer" 63 | ) 64 | 65 | langparser = parser.add_argument_group("language settings") 66 | langparser.add_argument( 67 | "--shuffle_lang", action="store_true", help="Shuffle words in caption" 68 | ) 69 | langparser.add_argument( 70 | "--scramble_lang", 71 | action="store_true", 72 | help="Scramble captions -> images mapping in a class", 73 | ) 74 | langparser.add_argument( 75 | "--sample_class_lang", 76 | action="store_true", 77 | help="Sample language randomly from class, rather than getting lang assoc. w/ img", 78 | ) 79 | langparser.add_argument( 80 | "--scramble_all", 81 | action="store_true", 82 | help="Scramble captions -> images mapping across all classes", 83 | ) 84 | langparser.add_argument( 85 | "--scramble_lang_class", 86 | action="store_true", 87 | help="Scramble captions -> images mapping across all classes, but keep classes consistent", 88 | ) 89 | langparser.add_argument( 90 | "--language_filter", 91 | default="all", 92 | choices=["all", "color", "nocolor"], 93 | help="What language to use", 94 | ) 95 | 96 | parser.add_argument( 97 | "--lang_hidden_size", type=int, default=200, help="Language decoder hidden size" 98 | ) 99 | parser.add_argument( 100 | "--lang_emb_size", type=int, default=300, help="Language embedding hidden size" 101 | ) 102 | parser.add_argument( 103 | "--lang_lambda", type=float, default=5, help="Weight on language loss" 104 | ) 105 | 106 | parser.add_argument( 107 | "--n_caption", 108 | type=int, 109 | default=1, 110 | choices=list(range(1, 11)), 111 | help="How many captions to use for pretraining", 112 | ) 113 | parser.add_argument( 114 | "--max_class", type=int, default=None, help="Max number of training classes" 115 | ) 116 | parser.add_argument( 117 | "--max_img_per_class", 118 | type=int, 119 | default=None, 120 | help="Max number of images per training class", 121 | ) 122 | parser.add_argument( 123 | "--max_lang_per_class", 124 | type=int, 125 | default=None, 126 | help="Max number of language per training class (recycled among images)", 127 | ) 128 | parser.add_argument( 129 | "--train_n_way", default=5, type=int, help="class num to classify for training" 130 | ) 131 | parser.add_argument( 132 | "--test_n_way", 133 | default=5, 134 | type=int, 135 | help="class num to classify for testing (validation) ", 136 | ) 137 | parser.add_argument( 138 | "--n_shot", 139 | default=1, 140 | type=int, 141 | help="number of labeled data in each class, same as n_support", 142 | ) 143 | parser.add_argument( 144 | "--n_workers", 145 | default=4, 146 | type=int, 147 | help="Use this many workers for loading data", 148 | ) 149 | parser.add_argument( 150 | "--debug", action="store_true", help="Inspect generated language" 151 | ) 152 | parser.add_argument( 153 | "--seed", type=int, default=None, help="random seed (torch only; not numpy)" 154 | ) 155 | 156 | if script == "train": 157 | parser.add_argument( 158 | "--n", default=1, type=int, help="Train run number (used for metrics)" 159 | ) 160 | parser.add_argument( 161 | "--optimizer", 162 | default="adam", 163 | choices=["adam", "amsgrad", "rmsprop"], 164 | help="Optimizer", 165 | ) 166 | parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate") 167 | parser.add_argument( 168 | "--rnn_lr_scale", 169 | default=1.0, 170 | type=float, 171 | help="Scale the RNN lr by this amount of the original lr", 172 | ) 173 | parser.add_argument("--save_freq", default=50, type=int, help="Save frequency") 174 | parser.add_argument("--start_epoch", default=0, type=int, help="Starting epoch") 175 | parser.add_argument( 176 | "--stop_epoch", default=600, type=int, help="Stopping epoch" 177 | ) # for meta-learning methods, each epoch contains 100 episodes 178 | parser.add_argument( 179 | "--resume", 180 | action="store_true", 181 | help="continue from previous trained model with largest epoch", 182 | ) 183 | elif script == "test": 184 | parser.add_argument( 185 | "--split", 186 | default="novel", 187 | choices=["base", "val", "novel"], 188 | help="which split to evaluate on", 189 | ) 190 | parser.add_argument( 191 | "--save_iter", 192 | default=-1, 193 | type=int, 194 | help="saved feature from the model trained in x epoch, use the best model if x is -1", 195 | ) 196 | parser.add_argument( 197 | "--save_embeddings", 198 | action="store_true", 199 | help="Save embeddings from language model, then exit (requires --lsl)", 200 | ) 201 | parser.add_argument( 202 | "--embeddings_file", 203 | default="./embeddings.txt", 204 | help="File to save embeddings to", 205 | ) 206 | parser.add_argument( 207 | "--embeddings_metadata", 208 | default="./embeddings_metadata.txt", 209 | help="File to save embedding metadata to (currently just words)", 210 | ) 211 | parser.add_argument( 212 | "--record_file", 213 | default="./record/results.txt", 214 | help="Where to write results to", 215 | ) 216 | else: 217 | raise ValueError("Unknown script") 218 | 219 | args = parser.parse_args() 220 | 221 | if "save_embeddings" in args and (args.save_embeddings and not args.lsl): 222 | parser.error("Must set --lsl to save embeddings") 223 | 224 | if args.glove_init and not (args.lsl or args.l3): 225 | parser.error("Must set --lsl to init with glove") 226 | 227 | return args 228 | 229 | 230 | def get_assigned_file(checkpoint_dir, num): 231 | assign_file = os.path.join(checkpoint_dir, "{:d}.tar".format(num)) 232 | return assign_file 233 | 234 | 235 | def get_resume_file(checkpoint_dir): 236 | filelist = glob.glob(os.path.join(checkpoint_dir, "*.tar")) 237 | if len(filelist) == 0: 238 | return None 239 | 240 | filelist = [x for x in filelist if os.path.basename(x) != "best_model.tar"] 241 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist]) 242 | max_epoch = np.max(epochs) 243 | resume_file = os.path.join(checkpoint_dir, "{:d}.tar".format(max_epoch)) 244 | return resume_file 245 | 246 | 247 | def get_best_file(checkpoint_dir): 248 | best_file = os.path.join(checkpoint_dir, "best_model.tar") 249 | if os.path.isfile(best_file): 250 | return best_file 251 | else: 252 | return get_resume_file(checkpoint_dir) 253 | -------------------------------------------------------------------------------- /birds/fewshot/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import language, protonet 2 | -------------------------------------------------------------------------------- /birds/fewshot/models/language.py: -------------------------------------------------------------------------------- 1 | """ 2 | Language encoders/decoders. 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.utils.rnn as rnn_utils 10 | 11 | 12 | class TextProposal(nn.Module): 13 | r"""Reverse proposal model, estimating: 14 | argmax_lambda log q(w_i|x_1, y_1, ..., x_n, y_n; lambda) 15 | approximation to the distribution of descriptions. 16 | Because they use only positive labels, it actually simplifies to 17 | argmax_lambda log q(w_i|x_1, ..., x_4; lambda) 18 | https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/model.py 19 | """ 20 | 21 | def __init__( 22 | self, 23 | embedding_module, 24 | input_size=1600, 25 | hidden_size=512, 26 | project_input=False, 27 | rnn="gru", 28 | num_layers=1, 29 | dropout=0.2, 30 | vocab=None, 31 | sos_index=0, 32 | eos_index=0, 33 | pad_index=0, 34 | ): 35 | super(TextProposal, self).__init__() 36 | self.embedding = embedding_module 37 | self.embedding_dim = embedding_module.embedding_dim 38 | self.vocab_size = embedding_module.num_embeddings 39 | self.input_size = input_size 40 | self.hidden_size = hidden_size 41 | self.project_input = project_input 42 | self.num_layers = num_layers 43 | self.rnn_type = rnn 44 | if self.project_input: 45 | self.proj_h = nn.Linear(self.input_size, self.hidden_size) 46 | if self.rnn_type == "lstm": 47 | self.proj_c = nn.Linear(self.input_size, self.hidden_size) 48 | 49 | if rnn == "gru": 50 | RNN = nn.GRU 51 | elif rnn == "lstm": 52 | RNN = nn.LSTM 53 | else: 54 | raise ValueError("Unknown RNN model {}".format(rnn)) 55 | 56 | # Init the RNN 57 | self.rnn = None 58 | self.rnn = RNN( 59 | self.embedding_dim, 60 | hidden_size, 61 | num_layers=num_layers, 62 | dropout=dropout if num_layers > 1 else 0.0, 63 | batch_first=True, 64 | ) 65 | self.dropout = nn.Dropout(p=dropout) 66 | 67 | # Projection from RNN hidden size to output vocab 68 | self.outputs2vocab = nn.Linear(hidden_size, self.vocab_size) 69 | self.vocab = vocab 70 | # Get sos/eos/pad indices 71 | self.sos_index = sos_index 72 | self.eos_index = eos_index 73 | self.pad_index = pad_index 74 | self.rev_vocab = {v: k for k, v in vocab.items()} 75 | 76 | def forward(self, feats, seq, length): 77 | # feats is from example images 78 | batch_size = seq.size(0) 79 | 80 | if self.project_input: 81 | feats_h = self.proj_h(feats) 82 | if self.rnn_type == "lstm": 83 | feats_c = self.proj_c(feats) 84 | else: 85 | feats_h = feats 86 | feats_c = feats 87 | 88 | if batch_size > 1: 89 | sorted_lengths, sorted_idx = torch.sort(length, descending=True) 90 | seq = seq[sorted_idx] 91 | feats_h = feats_h[sorted_idx] 92 | if self.rnn_type == "lstm": 93 | feats_c = feats_c[sorted_idx] 94 | 95 | # Construct hidden states by expanding to number of layers 96 | feats_h = feats_h.unsqueeze(0).expand(self.num_layers, -1, -1).contiguous() 97 | if self.rnn_type == "lstm": 98 | feats_c = feats_c.unsqueeze(0).expand(self.num_layers, -1, -1).contiguous() 99 | hidden = (feats_h, feats_c) 100 | else: 101 | hidden = feats_h 102 | 103 | # embed your sequences 104 | embed_seq = self.embedding(seq) 105 | 106 | # shape = (seq_len, batch, hidden_dim) 107 | packed_input = rnn_utils.pack_padded_sequence( 108 | embed_seq, sorted_lengths, batch_first=True 109 | ) 110 | packed_output, _ = self.rnn(packed_input, hidden) 111 | output = rnn_utils.pad_packed_sequence(packed_output, batch_first=True) 112 | output = output[0].contiguous() 113 | 114 | if batch_size > 1: 115 | _, reversed_idx = torch.sort(sorted_idx) 116 | output = output[reversed_idx] 117 | 118 | max_length = output.size(1) 119 | output_2d = output.view(batch_size * max_length, self.hidden_size) 120 | output_2d_dropout = self.dropout(output_2d) 121 | outputs_2d = self.outputs2vocab(output_2d_dropout) 122 | outputs = outputs_2d.view(batch_size, max_length, self.vocab_size) 123 | 124 | return outputs 125 | 126 | def sample(self, feats, greedy=False, to_text=False): 127 | """Generate from image features using greedy search.""" 128 | with torch.no_grad(): 129 | if self.project_input: 130 | feats_h = self.proj_h(feats) 131 | states = feats_h 132 | if self.rnn_type == "lstm": 133 | feats_c = self.proj_c(feats) 134 | states = (feats_h, feats_c) 135 | else: 136 | states = feats 137 | 138 | batch_size = states.size(0) 139 | 140 | # initialize hidden states using image features 141 | states = states.unsqueeze(0) 142 | 143 | # first input is SOS token 144 | inputs = np.array([self.sos_index for _ in range(batch_size)]) 145 | inputs = torch.from_numpy(inputs) 146 | inputs = inputs.unsqueeze(1) 147 | inputs = inputs.to(feats.device) 148 | 149 | # save SOS as first generated token 150 | inputs_npy = inputs.squeeze(1).cpu().numpy() 151 | sampled_ids = [[w] for w in inputs_npy] 152 | 153 | # compute embeddings 154 | inputs = self.embedding(inputs) 155 | 156 | # Here, we use the same as max caption length 157 | for i in range(32): # like in jacobs repo 158 | outputs, states = self.rnn(inputs, states) # outputs: (L=1,B,H) 159 | outputs = outputs.squeeze(1) # outputs: (B,H) 160 | outputs = self.outputs2vocab(outputs) # outputs: (B,V) 161 | 162 | if greedy: 163 | predicted = outputs.max(1)[1] 164 | predicted = predicted.unsqueeze(1) 165 | else: 166 | outputs = F.softmax(outputs, dim=1) 167 | predicted = torch.multinomial(outputs, 1) 168 | 169 | predicted_npy = predicted.squeeze(1).cpu().numpy() 170 | predicted_lst = predicted_npy.tolist() 171 | 172 | for w, so_far in zip(predicted_lst, sampled_ids): 173 | if so_far[-1] != self.eos_index: 174 | so_far.append(w) 175 | 176 | inputs = predicted 177 | inputs = self.embedding(inputs) # inputs: (L=1,B,E) 178 | 179 | sampled_lengths = [len(text) for text in sampled_ids] 180 | sampled_lengths = np.array(sampled_lengths) 181 | 182 | max_length = max(sampled_lengths) 183 | padded_ids = np.ones((batch_size, max_length)) * self.pad_index 184 | 185 | for i in range(batch_size): 186 | padded_ids[i, : sampled_lengths[i]] = sampled_ids[i] 187 | 188 | sampled_lengths = torch.from_numpy(sampled_lengths).long() 189 | sampled_ids = torch.from_numpy(padded_ids).long() 190 | 191 | if to_text: 192 | sampled_text = self.to_text(sampled_ids) 193 | return sampled_text, sampled_lengths 194 | return sampled_ids, sampled_lengths 195 | 196 | def to_text(self, sampled_ids): 197 | texts = [] 198 | for sample in sampled_ids.numpy(): 199 | texts.append(" ".join([self.rev_vocab[v] for v in sample if v != 0])) 200 | return np.array(texts, dtype=np.unicode_) 201 | 202 | 203 | class TextRep(nn.Module): 204 | r"""Deterministic Bowman et. al. model to form 205 | text representation. 206 | 207 | Again, this uses 512 hidden dimensions. 208 | """ 209 | 210 | def __init__( 211 | self, embedding_module, hidden_size=512, rnn="gru", num_layers=1, dropout=0.2 212 | ): 213 | super(TextRep, self).__init__() 214 | self.embedding = embedding_module 215 | self.embedding_dim = embedding_module.embedding_dim 216 | if rnn == "gru": 217 | RNN = nn.GRU 218 | elif rnn == "lstm": 219 | RNN = nn.LSTM 220 | else: 221 | raise ValueError("Unknown RNN model {}".format(rnn)) 222 | self.rnn = RNN( 223 | self.embedding_dim, 224 | hidden_size, 225 | num_layers=num_layers, 226 | dropout=dropout if num_layers > 1 else 0.0, 227 | ) 228 | self.hidden_size = hidden_size 229 | 230 | def forward(self, seq, length): 231 | batch_size = seq.size(0) 232 | 233 | if batch_size > 1: 234 | sorted_lengths, sorted_idx = torch.sort(length, descending=True) 235 | seq = seq[sorted_idx] 236 | 237 | # reorder from (B,L,D) to (L,B,D) 238 | seq = seq.transpose(0, 1) 239 | 240 | # embed your sequences 241 | embed_seq = self.embedding(seq) 242 | 243 | packed = rnn_utils.pack_padded_sequence( 244 | embed_seq, 245 | sorted_lengths.data.tolist() if batch_size > 1 else length.data.tolist(), 246 | ) 247 | 248 | _, hidden = self.rnn(packed) 249 | hidden = hidden[-1, ...] 250 | 251 | if batch_size > 1: 252 | _, reversed_idx = torch.sort(sorted_idx) 253 | hidden = hidden[reversed_idx] 254 | 255 | return hidden 256 | -------------------------------------------------------------------------------- /birds/fewshot/run_cl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run everything on codalab. 3 | """ 4 | 5 | import json 6 | import os 7 | from subprocess import check_call 8 | 9 | 10 | if __name__ == "__main__": 11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 12 | 13 | parser = ArgumentParser( 14 | description="Run everything on codalab", 15 | formatter_class=ArgumentDefaultsHelpFormatter, 16 | ) 17 | 18 | cl_parser = parser.add_argument_group( 19 | "Codalab args", "args to control high level codalab eval" 20 | ) 21 | cl_parser.add_argument( 22 | "--no_train", action="store_true", help="Don't run the train command" 23 | ) 24 | cl_parser.add_argument( 25 | "--log_dir", default="./test/", help="Where to save metrics/models" 26 | ) 27 | cl_parser.add_argument("--n", default=1, type=int, help="Number of runs") 28 | 29 | fparser = parser.add_argument_group( 30 | "Few shot args", "args to pass to few shot scripts" 31 | ) 32 | fparser.add_argument("--model", default="Conv4") 33 | fparser.add_argument("--lsl", action="store_true") 34 | fparser.add_argument("--l3", action="store_true") 35 | fparser.add_argument("--l3_n_infer", type=int, default=10) 36 | fparser.add_argument("--rnn_type", choices=["gru", "lstm"], default="gru") 37 | fparser.add_argument("--rnn_num_layers", default=1, type=int) 38 | fparser.add_argument("--rnn_dropout", default=0.0, type=float) 39 | fparser.add_argument( 40 | "--language_filter", default="all", choices=["all", "color", "nocolor"] 41 | ) 42 | fparser.add_argument( 43 | "--lang_supervision", default="instance", choices=["instance", "class"] 44 | ) 45 | fparser.add_argument("--glove_init", action="store_true") 46 | fparser.add_argument("--freeze_emb", action="store_true") 47 | fparser.add_argument("--scramble_lang", action="store_true") 48 | fparser.add_argument("--sample_class_lang", action="store_true") 49 | fparser.add_argument("--scramble_all", action="store_true") 50 | fparser.add_argument("--shuffle_lang", action="store_true") 51 | fparser.add_argument("--scramble_lang_class", action="store_true") 52 | fparser.add_argument("--n_caption", choices=list(range(1, 11)), type=int, default=1) 53 | fparser.add_argument("--max_class", type=int, default=None) 54 | fparser.add_argument("--max_img_per_class", type=int, default=None) 55 | fparser.add_argument("--max_lang_per_class", type=int, default=None) 56 | fparser.add_argument("--lang_lambda", type=float, default=0.25) 57 | fparser.add_argument( 58 | "--save_freq", type=int, default=10000 59 | ) # In CL script, by default, never save, just keep best model 60 | fparser.add_argument("--lang_emb_size", type=int, default=300) 61 | fparser.add_argument("--lang_hidden_size", type=int, default=200) 62 | fparser.add_argument("--lr", type=float, default=1e-3) 63 | fparser.add_argument("--rnn_lr_scale", default=1.0, type=float) 64 | fparser.add_argument( 65 | "--optimizer", default="adam", choices=["adam", "amsgrad", "rmsprop"] 66 | ) 67 | fparser.add_argument("--n_way", type=int, default=5) 68 | fparser.add_argument( 69 | "--test_n_way", 70 | type=int, 71 | default=None, 72 | help="Specify to change n_way eval at test", 73 | ) 74 | fparser.add_argument("--n_shot", type=int, default=1) 75 | fparser.add_argument("--epochs", type=int, default=600) 76 | fparser.add_argument("--n_workers", type=int, default=4) 77 | fparser.add_argument("--resume", action="store_true") 78 | fparser.add_argument("--debug", action="store_true") 79 | fparser.add_argument("--seed", default=None, type=int) 80 | 81 | args = parser.parse_args() 82 | 83 | if args.test_n_way is None: 84 | args.test_n_way = args.n_way 85 | 86 | args.cl_dir = os.path.join(args.log_dir, "checkpoints") 87 | args.cl_record_file = os.path.join(args.log_dir, "results_novel.json") 88 | args.cl_args_file = os.path.join(args.log_dir, "args.json") 89 | 90 | os.makedirs(args.log_dir, exist_ok=True) 91 | if os.path.exists(args.cl_record_file): 92 | os.remove(args.cl_record_file) 93 | 94 | # Save arg metadata to root directory 95 | # Only save if training a model 96 | print("==== RUN_CL: PARAMS ====") 97 | argsv = vars(args) 98 | print(argsv) 99 | if not args.no_train: 100 | with open(args.cl_args_file, "w") as fout: 101 | json.dump(argsv, fout, sort_keys=True, indent=4, separators=(",", ": ")) 102 | 103 | # Train 104 | for i in range(1, args.n + 1): 105 | if not args.no_train: 106 | print("==== RUN_CL ({}/{}): TRAIN ====".format(i, args.n)) 107 | train_cmd = [ 108 | "python3.7", 109 | "fewshot/train.py", 110 | "--model", 111 | args.model, 112 | "--n_shot", 113 | args.n_shot, 114 | "--train_n_way", 115 | args.n_way, 116 | "--test_n_way", 117 | args.test_n_way, 118 | "--stop_epoch", 119 | args.epochs, 120 | "--rnn_type", 121 | args.rnn_type, 122 | "--rnn_num_layers", 123 | args.rnn_num_layers, 124 | "--rnn_dropout", 125 | args.rnn_dropout, 126 | "--language_filter", 127 | args.language_filter, 128 | "--lang_lambda", 129 | args.lang_lambda, 130 | "--lang_hidden_size", 131 | args.lang_hidden_size, 132 | "--lang_supervision", 133 | args.lang_supervision, 134 | "--lang_emb_size", 135 | args.lang_emb_size, 136 | "--n_caption", 137 | args.n_caption, 138 | "--stop_epoch", 139 | args.epochs, 140 | "--checkpoint_dir", 141 | args.cl_dir, 142 | "--save_freq", 143 | args.save_freq, 144 | "--n", 145 | i, 146 | "--lr", 147 | args.lr, 148 | "--rnn_lr_scale", 149 | args.rnn_lr_scale, 150 | "--optimizer", 151 | args.optimizer, 152 | "--n_workers", 153 | args.n_workers, 154 | "--l3_n_infer", 155 | args.l3_n_infer, 156 | ] 157 | if args.seed is not None: 158 | train_cmd.extend(["--seed", args.seed]) 159 | if args.max_class is not None: 160 | train_cmd.extend(["--max_class", args.max_class]) 161 | if args.max_img_per_class is not None: 162 | train_cmd.extend(["--max_img_per_class", args.max_img_per_class]) 163 | if args.max_lang_per_class is not None: 164 | train_cmd.extend(["--max_lang_per_class", args.max_lang_per_class]) 165 | if args.lsl: 166 | train_cmd.append("--lsl") 167 | if args.l3: 168 | train_cmd.append("--l3") 169 | if args.glove_init: 170 | train_cmd.append("--glove_init") 171 | if args.freeze_emb: 172 | train_cmd.append("--freeze_emb") 173 | if args.shuffle_lang: 174 | train_cmd.append("--shuffle_lang") 175 | if args.scramble_lang: 176 | train_cmd.append("--scramble_lang") 177 | if args.sample_class_lang: 178 | train_cmd.append("--sample_class_lang") 179 | if args.scramble_all: 180 | train_cmd.append("--scramble_all") 181 | if args.scramble_lang_class: 182 | train_cmd.append("--scramble_lang_class") 183 | if args.resume: 184 | train_cmd.append("--resume") 185 | if args.debug: 186 | train_cmd.append("--debug") 187 | train_cmd = [str(x) for x in train_cmd] 188 | check_call(train_cmd) 189 | 190 | print("==== RUN_CL ({}/{}): TEST NOVEL ====".format(i, args.n)) 191 | test_cmd = [ 192 | "python3.7", 193 | "fewshot/test.py", 194 | "--model", 195 | args.model, 196 | "--n_shot", 197 | args.n_shot, 198 | "--test_n_way", 199 | args.test_n_way, 200 | "--rnn_type", 201 | args.rnn_type, 202 | "--rnn_num_layers", 203 | args.rnn_num_layers, 204 | "--rnn_dropout", 205 | args.rnn_dropout, 206 | "--language_filter", 207 | args.language_filter, 208 | "--lang_lambda", 209 | args.lang_lambda, 210 | "--lang_hidden_size", 211 | args.lang_hidden_size, 212 | "--lang_supervision", 213 | args.lang_supervision, 214 | "--lang_emb_size", 215 | args.lang_emb_size, 216 | "--checkpoint_dir", 217 | args.cl_dir, 218 | "--split", 219 | "novel", 220 | "--n_workers", 221 | args.n_workers, 222 | "--record_file", 223 | args.cl_record_file, 224 | "--l3_n_infer", 225 | args.l3_n_infer, 226 | ] 227 | if args.seed is not None: 228 | test_cmd.extend(["--seed", args.seed]) 229 | if args.lsl: 230 | test_cmd.append("--lsl") 231 | if args.l3: 232 | test_cmd.append("--l3") 233 | if args.debug: 234 | test_cmd.append("--debug") 235 | test_cmd = [str(x) for x in test_cmd] 236 | check_call(test_cmd) 237 | -------------------------------------------------------------------------------- /birds/fewshot/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script. 3 | """ 4 | 5 | import json 6 | import os 7 | import sys 8 | import time 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim 14 | import torch.utils.data.sampler 15 | 16 | import constants 17 | from data import lang_utils 18 | from data.datamgr import SetDataManager, TransformLoader 19 | from io_utils import get_assigned_file, get_best_file, model_dict, parse_args 20 | from models.language import TextProposal, TextRep 21 | from models.protonet import ProtoNet 22 | 23 | 24 | if __name__ == "__main__": 25 | args = parse_args("test") 26 | 27 | if args.seed is not None: 28 | torch.manual_seed(args.seed) 29 | 30 | acc_all = [] 31 | 32 | vocab = lang_utils.load_vocab(constants.LANG_DIR) 33 | 34 | l3_model = None 35 | lang_model = None 36 | if args.lsl or args.l3: 37 | embedding_model = nn.Embedding(len(vocab), args.lang_emb_size) 38 | lang_model = TextProposal( 39 | embedding_model, 40 | input_size=1600, 41 | hidden_size=args.lang_hidden_size, 42 | project_input=1600 != args.lang_hidden_size, 43 | rnn=args.rnn_type, 44 | num_layers=args.rnn_num_layers, 45 | dropout=args.rnn_dropout, 46 | vocab=vocab, 47 | **lang_utils.get_special_indices(vocab), 48 | ) 49 | 50 | if args.l3: 51 | l3_model = TextRep( 52 | embedding_model, 53 | hidden_size=args.lang_hidden_size, 54 | rnn=args.rnn_type, 55 | num_layers=args.rnn_num_layers, 56 | dropout=args.rnn_dropout, 57 | ) 58 | l3_model = l3_model.cuda() 59 | 60 | embedding_model = embedding_model.cuda() 61 | lang_model = lang_model.cuda() 62 | 63 | model = ProtoNet( 64 | model_dict[args.model], 65 | n_way=args.test_n_way, 66 | n_support=args.n_shot, 67 | # Language options 68 | lsl=args.lsl, 69 | language_model=lang_model, 70 | lang_supervision=args.lang_supervision, 71 | l3=args.l3, 72 | l3_model=l3_model, 73 | l3_n_infer=args.l3_n_infer, 74 | ) 75 | 76 | model = model.cuda() 77 | 78 | if args.save_iter != -1: 79 | modelfile = get_assigned_file(args.checkpoint_dir, args.save_iter) 80 | else: 81 | modelfile = get_best_file(args.checkpoint_dir) 82 | 83 | if modelfile is not None: 84 | tmp = torch.load(modelfile) 85 | model.load_state_dict( 86 | tmp["state"], 87 | # If language was used for pretraining, ignore 88 | # the language model component here. If we want to use language, 89 | # make sure the model is loaded 90 | strict=args.lsl, 91 | ) 92 | 93 | if args.save_embeddings: 94 | if args.lsl: 95 | weights = model.language_model.embedding.weight.detach().cpu().numpy() 96 | vocab_srt = sorted(list(vocab.items()), key=lambda x: x[1]) 97 | vocab_srt = [v[0] for v in vocab_srt] 98 | with open(args.embeddings_file, "w") as fout: 99 | fout.write("\n".join(vocab_srt)) 100 | fout.write("\n") 101 | np.savetxt(args.embeddings_metadata, weights, fmt="%f", delimiter="\t") 102 | sys.exit(0) 103 | 104 | # Run the test loop for 600 iterations 105 | ITER_NUM = 600 106 | N_QUERY = 15 107 | 108 | test_datamgr = SetDataManager( 109 | "CUB", 110 | 84, 111 | n_query=N_QUERY, 112 | n_way=args.test_n_way, 113 | n_support=args.n_shot, 114 | n_episode=ITER_NUM, 115 | args=args, 116 | ) 117 | test_loader = test_datamgr.get_data_loader( 118 | os.path.join(constants.DATA_DIR, f"{args.split}.json"), 119 | aug=False, 120 | lang_dir=constants.LANG_DIR, 121 | normalize=False, 122 | vocab=vocab, 123 | ) 124 | normalizer = TransformLoader(84).get_normalize() 125 | 126 | model.eval() 127 | 128 | acc_all = model.test_loop( 129 | test_loader, 130 | normalizer=normalizer, 131 | verbose=True, 132 | return_all=True, 133 | # Debug on first loop only 134 | debug=args.debug, 135 | debug_dir=os.path.split(args.checkpoint_dir)[0], 136 | ) 137 | acc_mean = np.mean(acc_all) 138 | acc_std = np.std(acc_all) 139 | print( 140 | "%d Test Acc = %4.2f%% +- %4.2f%%" 141 | % (ITER_NUM, acc_mean, 1.96 * acc_std / np.sqrt(ITER_NUM)) 142 | ) 143 | 144 | with open(args.record_file, "a") as f: 145 | timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime()) 146 | acc_ci = 1.96 * acc_std / np.sqrt(ITER_NUM) 147 | f.write( 148 | json.dumps( 149 | { 150 | "time": timestamp, 151 | "split": args.split, 152 | "setting": args.checkpoint_dir, 153 | "iter_num": ITER_NUM, 154 | "acc": acc_mean, 155 | "acc_ci": acc_ci, 156 | "acc_all": list(acc_all), 157 | "acc_std": acc_std, 158 | }, 159 | sort_keys=True, 160 | ) 161 | ) 162 | f.write("\n") 163 | -------------------------------------------------------------------------------- /birds/fewshot/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train script. 3 | """ 4 | 5 | import json 6 | import os 7 | from collections import defaultdict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim 12 | from tqdm import tqdm 13 | 14 | import constants 15 | from data import lang_utils 16 | from data.datamgr import SetDataManager 17 | from io_utils import get_resume_file, model_dict, parse_args 18 | from models.language import TextProposal, TextRep 19 | from models.protonet import ProtoNet 20 | 21 | 22 | def get_optimizer(model, args): 23 | """ 24 | Get the optimizer for the model based on arguments. Specifically, if 25 | needed, we split up training into (1) main parameters, (2) RNN-specific 26 | parameters, with different learning rates if specified. 27 | 28 | :param model: nn.Module to train 29 | :param args: argparse.Namespace - other args passed to the script 30 | 31 | :returns: a torch.optim.Optimizer 32 | """ 33 | # Get params 34 | main_params = {"params": []} 35 | rnn_params = {"params": [], "lr": args.rnn_lr_scale * args.lr} 36 | for name, param in model.named_parameters(): 37 | if not param.requires_grad: 38 | continue 39 | if name.startswith("language_model."): 40 | # Scale RNN learning rate 41 | rnn_params["params"].append(param) 42 | else: 43 | main_params["params"].append(param) 44 | if args.lsl and not rnn_params["params"]: 45 | print("Warning: --lsl is set but no RNN parameters found") 46 | params_to_optimize = [main_params, rnn_params] 47 | 48 | # Define optimizer 49 | if args.optimizer == "adam": 50 | optimizer = torch.optim.Adam(params_to_optimize, lr=args.lr) 51 | elif args.optimizer == "amsgrad": 52 | optimizer = torch.optim.Adam(params_to_optimize, lr=args.lr, amsgrad=True) 53 | elif args.optimizer == "rmsprop": 54 | optimizer = torch.optim.RMSprop(params_to_optimize, lr=args.lr) 55 | else: 56 | raise NotImplementedError("optimizer = {}".format(args.optimizer)) 57 | return optimizer 58 | 59 | 60 | def train( 61 | base_loader, 62 | val_loader, 63 | model, 64 | start_epoch, 65 | stop_epoch, 66 | args, 67 | metrics_fname="metrics.json", 68 | ): 69 | """ 70 | Main training script. 71 | 72 | :param base_loader: torch.utils.data.DataLoader for training set, generated 73 | by data.datamgr.SetDataManager 74 | :param val_loader: torch.utils.data.DataLoader for validation set, 75 | generated by data.datamgr.SetDataManager 76 | :param model: nn.Module to train 77 | :param start_epoch: which epoch we started at 78 | :param stop_epoch: which epoch to end at 79 | :param args: other arguments passed to the script 80 | "param metrics_fname": where to save metrics 81 | """ 82 | optimizer = get_optimizer(model, args) 83 | 84 | max_val_acc = 0 85 | best_epoch = 0 86 | 87 | val_accs = [] 88 | val_losses = [] 89 | all_metrics = defaultdict(list) 90 | for epoch in tqdm( 91 | range(start_epoch, stop_epoch), total=stop_epoch - start_epoch, desc="Train" 92 | ): 93 | model.train() 94 | metric = model.train_loop(epoch, base_loader, optimizer, args) 95 | for m, val in metric.items(): 96 | all_metrics[m].append(val) 97 | model.eval() 98 | 99 | os.makedirs(args.checkpoint_dir, exist_ok=True) 100 | 101 | val_acc, val_loss = model.test_loop(val_loader,) 102 | val_accs.append(val_acc) 103 | val_losses.append(val_loss) 104 | if val_acc > max_val_acc: 105 | best_epoch = epoch 106 | tqdm.write("best model! save...") 107 | max_val_acc = val_acc 108 | outfile = os.path.join(args.checkpoint_dir, "best_model.tar") 109 | torch.save({"epoch": epoch, "state": model.state_dict()}, outfile) 110 | 111 | if epoch and (epoch % args.save_freq == 0) or (epoch == stop_epoch - 1): 112 | outfile = os.path.join(args.checkpoint_dir, "{:d}.tar".format(epoch)) 113 | torch.save({"epoch": epoch, "state": model.state_dict()}, outfile) 114 | tqdm.write("") 115 | 116 | # Save metrics 117 | metrics = { 118 | "train_acc": all_metrics["train_acc"], 119 | "current_train_acc": all_metrics["train_acc"][-1], 120 | "train_loss": all_metrics["train_loss"], 121 | "current_train_loss": all_metrics["train_loss"][-1], 122 | "cls_loss": all_metrics["cls_loss"], 123 | "current_cls_loss": all_metrics["cls_loss"][-1], 124 | "lang_loss": all_metrics["lang_loss"], 125 | "current_lang_loss": all_metrics["lang_loss"][-1], 126 | "current_epoch": epoch, 127 | "val_acc": val_accs, 128 | "val_loss": val_losses, 129 | "current_val_loss": val_losses[-1], 130 | "current_val_acc": val_acc, 131 | "best_epoch": best_epoch, 132 | "best_val_acc": max_val_acc, 133 | } 134 | with open(os.path.join(args.checkpoint_dir, metrics_fname), "w") as fout: 135 | json.dump(metrics, fout, sort_keys=True, indent=4, separators=(",", ": ")) 136 | 137 | # Save a copy to current metrics too 138 | if ( 139 | metrics_fname != "metrics.json" 140 | and metrics_fname.startswith("metrics_") 141 | and metrics_fname.endswith(".json") 142 | ): 143 | metrics["n"] = int(metrics_fname[8]) 144 | with open(os.path.join(args.checkpoint_dir, "metrics.json"), "w") as fout: 145 | json.dump( 146 | metrics, fout, sort_keys=True, indent=4, separators=(",", ": ") 147 | ) 148 | 149 | # If didn't train, save model anyways 150 | if stop_epoch == 0: 151 | outfile = os.path.join(args.checkpoint_dir, "best_model.tar") 152 | torch.save({"epoch": stop_epoch, "state": model.state_dict()}, outfile) 153 | 154 | 155 | if __name__ == "__main__": 156 | args = parse_args("train") 157 | 158 | if args.seed is not None: 159 | torch.manual_seed(args.seed) 160 | # I don't seed the np rng since dataset loading uses multiprocessing with 161 | # random choices. 162 | # https://github.com/numpy/numpy/issues/9650 163 | # Unavoidable undeterminism here for now 164 | 165 | base_file = os.path.join(constants.DATA_DIR, "base.json") 166 | val_file = os.path.join(constants.DATA_DIR, "val.json") 167 | 168 | # Load language 169 | vocab = lang_utils.load_vocab(constants.LANG_DIR) 170 | 171 | l3_model = None 172 | lang_model = None 173 | if args.lsl or args.l3: 174 | if args.glove_init: 175 | vecs = lang_utils.glove_init(vocab, emb_size=args.lang_emb_size) 176 | embedding_model = nn.Embedding( 177 | len(vocab), args.lang_emb_size, _weight=vecs if args.glove_init else None 178 | ) 179 | if args.freeze_emb: 180 | embedding_model.weight.requires_grad = False 181 | 182 | lang_input_size = 1600 183 | lang_model = TextProposal( 184 | embedding_model, 185 | input_size=lang_input_size, 186 | hidden_size=args.lang_hidden_size, 187 | project_input=lang_input_size != args.lang_hidden_size, 188 | rnn=args.rnn_type, 189 | num_layers=args.rnn_num_layers, 190 | dropout=args.rnn_dropout, 191 | vocab=vocab, 192 | **lang_utils.get_special_indices(vocab) 193 | ) 194 | 195 | if args.l3: 196 | l3_model = TextRep( 197 | embedding_model, 198 | hidden_size=args.lang_hidden_size, 199 | rnn=args.rnn_type, 200 | num_layers=args.rnn_num_layers, 201 | dropout=args.rnn_dropout, 202 | ) 203 | l3_model = l3_model.cuda() 204 | 205 | embedding_model = embedding_model.cuda() 206 | lang_model = lang_model.cuda() 207 | 208 | # if test_n_way is smaller than train_n_way, reduce n_query to keep batch 209 | # size small 210 | n_query = max(1, int(16 * args.test_n_way / args.train_n_way)) 211 | 212 | train_few_shot_args = dict(n_way=args.train_n_way, n_support=args.n_shot) 213 | base_datamgr = SetDataManager( 214 | "CUB", 84, n_query=n_query, **train_few_shot_args, args=args 215 | ) 216 | print("Loading train data") 217 | 218 | base_loader = base_datamgr.get_data_loader( 219 | base_file, 220 | aug=True, 221 | lang_dir=constants.LANG_DIR, 222 | normalize=True, 223 | vocab=vocab, 224 | # Maximum training data restrictions only apply at train time 225 | max_class=args.max_class, 226 | max_img_per_class=args.max_img_per_class, 227 | max_lang_per_class=args.max_lang_per_class, 228 | ) 229 | 230 | val_datamgr = SetDataManager( 231 | "CUB", 232 | 84, 233 | n_query=n_query, 234 | n_way=args.test_n_way, 235 | n_support=args.n_shot, 236 | args=args, 237 | ) 238 | print("Loading val data\n") 239 | val_loader = val_datamgr.get_data_loader( 240 | val_file, aug=False, lang_dir=constants.LANG_DIR, normalize=True, vocab=vocab, 241 | ) 242 | # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor 243 | 244 | model = ProtoNet( 245 | model_dict[args.model], 246 | **train_few_shot_args, 247 | # Language options 248 | lsl=args.lsl, 249 | language_model=lang_model, 250 | lang_supervision=args.lang_supervision, 251 | l3=args.l3, 252 | l3_model=l3_model, 253 | l3_n_infer=args.l3_n_infer 254 | ) 255 | 256 | model = model.cuda() 257 | 258 | os.makedirs(args.checkpoint_dir, exist_ok=True) 259 | 260 | start_epoch = args.start_epoch 261 | stop_epoch = args.stop_epoch 262 | 263 | if args.resume: 264 | resume_file = get_resume_file(args.checkpoint_dir) 265 | if resume_file is not None: 266 | tmp = torch.load(resume_file) 267 | start_epoch = tmp["epoch"] + 1 268 | model.load_state_dict(tmp["state"]) 269 | 270 | metrics_fname = "metrics_{}.json".format(args.n) 271 | 272 | train( 273 | base_loader, 274 | val_loader, 275 | model, 276 | start_epoch, 277 | stop_epoch, 278 | args, 279 | metrics_fname=metrics_fname, 280 | ) 281 | -------------------------------------------------------------------------------- /birds/filelists/CUB/download_CUB.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 3 | tar -zxvf CUB_200_2011.tgz 4 | python write_CUB_filelist.py 5 | -------------------------------------------------------------------------------- /birds/filelists/CUB/save_np.py: -------------------------------------------------------------------------------- 1 | """ 2 | For each class, load images and save as numpy arrays. 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | if __name__ == "__main__": 12 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 13 | 14 | parser = ArgumentParser( 15 | description="Save numpy", formatter_class=ArgumentDefaultsHelpFormatter 16 | ) 17 | 18 | parser.add_argument( 19 | "--cub_dir", default="CUB_200_2011/images", help="Directory to load/cache" 20 | ) 21 | parser.add_argument( 22 | "--original_cub_dir", 23 | default="CUB_200_2011/images", 24 | help="Original CUB directory if you want the image keys to be different (in case --cub_dir has changed)", 25 | ) 26 | parser.add_argument("--filelist_prefix", default="./filelists/CUB/") 27 | 28 | args = parser.parse_args() 29 | 30 | for bird_class in tqdm(os.listdir(args.cub_dir), desc="Classes"): 31 | bird_imgs_np = {} 32 | class_dir = os.path.join(args.cub_dir, bird_class) 33 | bird_imgs = sorted([x for x in os.listdir(class_dir) if x != "img.npz"]) 34 | for bird_img in bird_imgs: 35 | bird_img_fname = os.path.join(class_dir, bird_img) 36 | img = Image.open(bird_img_fname).convert("RGB") 37 | img_np = np.asarray(img) 38 | 39 | full_bird_img_fname = os.path.join( 40 | args.filelist_prefix, args.original_cub_dir, bird_class, bird_img 41 | ) 42 | 43 | bird_imgs_np[full_bird_img_fname] = img_np 44 | 45 | np_fname = os.path.join(class_dir, "img.npz") 46 | np.savez_compressed(np_fname, **bird_imgs_np) 47 | -------------------------------------------------------------------------------- /birds/filelists/CUB/write_CUB_filelist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os import listdir 3 | from os.path import isfile, isdir, join 4 | import os 5 | import json 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | if __name__ == '__main__': 11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 12 | 13 | parser = ArgumentParser() 14 | 15 | parser.add_argument('--seed', type=int, default=0, help='Random seed') 16 | parser.add_argument('--savedir', type=str, default='../../custom_filelists/CUB/', 17 | help='Directory to save filelists') 18 | 19 | args = parser.parse_args() 20 | 21 | random = np.random.RandomState(args.seed) 22 | 23 | filelist_path = './filelists/CUB/' 24 | data_path = 'CUB_200_2011/images' 25 | dataset_list = ['base', 'val', 'novel'] 26 | 27 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))] 28 | folder_list.sort() 29 | label_dict = dict(zip(folder_list, range(0, len(folder_list)))) 30 | 31 | classfile_list_all = [] 32 | 33 | # Load attributes 34 | attrs = pd.read_csv('./CUB_200_2011/attributes/image_attribute_labels.txt', 35 | sep=' ', 36 | header=None, 37 | names=['image_id', 'attribute_id', 'is_present', 'certainty_id', 'time']) 38 | # Zero out attributes with certainty < 3 39 | attrs['is_present'] = np.where(attrs['certainty_id'] < 3, 0, attrs['is_present']) 40 | # Get image names 41 | image_names = pd.read_csv('./CUB_200_2011/images.txt', sep=' ', 42 | header=None, 43 | names=['image_id', 'image_name']) 44 | attrs = attrs.merge(image_names, on='image_id') 45 | attrs['is_present'] = attrs['is_present'].astype(str) 46 | attrs = attrs.groupby('image_name')['is_present'].apply(lambda col: ''.join(col)) 47 | attrs = dict(zip(attrs.index, attrs)) 48 | attrs = {os.path.basename(k): v for k, v in attrs.items()} 49 | 50 | for i, folder in enumerate(folder_list): 51 | folder_path = join(data_path, folder) 52 | classfile_list_all.append([ 53 | join(filelist_path, folder_path, cf) for cf in listdir(folder_path) 54 | if (isfile(join(folder_path, cf)) and cf[0] != '.' and not cf.endswith('.npz')) 55 | ]) 56 | random.shuffle(classfile_list_all[i]) 57 | 58 | for dataset in dataset_list: 59 | file_list = [] 60 | label_list = [] 61 | for i, classfile_list in enumerate(classfile_list_all): 62 | if 'base' in dataset: 63 | if (i % 2 == 0): 64 | file_list.extend(classfile_list) 65 | label_list.extend(np.repeat( 66 | i, len(classfile_list)).tolist()) 67 | if 'val' in dataset: 68 | if (i % 4 == 1): 69 | file_list.extend(classfile_list) 70 | label_list.extend(np.repeat( 71 | i, len(classfile_list)).tolist()) 72 | if 'novel' in dataset: 73 | if (i % 4 == 3): 74 | file_list.extend(classfile_list) 75 | label_list.extend(np.repeat( 76 | i, len(classfile_list)).tolist()) 77 | 78 | # Get attributes 79 | attribute_list = [ 80 | attrs[os.path.basename(f)] for f in file_list if not f.endswith('.npz') 81 | ] 82 | 83 | djson = { 84 | 'label_names': folder_list, 85 | 'image_names': file_list, 86 | 'image_labels': label_list, 87 | 'image_attributes': attribute_list, 88 | } 89 | 90 | os.makedirs(args.savedir, exist_ok=True) 91 | with open(os.path.join(args.savedir, dataset + '.json'), 'w') as fout: 92 | json.dump(djson, fout) 93 | 94 | print("%s -OK" % dataset) 95 | -------------------------------------------------------------------------------- /birds/run_l3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/l3 --l3 --glove_init --lang_lambda 5 --max_lang_per_class 20 --sample_class_lang 4 | -------------------------------------------------------------------------------- /birds/run_lang_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Color 4 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_ablation/lsl_color --lsl --glove_init --lang_lambda 5 --language_filter color --max_lang_per_class 20 --sample_class_lang 5 | 6 | # Nocolor 7 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_ablation/lsl_nocolor --lsl --glove_init --lang_lambda 5 --language_filter nocolor --max_lang_per_class 20 --sample_class_lang 8 | 9 | # Shuffled words 10 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_ablation/lsl_shuffled_words --lsl --glove_init --lang_lambda 5 --shuffle_lang --max_lang_per_class 20 --sample_class_lang 11 | 12 | # Shuffled captions 13 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_ablation/lsl_shuffled_captions --lsl --glove_init --lang_lambda 5 --scramble_all --max_lang_per_class 20 --sample_class_lang 14 | -------------------------------------------------------------------------------- /birds/run_lang_amount.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for max_lang_per_class in 1 5 10 20 30 40 50 60; do 4 | # LSL 5 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_amount/lsl_max_lang_$max_lang_per_class --lsl --glove_init --lang_lambda 5 --max_lang_per_class $max_lang_per_class --sample_class_lang 6 | 7 | # L3 8 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_amount/l3_max_lang_$max_lang_per_class --l3 --glove_init --max_lang_per_class $max_lang_per_class --sample_class_lang 9 | done 10 | -------------------------------------------------------------------------------- /birds/run_lsl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Standard LSL 4 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/lsl --lsl --glove_init --lang_lambda 5 --max_lang_per_class 20 --sample_class_lang 5 | -------------------------------------------------------------------------------- /birds/run_meta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/meta 4 | -------------------------------------------------------------------------------- /shapeworld/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | notebooks/ 107 | tmp/ 108 | .vscode 109 | viz/ 110 | 111 | # History files 112 | .Rhistory 113 | .Rapp.history 114 | 115 | # Session Data files 116 | .RData 117 | 118 | # Example code in package build process 119 | *-Ex.R 120 | 121 | # Output files from R CMD build 122 | /*.tar.gz 123 | 124 | # Output files from R CMD check 125 | /*.Rcheck/ 126 | 127 | # RStudio files 128 | .Rproj.user/ 129 | 130 | # produced vignettes 131 | vignettes/*.html 132 | vignettes/*.pdf 133 | 134 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 135 | .httr-oauth 136 | 137 | # knitr and R markdown default cache directories 138 | /*_cache/ 139 | /cache/ 140 | # R markdown files directories 141 | /*_files/ 142 | 143 | # Temporary files created by R markdown 144 | *.utf8.md 145 | *.knit.md 146 | .Rproj.user 147 | 148 | *.nb.html 149 | notebooks/ 150 | 151 | exp/* 152 | -------------------------------------------------------------------------------- /shapeworld/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Mike Wu, Jesse Mu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /shapeworld/README.md: -------------------------------------------------------------------------------- 1 | # LSL - ShapeWorld experiments 2 | 3 | This code is graciously adapted from code written by [Mike Wu](https://www.mikehwu.com/). 4 | 5 | ## Dependencies 6 | 7 | Tested with Python 3.7.4, torch 1.3.0, torchvision 0.4.1, sklearn 0.21.3, and numpy 1.17.2. 8 | 9 | ## Data 10 | 11 | Download data [here](http://nlp.stanford.edu/data/muj/shapeworld_4k.tar.gz) 12 | (~850 MB). Untar, and set `DATA_DIR` in `datasets.py` to be 13 | point to the folder *containing* the ShapeWorld folder you just unzipped. 14 | 15 | This code works with Jacob Andreas' [original ShapeWorld data 16 | files](http://people.eecs.berkeley.edu/~jda/data/shapeworld.tar.gz) if you replace 17 | every `.npz` file with `.npy` in `datasets.py` and remove the `['arr_0']` indexing after each `np.load`. 18 | Results are similar, but with higher variance on test accuracies. 19 | 20 | For more details on the dataset (and how to reproduce it), check 21 | [jacobandreas/l3](https://github.com/jacobandreas/l3) and the accompanying 22 | [paper](https://arxiv.org/abs/1711.00482) 23 | 24 | ## Running 25 | 26 | The models can be run with the scripts in this directory: 27 | 28 | - `run_l3.sh` - L3 29 | - `run_lsl.sh` - LSL (ours) 30 | - `run_lsl_img.sh` - LSL, but decoding captions from the image embeddings 31 | instead of the concept (not reported) 32 | - `run_meta.sh` - meta-learning baseline 33 | - `run_lang_ablation.sh` - language ablation studies 34 | 35 | They will output results in the `exp/` directory (paper runs are already present there) 36 | 37 | To change the backbone, use `--backbone conv4` or `--backbone ResNet18`. ResNet18 may need reduced batch size (we use batch size 32) 38 | 39 | ## Analysis 40 | 41 | `analysis/metrics.Rmd` contains `R` code for reproducing the plots in the 42 | paper. 43 | -------------------------------------------------------------------------------- /shapeworld/analysis/analysis.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | -------------------------------------------------------------------------------- /shapeworld/analysis/metrics.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Metrics analysis" 3 | output: html_notebook 4 | --- 5 | 6 | ```{r setup} 7 | library(tidyverse) 8 | library(cowplot) 9 | library(jsonlite) 10 | theme_set(theme_cowplot()) 11 | ``` 12 | 13 | # Standard eval 14 | 15 | ```{r} 16 | MODELS <- c( 17 | 'l3' = 'L3', 18 | 'lsl' = 'LSL', 19 | 'meta' = 'Meta' 20 | ) 21 | ``` 22 | 23 | ```{r} 24 | metrics <- sapply(names(MODELS), function(model) { 25 | metrics_file <- paste0('../exp/', model, '/metrics.json') 26 | metrics_df <- as.data.frame(read_json(metrics_file, simplifyVector = TRUE)) %>% 27 | tbl_df %>% 28 | select(train_acc, val_acc, val_same_acc, test_acc, test_same_acc) %>% 29 | mutate(avg_val_acc = (val_acc + val_same_acc) / 2, 30 | avg_test_acc = (test_acc + test_same_acc) / 2) 31 | metrics_df %>% 32 | mutate(epoch = 1:nrow(metrics_df)) %>% 33 | mutate(model = MODELS[model]) 34 | }, simplify = FALSE) %>% 35 | do.call(rbind, .) %>% 36 | mutate(model = factor(model)) 37 | 38 | metrics_long <- metrics %>% 39 | gather('metric', 'value', -epoch, -model) %>% 40 | mutate(metric = factor(metric, levels = c('train_acc', 'avg_val_acc', 'val_acc', 'val_same_acc', 'avg_test_acc', 'test_acc', 'test_same_acc'))) 41 | ``` 42 | 43 | ```{r fig.width=3.5, fig.height=2} 44 | metric_names <- c('train_acc' = 'Train', 'avg_val_acc' = 'Val', 'avg_test_acc' = 'Test') 45 | ggplot(metrics_long %>% filter(metric %in% c('train_acc', 'avg_val_acc', 'avg_test_acc')) %>% rename(Model = model), aes(x = epoch, y = value, color = Model)) + 46 | geom_line() + 47 | facet_wrap(~ metric, labeller = as_labeller(metric_names)) + 48 | xlab('Epoch') + 49 | ylab('Accuracy') 50 | ``` -------------------------------------------------------------------------------- /shapeworld/exp/README.md: -------------------------------------------------------------------------------- 1 | # Experiment results folder 2 | -------------------------------------------------------------------------------- /shapeworld/exp/l3/args.json: -------------------------------------------------------------------------------- 1 | {"exp_dir": "exp/vigil/l3", "predict_concept_hyp": false, "predict_image_hyp": false, "infer_hyp": true, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 523, "language_filter": null, "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 1.0, "save_checkpoint": false, "cuda": true, "predict_hyp": false, "use_hyp": true, "encode_hyp": true, "decode_hyp": true} 2 | -------------------------------------------------------------------------------- /shapeworld/exp/l3/metrics.json: -------------------------------------------------------------------------------- 1 | {"train_acc": [0.5023333333333333, 0.5304444444444445, 0.545, 0.573, 0.662, 0.6244444444444445, 0.6083333333333333, 0.7182222222222222, 0.5692222222222222, 0.7284444444444444, 0.7607777777777778, 0.7855555555555556, 0.7686666666666667, 0.7953333333333333, 0.8346666666666667, 0.8088888888888889, 0.8265555555555556, 0.8394444444444444, 0.8614444444444445, 0.7967777777777778, 0.8381111111111111, 0.862, 0.86, 0.8982222222222223, 0.878, 0.8752222222222222, 0.9072222222222223, 0.8603333333333333, 0.9007777777777778, 0.8971111111111111, 0.9114444444444444, 0.8907777777777778, 0.9074444444444445, 0.9084444444444445, 0.8948888888888888, 0.8741111111111111, 0.8701111111111111, 0.9282222222222222, 0.891, 0.9207777777777778, 0.932, 0.9013333333333333, 0.9011111111111111, 0.9132222222222223, 0.9084444444444445, 0.9397777777777778, 0.9152222222222223, 0.9025555555555556, 0.9347777777777778, 0.916], "val_acc": [0.508, 0.5, 0.52, 0.486, 0.536, 0.524, 0.524, 0.538, 0.53, 0.546, 0.598, 0.568, 0.572, 0.56, 0.62, 0.608, 0.582, 0.588, 0.594, 0.61, 0.618, 0.626, 0.632, 0.58, 0.622, 0.606, 0.598, 0.65, 0.64, 0.624, 0.6, 0.632, 0.62, 0.618, 0.608, 0.63, 0.634, 0.596, 0.63, 0.616, 0.628, 0.614, 0.612, 0.632, 0.602, 0.628, 0.648, 0.652, 0.614, 0.622], "val_same_acc": [0.498, 0.494, 0.52, 0.546, 0.508, 0.556, 0.52, 0.532, 0.532, 0.57, 0.582, 0.582, 0.594, 0.582, 0.602, 0.614, 0.596, 0.566, 0.594, 0.636, 0.644, 0.612, 0.63, 0.654, 0.6, 0.632, 0.616, 0.69, 0.648, 0.66, 0.662, 0.648, 0.638, 0.674, 0.646, 0.69, 0.676, 0.614, 0.664, 0.67, 0.626, 0.674, 0.67, 0.668, 0.66, 0.64, 0.692, 0.694, 0.61, 0.642], "val_tre": [0.25633602127432825, 0.29470103612542153, 0.3173871642053127, 0.32591839861869815, 0.33455874407291414, 0.3444766681492329, 0.3517530900835991, 0.3546512795686722, 0.3629691895842552, 0.37140207839012146, 0.3772867166697979, 0.38848341038823125, 0.39334050261974335, 0.41172417044639587, 0.41499094939231873, 0.42147668239474295, 0.4337600200772285, 0.44199954602122304, 0.45299749591946603, 0.4578404571712017, 0.4647700753211975, 0.471630163282156, 0.4751690610051155, 0.4885770261287689, 0.493298515021801, 0.4948560266792774, 0.4987256095409393, 0.5102960558831692, 0.5090047884583473, 0.510729871481657, 0.5190574961602687, 0.5180712405443192, 0.5254904857873917, 0.5273345997929573, 0.5298865022361279, 0.5335929306447506, 0.5398195767402649, 0.5402879483401776, 0.5459680155813694, 0.5467929720878602, 0.5468201187252998, 0.5507783033549786, 0.5473250425457954, 0.5527532941997051, 0.5607433926463127, 0.5589669682085514, 0.5609841901659965, 0.5578600949645043, 0.5605509672164917, 0.563900175690651], "val_tre_std": [0.0008983922636733366, 0.0012393818670073098, 0.002705310356240471, 0.004119969486292012, 0.004415815432919283, 0.004880650390603528, 0.008861252828759958, 0.009476131311003444, 0.009022034651017733, 0.01143623739051007, 0.017494352103482703, 0.017762730295086442, 0.021317160074330523, 0.023186749583176555, 0.022208067779347494, 0.02261467769505962, 0.021649295420653976, 0.02183235733200545, 0.02308747601341338, 0.024546957098724283, 0.027185148108284336, 0.02425879832978174, 0.025027398060597974, 0.024505227514392356, 0.025782036503912583, 0.02413022162730152, 0.02493159414026219, 0.02113466113404446, 0.02219720340775683, 0.023728708218024107, 0.023085432824290267, 0.022559277009476516, 0.021082738316435375, 0.024827830885288418, 0.022480336133616313, 0.022559989002424452, 0.02072817771662128, 0.01973465625196693, 0.018798697877029417, 0.019722927256330926, 0.023218060008900417, 0.020582324265838555, 0.02138545180626131, 0.020373384768127337, 0.01968535742858713, 0.021034499364920817, 0.019202968664080812, 0.020055075653409696, 0.019806053228292417, 0.021419522391281927], "test_acc": [0.49675, 0.5085, 0.515, 0.53325, 0.539, 0.52825, 0.541, 0.5665, 0.5155, 0.57875, 0.58325, 0.579, 0.5835, 0.58175, 0.60325, 0.61875, 0.60825, 0.6055, 0.60975, 0.632, 0.638, 0.6335, 0.64175, 0.62575, 0.63975, 0.644, 0.632, 0.65275, 0.6365, 0.6455, 0.63525, 0.64425, 0.63775, 0.646, 0.64475, 0.682, 0.66075, 0.62875, 0.6405, 0.6415, 0.6535, 0.65975, 0.6585, 0.6655, 0.654, 0.62825, 0.64875, 0.67, 0.6345, 0.648], "test_same_acc": [0.4985, 0.51025, 0.51725, 0.5185, 0.5445, 0.54175, 0.53425, 0.5655, 0.522, 0.57925, 0.5795, 0.57975, 0.59125, 0.59425, 0.6055, 0.62575, 0.613, 0.59075, 0.6055, 0.636, 0.63925, 0.63925, 0.64675, 0.618, 0.63275, 0.64875, 0.6225, 0.63025, 0.64375, 0.64175, 0.642, 0.6535, 0.6365, 0.64925, 0.6525, 0.66225, 0.66975, 0.6335, 0.65175, 0.63025, 0.63575, 0.65775, 0.665, 0.66, 0.657, 0.61475, 0.64725, 0.662, 0.6175, 0.63975], "test_acc_ci": [0.010955877061536423, 0.010956264336984343, 0.010956360212081153, 0.010956483478830012, 0.010955374028159649, 0.010956264336984343, 0.010955877061536423, 0.010956675224349538, 0.010955697287113908, 0.0109567245297979, 0.010956683784339127, 0.010955659277286785, 0.010955659277286785, 0.010953376733769135, 0.010954702824443708, 0.010951121805550332, 0.010950020034228245, 0.010948714963267377, 0.01094729390705758, 0.010922705819616081, 0.010919382294228, 0.01092869042417709, 0.010911272425798927, 0.010930996998787437, 0.010903721834624864, 0.010910012236330214, 0.01091202021840479, 0.010886866255809107, 0.010906127598740076, 0.010874212600804713, 0.010906391433323625, 0.010866250967876408, 0.0108775585496011, 0.010902360982259758, 0.010861949694777408, 0.010840536777386764, 0.010819544733808119, 0.010893535577207014, 0.0108775585496011, 0.010881481312303026, 0.010882125315741406, 0.010857547206753235, 0.010835660275076688, 0.010835249299739252, 0.010826880986115069, 0.01087854868407891, 0.010847660668175189, 0.010826880986115069, 0.010896743733221406, 0.010856057197338035], "best_epoch": 48, "best_val_acc": 0.652, "best_val_same_acc": 0.694, "best_val_tre": 0.5578600949645043, "best_val_tre_std": 0.020055075653409696, "best_test_acc": 0.67, "best_test_same_acc": 0.662, "best_test_acc_ci": 0.010826880986115069, "lowest_val_tre": 0.25633602127432825, "lowest_val_tre_std": 0.0008983922636733366, "has_same": true} -------------------------------------------------------------------------------- /shapeworld/exp/lsl/args.json: -------------------------------------------------------------------------------- 1 | {"exp_dir": "exp/lsl", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 27140, "language_filter": null, "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true} 2 | -------------------------------------------------------------------------------- /shapeworld/exp/lsl/metrics.json: -------------------------------------------------------------------------------- 1 | {"train_acc": [0.5018888888888889, 0.5018888888888889, 0.5592222222222222, 0.6594444444444445, 0.7117777777777777, 0.731, 0.7423333333333333, 0.7373333333333333, 0.7316666666666667, 0.7376666666666667, 0.7403333333333333, 0.7303333333333333, 0.7413333333333333, 0.7401111111111112, 0.7433333333333333, 0.7292222222222222, 0.7354444444444445, 0.7391111111111112, 0.7423333333333333, 0.7348888888888889, 0.7391111111111112, 0.742, 0.745, 0.7401111111111112, 0.7373333333333333, 0.7472222222222222, 0.7374444444444445, 0.714, 0.7431111111111111, 0.742, 0.7372222222222222, 0.7266666666666667, 0.7314444444444445, 0.7421111111111112, 0.7452222222222222, 0.7514444444444445, 0.756, 0.7337777777777778, 0.742, 0.7391111111111112, 0.7505555555555555, 0.7317777777777777, 0.7293333333333333, 0.7471111111111111, 0.7375555555555555, 0.7288888888888889, 0.7352222222222222, 0.7392222222222222, 0.7356666666666667, 0.7437777777777778], "val_acc": [0.508, 0.508, 0.556, 0.6, 0.598, 0.612, 0.608, 0.614, 0.618, 0.618, 0.636, 0.65, 0.656, 0.636, 0.656, 0.66, 0.668, 0.664, 0.654, 0.646, 0.656, 0.654, 0.66, 0.642, 0.65, 0.636, 0.644, 0.652, 0.652, 0.662, 0.654, 0.638, 0.642, 0.656, 0.646, 0.666, 0.65, 0.666, 0.676, 0.67, 0.664, 0.668, 0.65, 0.668, 0.662, 0.656, 0.658, 0.65, 0.664, 0.652], "val_same_acc": [0.496, 0.496, 0.528, 0.558, 0.59, 0.592, 0.61, 0.628, 0.638, 0.624, 0.644, 0.656, 0.666, 0.664, 0.676, 0.656, 0.658, 0.662, 0.654, 0.64, 0.676, 0.682, 0.674, 0.674, 0.678, 0.686, 0.684, 0.674, 0.664, 0.684, 0.694, 0.668, 0.696, 0.71, 0.698, 0.7, 0.69, 0.678, 0.68, 0.694, 0.696, 0.706, 0.702, 0.698, 0.656, 0.692, 0.682, 0.676, 0.68, 0.692], "val_tre": [0.23315593773126603, 0.3193679120838642, 0.4544656649827957, 0.6009624934494495, 0.6937140434384346, 0.7510887870192527, 0.7626189323961735, 0.7726548528671264, 0.7821315121650696, 0.7757832805812359, 0.777156808435917, 0.7755475530028343, 0.803969556093216, 0.7948731675744056, 0.8089491415023804, 0.8003908374905586, 0.79257049202919, 0.7978935792148113, 0.7970038573145867, 0.7915700083374977, 0.8110601940453053, 0.8116079128980637, 0.8045996649861336, 0.8138662700951099, 0.8021689679026603, 0.8072018190920353, 0.8115154400467872, 0.7561183496713638, 0.7906409577429294, 0.7920836460888385, 0.7970328702330589, 0.7686191479265689, 0.7872865536510945, 0.8017647383213043, 0.7931095496416092, 0.7974768998026848, 0.8039976232647896, 0.7989465886652469, 0.7998599541783333, 0.7994992446899414, 0.7942561790943146, 0.7959962756037712, 0.7920953050553798, 0.799251141756773, 0.7694647860527039, 0.7820107011795044, 0.7913144878745079, 0.7904651002883911, 0.7779531913697719, 0.7935415287613868], "val_tre_std": [0.015967926829254738, 0.041638463279630916, 0.10019827073131268, 0.14631836833822712, 0.2240175849950952, 0.2284532373612079, 0.19094390253159696, 0.16651079912256633, 0.1366198991062094, 0.1531613575168732, 0.15832926271828068, 0.1355022313367676, 0.1371991570253776, 0.15200430030145073, 0.1289743287693288, 0.11889629150246303, 0.1299349850688645, 0.1439676228356837, 0.15671544266725537, 0.13916557490241635, 0.1254910527876445, 0.12804438160833864, 0.1348672604639425, 0.12044834017050647, 0.12676008889430834, 0.11461523484218875, 0.12100130478055854, 0.16064779698548673, 0.15085657104567965, 0.12438815929620986, 0.12298713167116866, 0.1302325547150731, 0.13835974209656793, 0.13045702495739636, 0.12149457598918338, 0.12178719797982543, 0.10589989430645402, 0.10695060304222803, 0.11247619327198298, 0.12526679163881363, 0.11674564159243002, 0.12351190996349973, 0.13275382956110487, 0.11753092526746389, 0.12913086971281487, 0.13101288147859969, 0.11308533867954541, 0.1238697781895269, 0.12148779349419715, 0.12295913230178925], "test_acc": [0.496, 0.496, 0.532, 0.558, 0.57675, 0.601, 0.61175, 0.6135, 0.62775, 0.61775, 0.635, 0.6445, 0.65125, 0.646, 0.65825, 0.6525, 0.65325, 0.663, 0.65875, 0.65725, 0.66775, 0.67025, 0.66625, 0.664, 0.654, 0.67625, 0.6745, 0.66525, 0.66375, 0.6715, 0.66725, 0.652, 0.661, 0.68025, 0.66775, 0.67025, 0.67625, 0.6735, 0.6715, 0.66925, 0.6675, 0.672, 0.6785, 0.68, 0.6535, 0.6715, 0.67175, 0.667, 0.6665, 0.67025], "test_same_acc": [0.49675, 0.49675, 0.53075, 0.56775, 0.59, 0.61025, 0.61975, 0.6265, 0.628, 0.632, 0.65075, 0.64925, 0.65575, 0.654, 0.67175, 0.66725, 0.66675, 0.6685, 0.676, 0.66425, 0.6755, 0.68, 0.67975, 0.6695, 0.6645, 0.6785, 0.67675, 0.66475, 0.66725, 0.67275, 0.6695, 0.654, 0.66725, 0.6735, 0.66875, 0.66725, 0.67725, 0.6755, 0.67725, 0.6745, 0.67025, 0.672, 0.6735, 0.67425, 0.65825, 0.66575, 0.673, 0.67475, 0.667, 0.6755], "test_acc_ci": [0.010956445129323426, 0.010956445129323426, 0.01093514040247655, 0.010869758131939963, 0.010803330145000428, 0.010709462226081894, 0.01065909238103789, 0.010636499424152667, 0.010592344504257542, 0.010609518073262093, 0.01049988332539343, 0.010473347031721758, 0.010427627129409645, 0.010452057213773755, 0.010342947113854927, 0.010381525923334921, 0.010380601138662442, 0.010337187015213566, 0.010324606027441192, 0.010375035512589824, 0.010291046665518283, 0.01026269256807759, 0.010279985126448383, 0.01032946120993249, 0.010386137775299342, 0.010244120270763858, 0.010258588775844122, 0.010342947113854929, 0.010339110307468433, 0.010287035875162245, 0.010316795164772585, 0.010431155170929057, 0.010349630191309977, 0.010248270856435975, 0.010317774391190184, 0.010313852562815702, 0.010249306410069903, 0.010267803560158325, 0.010268823261302872, 0.010289042924241059, 0.010312870053815038, 0.010288039813297768, 0.010255502172005037, 0.010246197238335546, 0.01041069513033952, 0.010314834250819303, 0.010285025516345352, 0.01029703806183482, 0.01032946120993249, 0.010280994861727876], "best_epoch": 50, "best_val_acc": 0.652, "best_val_same_acc": 0.692, "best_val_tre": 0.7935415287613868, "best_val_tre_std": 0.12295913230178925, "best_test_acc": 0.67025, "best_test_same_acc": 0.6755, "best_test_acc_ci": 0.010280994861727876, "lowest_val_tre": 0.23315593773126603, "lowest_val_tre_std": 0.015967926829254738, "has_same": true} -------------------------------------------------------------------------------- /shapeworld/exp/lsl_color/args.json: -------------------------------------------------------------------------------- 1 | {"exp_dir": "exp/lsl_color", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 19626, "language_filter": "color", "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true} 2 | -------------------------------------------------------------------------------- /shapeworld/exp/lsl_color/metrics.json: -------------------------------------------------------------------------------- 1 | {"train_acc": [0.5018888888888889, 0.583, 0.7208888888888889, 0.7158888888888889, 0.7237777777777777, 0.7438888888888889, 0.7416666666666667, 0.7338888888888889, 0.7194444444444444, 0.7316666666666667, 0.7446666666666667, 0.7274444444444444, 0.7347777777777778, 0.7404444444444445, 0.7362222222222222, 0.7321111111111112, 0.7328888888888889, 0.7116666666666667, 0.7422222222222222, 0.7313333333333333, 0.7286666666666667, 0.7286666666666667, 0.737, 0.7194444444444444, 0.7313333333333333, 0.7153333333333334, 0.7331111111111112, 0.731, 0.7397777777777778, 0.7474444444444445, 0.7346666666666667, 0.6967777777777778, 0.725, 0.7418888888888889, 0.739, 0.7433333333333333, 0.7352222222222222, 0.7178888888888889, 0.7315555555555555, 0.7408888888888889, 0.7422222222222222, 0.7294444444444445, 0.734, 0.7356666666666667, 0.733, 0.7255555555555555, 0.7358888888888889, 0.7362222222222222, 0.7361111111111112, 0.7376666666666667], "val_acc": [0.508, 0.546, 0.602, 0.62, 0.608, 0.602, 0.596, 0.584, 0.606, 0.604, 0.61, 0.606, 0.616, 0.632, 0.592, 0.624, 0.59, 0.616, 0.614, 0.598, 0.624, 0.612, 0.628, 0.644, 0.632, 0.604, 0.636, 0.616, 0.638, 0.644, 0.63, 0.59, 0.616, 0.626, 0.63, 0.634, 0.648, 0.614, 0.614, 0.632, 0.62, 0.62, 0.622, 0.622, 0.622, 0.628, 0.626, 0.622, 0.634, 0.606], "val_same_acc": [0.496, 0.564, 0.588, 0.602, 0.638, 0.626, 0.624, 0.636, 0.636, 0.634, 0.646, 0.642, 0.648, 0.648, 0.656, 0.642, 0.652, 0.626, 0.638, 0.636, 0.652, 0.66, 0.664, 0.676, 0.638, 0.628, 0.666, 0.65, 0.668, 0.662, 0.66, 0.624, 0.65, 0.668, 0.658, 0.666, 0.654, 0.642, 0.648, 0.68, 0.66, 0.654, 0.666, 0.652, 0.642, 0.662, 0.658, 0.646, 0.652, 0.64], "val_tre": [0.2928761223256588, 0.4909118238091469, 0.7093970262408257, 0.715656586676836, 0.6946930028498173, 0.7530982179045678, 0.7181237662732601, 0.7177722745537758, 0.7063919916450977, 0.7170701187551022, 0.7364676860868931, 0.7208060621917248, 0.707125977486372, 0.7182179855704307, 0.7051682096123696, 0.7211604817807674, 0.7094560405910015, 0.6675868063867092, 0.7088536221086978, 0.691121347218752, 0.7120009001791477, 0.6966433502137661, 0.705272063344717, 0.6911544860005379, 0.6795149468779564, 0.6730550227761268, 0.7166172302365303, 0.7124211880266667, 0.7084594193398952, 0.7274633083939552, 0.7160733468532562, 0.6490570981502533, 0.6786254914104939, 0.7213582392036915, 0.7089775923788547, 0.7222011366784573, 0.7164143627583981, 0.6961094659864903, 0.7130949632823467, 0.7321216926574707, 0.7255135977864265, 0.7286434670984745, 0.714307487398386, 0.7320686898827553, 0.7254576664865017, 0.7269528369903564, 0.7190864905118942, 0.7051229429543018, 0.7255788600146771, 0.7198136151731014], "val_tre_std": [0.028786290772653104, 0.15274888657758995, 0.29936371313875904, 0.3729337300219388, 0.2503601848631358, 0.34096694884861295, 0.29398723347551503, 0.28951914089576947, 0.29075415453489756, 0.2656755109916801, 0.2658799580899334, 0.24215173114625418, 0.22278727569008222, 0.2368307482422677, 0.23235498110132288, 0.2072491654701949, 0.2702409115734452, 0.2236184128671816, 0.22949241611567125, 0.27024222622417127, 0.2521140637309202, 0.22977934149599544, 0.2523585546825122, 0.2574569513302243, 0.225301321506544, 0.24954425982964126, 0.21412330289882422, 0.17697085712411145, 0.20319146035285715, 0.14752056211130987, 0.22221709431604936, 0.23647851952807025, 0.1968089448080926, 0.2107266092599971, 0.1918881409223694, 0.1775148583175112, 0.18596143152343486, 0.22458536367131698, 0.2174860410072009, 0.1585495236993038, 0.17890690690619446, 0.19183306641129158, 0.23175687856064675, 0.1876181925216055, 0.21742256350551445, 0.17311539572774087, 0.20677547101924948, 0.20651109041952515, 0.1764090059441681, 0.2239317514391409], "test_acc": [0.496, 0.55175, 0.59975, 0.60625, 0.6155, 0.62075, 0.621, 0.616, 0.63225, 0.624, 0.6355, 0.6345, 0.63875, 0.631, 0.6295, 0.6435, 0.63125, 0.61675, 0.63425, 0.6255, 0.63925, 0.63825, 0.63875, 0.6365, 0.62925, 0.61675, 0.637, 0.64575, 0.63775, 0.642, 0.63775, 0.60725, 0.621, 0.64625, 0.6385, 0.645, 0.63325, 0.62075, 0.62625, 0.642, 0.6445, 0.6385, 0.628, 0.635, 0.63825, 0.64125, 0.6415, 0.63175, 0.6405, 0.6355], "test_same_acc": [0.49675, 0.545, 0.6035, 0.603, 0.62375, 0.628, 0.62875, 0.634, 0.63175, 0.63975, 0.64275, 0.64375, 0.64825, 0.64525, 0.64525, 0.6505, 0.6435, 0.63, 0.6425, 0.63475, 0.63825, 0.635, 0.64575, 0.6405, 0.6345, 0.63675, 0.64825, 0.6515, 0.64825, 0.65975, 0.6395, 0.61725, 0.6275, 0.656, 0.648, 0.6565, 0.6465, 0.62625, 0.64025, 0.65375, 0.64675, 0.6485, 0.6375, 0.64575, 0.63975, 0.64125, 0.653, 0.64925, 0.644, 0.6445], "test_acc_ci": [0.010956445129323426, 0.01090533192855105, 0.010728031831229574, 0.01071417487359036, 0.010638527645866931, 0.010612338039841879, 0.010609518073262093, 0.010608811196359372, 0.010568017562438093, 0.010568766924143751, 0.0105240373121191, 0.0105240373121191, 0.010495789705877305, 0.010530360715420671, 0.010535070864587241, 0.010472504867509013, 0.010535070864587243, 0.010617941807237173, 0.010528784496411491, 0.010579177945231616, 0.010526414379431392, 0.010539753280882573, 0.01050395749170283, 0.010527995229387216, 0.010568766924143751, 0.010598835638290652, 0.010501515325031668, 0.010461485297120815, 0.010499066158473331, 0.010446007842562392, 0.010527205190451784, 0.010677052495305058, 0.010613041146980444, 0.01044427234166531, 0.010497429489522661, 0.010446874407568035, 0.010519262307637119, 0.010617243971483373, 0.01056048194390294, 0.010466587748491624, 0.01048172558178638, 0.010495789705877305, 0.010563505359846228, 0.010516063496949559, 0.010524830440439406, 0.010510435751551883, 0.01047081818615432, 0.010515261858365679, 0.01050395749170283, 0.01051846376615901], "best_epoch": 49, "best_val_acc": 0.634, "best_val_same_acc": 0.652, "best_val_tre": 0.7255788600146771, "best_val_tre_std": 0.1764090059441681, "best_test_acc": 0.6405, "best_test_same_acc": 0.644, "best_test_acc_ci": 0.01050395749170283, "lowest_val_tre": 0.2928761223256588, "lowest_val_tre_std": 0.028786290772653104, "has_same": true} -------------------------------------------------------------------------------- /shapeworld/exp/lsl_nocolor/args.json: -------------------------------------------------------------------------------- 1 | {"exp_dir": "exp/lsl_nocolor", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 12126, "language_filter": "nocolor", "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true} 2 | -------------------------------------------------------------------------------- /shapeworld/exp/lsl_nocolor/metrics.json: -------------------------------------------------------------------------------- 1 | {"train_acc": [0.5018888888888889, 0.5067777777777778, 0.6631111111111111, 0.7022222222222222, 0.7382222222222222, 0.7432222222222222, 0.7513333333333333, 0.7447777777777778, 0.7344444444444445, 0.7235555555555555, 0.7504444444444445, 0.7398888888888889, 0.7394444444444445, 0.7472222222222222, 0.7412222222222222, 0.7385555555555555, 0.7404444444444445, 0.7384444444444445, 0.7408888888888889, 0.7381111111111112, 0.7401111111111112, 0.7348888888888889, 0.7365555555555555, 0.7384444444444445, 0.7391111111111112, 0.7394444444444445, 0.7406666666666667, 0.6986666666666667, 0.7175555555555555, 0.7244444444444444, 0.7346666666666667, 0.7128888888888889, 0.7334444444444445, 0.7233333333333334, 0.7347777777777778, 0.755, 0.7448888888888889, 0.7327777777777778, 0.735, 0.7406666666666667, 0.7381111111111112, 0.73, 0.723, 0.7344444444444445, 0.7138888888888889, 0.7175555555555555, 0.7363333333333333, 0.7253333333333334, 0.732, 0.7188888888888889], "val_acc": [0.508, 0.514, 0.574, 0.592, 0.616, 0.622, 0.646, 0.632, 0.632, 0.626, 0.656, 0.644, 0.646, 0.636, 0.634, 0.65, 0.666, 0.63, 0.642, 0.654, 0.638, 0.638, 0.63, 0.656, 0.624, 0.646, 0.664, 0.628, 0.634, 0.666, 0.652, 0.652, 0.664, 0.658, 0.656, 0.658, 0.668, 0.664, 0.662, 0.658, 0.654, 0.654, 0.676, 0.668, 0.66, 0.664, 0.676, 0.67, 0.676, 0.662], "val_same_acc": [0.496, 0.496, 0.54, 0.568, 0.602, 0.602, 0.628, 0.63, 0.61, 0.616, 0.614, 0.642, 0.648, 0.632, 0.648, 0.638, 0.654, 0.65, 0.64, 0.652, 0.648, 0.658, 0.646, 0.654, 0.66, 0.65, 0.658, 0.62, 0.614, 0.642, 0.648, 0.634, 0.676, 0.648, 0.65, 0.67, 0.636, 0.652, 0.668, 0.666, 0.648, 0.65, 0.646, 0.632, 0.642, 0.656, 0.662, 0.644, 0.66, 0.65], "val_tre": [0.22808635076880454, 0.3463576367199421, 0.5833195138275623, 0.6936306474804879, 0.7367003444433212, 0.7617647814750671, 0.7925184162259102, 0.7806620350778103, 0.7809249736964703, 0.7618764308989048, 0.818745451271534, 0.7906182879209518, 0.8013895196616649, 0.8116443670988083, 0.7867026439905167, 0.7947599971592426, 0.8079485739171505, 0.8126739595234395, 0.8161718902587891, 0.795997316300869, 0.7976047129631042, 0.7890185303986073, 0.7939921219050884, 0.7666756071150302, 0.7984343985021114, 0.7855962689816952, 0.801971653997898, 0.7558098890483379, 0.7291252992153168, 0.7725845266282558, 0.7714625637233258, 0.726695333570242, 0.7933139767348766, 0.778794666916132, 0.8004181494414806, 0.7938655286729336, 0.7891755924224854, 0.7856353769600392, 0.7903081247210503, 0.7963232119977475, 0.7696316801607609, 0.7731537413299083, 0.7955225827097893, 0.7584440202414989, 0.7596146790981293, 0.7835552912950515, 0.8013896183669567, 0.7690100253224373, 0.7866484541594982, 0.7879996562898159], "val_tre_std": [0.01763054536880929, 0.07339547584061559, 0.14941119368980174, 0.17516040573277214, 0.2652266204437812, 0.2193908594978303, 0.16969182617088535, 0.19375694303927413, 0.17319022049681249, 0.20916169809256696, 0.14892911600273737, 0.1340359423374055, 0.16795556135270545, 0.14105599331769797, 0.16546023395342999, 0.1476439768847445, 0.16017509978664035, 0.18668384015479503, 0.1601517815583093, 0.13709783292609365, 0.15121442764377227, 0.15114236661248562, 0.17220838823519422, 0.15274215369714983, 0.11693935023993977, 0.14640735146688152, 0.12193917723328035, 0.1359785426625704, 0.16940021888863163, 0.15788490043491232, 0.15381887433112926, 0.16534711985816136, 0.11927814837344929, 0.13760702346949716, 0.13427542712608062, 0.1268173391039503, 0.1446572349584176, 0.13625064891802438, 0.14825136077253176, 0.11018452950263992, 0.1267134292548735, 0.14440096747254103, 0.13868456348510233, 0.14252962444754882, 0.15683233923197384, 0.14174051286031775, 0.12008535457174874, 0.1366931669427143, 0.11642106634144173, 0.14012168432103977], "test_acc": [0.496, 0.5, 0.55975, 0.582, 0.59375, 0.603, 0.61475, 0.6135, 0.61525, 0.623, 0.63175, 0.63575, 0.63725, 0.642, 0.6355, 0.64925, 0.64175, 0.64325, 0.64125, 0.6435, 0.64475, 0.645, 0.64175, 0.641, 0.6405, 0.6515, 0.65, 0.63525, 0.6285, 0.64125, 0.64475, 0.62175, 0.64925, 0.6335, 0.642, 0.6515, 0.641, 0.648, 0.64825, 0.64875, 0.63925, 0.63875, 0.64875, 0.64, 0.63025, 0.641, 0.6475, 0.6325, 0.64525, 0.64975], "test_same_acc": [0.49675, 0.50075, 0.574, 0.59875, 0.60675, 0.61575, 0.62725, 0.62475, 0.62325, 0.63475, 0.62975, 0.629, 0.63525, 0.63325, 0.63475, 0.6385, 0.64025, 0.635, 0.644, 0.642, 0.639, 0.636, 0.636, 0.636, 0.62875, 0.63775, 0.639, 0.6275, 0.61925, 0.63, 0.637, 0.61675, 0.64225, 0.6265, 0.63475, 0.645, 0.641, 0.63475, 0.64175, 0.64525, 0.6385, 0.64475, 0.64475, 0.63, 0.62125, 0.63925, 0.64475, 0.63075, 0.64125, 0.6455], "test_acc_ci": [0.010956445129323426, 0.010956730008167355, 0.010858287988761166, 0.010776265539224384, 0.01073424240398455, 0.010691371283510595, 0.010631057887153093, 0.010641221468744789, 0.010640549134678153, 0.010586523071664038, 0.010575476863361764, 0.010565764897624544, 0.01054208410076015, 0.010533503898127867, 0.010549035062358783, 0.010493324187161809, 0.010512047555067468, 0.0105240373121191, 0.010501515325031666, 0.010500699714185716, 0.010506392663368098, 0.010515261858365677, 0.010525622796152014, 0.010527995229387216, 0.01055210442148271, 0.010488372095891478, 0.010489199395092077, 0.010571756742465981, 0.010615145947978057, 0.010545953426403656, 0.01051285229359164, 0.010640549134678153, 0.010480891249674332, 0.010579915878682589, 0.010528784496411491, 0.01046404006048811, 0.010512047555067468, 0.010509628686203663, 0.010485885513393706, 0.010472504867509013, 0.010525622796152014, 0.01050720283365178, 0.010474188411877075, 0.010549803552673384, 0.01060455408715991, 0.010517664450669406, 0.010478383558396543, 0.01057026335868104, 0.010497429489522661, 0.010468282278954603], "best_epoch": 47, "best_val_acc": 0.676, "best_val_same_acc": 0.662, "best_val_tre": 0.8013896183669567, "best_val_tre_std": 0.12008535457174874, "best_test_acc": 0.6475, "best_test_same_acc": 0.64475, "best_test_acc_ci": 0.010478383558396543, "lowest_val_tre": 0.22808635076880454, "lowest_val_tre_std": 0.01763054536880929, "has_same": true} -------------------------------------------------------------------------------- /shapeworld/exp/lsl_shuffle_captions/args.json: -------------------------------------------------------------------------------- 1 | {"exp_dir": "exp/lsl_shuffle_captions", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 5974, "language_filter": null, "shuffle_words": false, "shuffle_captions": true, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true} 2 | -------------------------------------------------------------------------------- /shapeworld/exp/lsl_shuffle_captions/metrics.json: -------------------------------------------------------------------------------- 1 | {"train_acc": [0.5071111111111111, 0.501, 0.5021111111111111, 0.5645555555555556, 0.6984444444444444, 0.736, 0.7335555555555555, 0.7392222222222222, 0.7366666666666667, 0.7367777777777778, 0.7414444444444445, 0.7402222222222222, 0.7327777777777778, 0.7438888888888889, 0.7367777777777778, 0.7236666666666667, 0.7268888888888889, 0.7038888888888889, 0.7293333333333333, 0.7333333333333333, 0.7113333333333334, 0.6705555555555556, 0.7296666666666667, 0.7068888888888889, 0.7284444444444444, 0.6974444444444444, 0.7138888888888889, 0.7088888888888889, 0.7338888888888889, 0.696, 0.7006666666666667, 0.7127777777777777, 0.665, 0.6768888888888889, 0.726, 0.729, 0.7235555555555555, 0.6732222222222223, 0.7087777777777777, 0.6856666666666666, 0.7183333333333334, 0.6925555555555556, 0.7375555555555555, 0.7103333333333334, 0.7212222222222222, 0.7194444444444444, 0.7066666666666667, 0.7207777777777777, 0.6978888888888889, 0.697], "val_acc": [0.508, 0.508, 0.508, 0.54, 0.566, 0.554, 0.542, 0.556, 0.56, 0.578, 0.56, 0.546, 0.588, 0.564, 0.572, 0.58, 0.56, 0.572, 0.582, 0.592, 0.564, 0.55, 0.566, 0.562, 0.566, 0.568, 0.556, 0.566, 0.564, 0.54, 0.526, 0.562, 0.538, 0.566, 0.536, 0.55, 0.544, 0.522, 0.562, 0.538, 0.564, 0.542, 0.552, 0.55, 0.548, 0.55, 0.544, 0.546, 0.532, 0.55], "val_same_acc": [0.496, 0.496, 0.496, 0.526, 0.532, 0.544, 0.546, 0.534, 0.544, 0.552, 0.554, 0.552, 0.56, 0.566, 0.562, 0.568, 0.572, 0.58, 0.58, 0.55, 0.57, 0.554, 0.56, 0.566, 0.554, 0.566, 0.542, 0.522, 0.552, 0.514, 0.544, 0.552, 0.512, 0.54, 0.584, 0.582, 0.562, 0.534, 0.554, 0.554, 0.546, 0.538, 0.57, 0.55, 0.576, 0.564, 0.554, 0.574, 0.546, 0.564], "val_tre": [0.22047525447607041, 0.2645132866203785, 0.36374358320236205, 0.5220565661787987, 0.7167646891772746, 0.8003292497396469, 0.7903731316030026, 0.8116375431418419, 0.7994353666901588, 0.7786327120959758, 0.8107034644186497, 0.813195310741663, 0.7700015254616738, 0.8113658610582352, 0.7617959608137608, 0.7662459227144718, 0.7579508483409881, 0.75589797565341, 0.7805133513510227, 0.7729206332266331, 0.7705274262428283, 0.707739619165659, 0.7790121876001358, 0.7549196770191192, 0.7852870355248451, 0.7266519095599652, 0.7746121415793896, 0.7715683530569076, 0.7978267765641213, 0.7414474821686745, 0.7512542349100113, 0.7495681802928448, 0.6815140551030636, 0.7298193091452122, 0.793954403668642, 0.8172262015938759, 0.7798480809032917, 0.7092128167152405, 0.7335295847356319, 0.7364193704426288, 0.786985212892294, 0.7563742088973522, 0.8047082170248031, 0.7670680065453053, 0.7921611217856407, 0.8007868621647358, 0.768711370229721, 0.7819871633350849, 0.7565498836040497, 0.7600137263834477], "val_tre_std": [0.014784197950176159, 0.02652211147234335, 0.04997577132833228, 0.09928714512269435, 0.14315726153916938, 0.17905031274741376, 0.3028480897883202, 0.24420796565707084, 0.2854344250567698, 0.20748745243769307, 0.16452433254581292, 0.19958105132499038, 0.19181740719145532, 0.18428430107650812, 0.2091974202371749, 0.2142233483216199, 0.16718732490884747, 0.17163516532810213, 0.1424397817852027, 0.13710612071647124, 0.19242126031642523, 0.16510685661208033, 0.16180709629451084, 0.14522477034842432, 0.13575573003599034, 0.20591636374702857, 0.17189024005737322, 0.2046076913501664, 0.11446174753005237, 0.19160172738826214, 0.1564791769508523, 0.22557180313439085, 0.18962065680253348, 0.2226163961731732, 0.14819683260006034, 0.12277138594077944, 0.18841027141919195, 0.21871579157749851, 0.21611514095686526, 0.2226257086464708, 0.154523188492496, 0.18645111716967322, 0.12550173796621417, 0.17387476825860076, 0.12255660439773507, 0.18413762094468844, 0.14161339259122727, 0.1449093187359931, 0.16743659468361968, 0.16200972926336132], "test_acc": [0.496, 0.496, 0.496, 0.5205, 0.54575, 0.55475, 0.5495, 0.5565, 0.55175, 0.55725, 0.55975, 0.56525, 0.5665, 0.57075, 0.56825, 0.571, 0.5675, 0.57, 0.5795, 0.5745, 0.56625, 0.536, 0.5745, 0.5555, 0.5655, 0.559, 0.564, 0.54425, 0.562, 0.55275, 0.54975, 0.55675, 0.53125, 0.532, 0.56625, 0.569, 0.55375, 0.5315, 0.552, 0.5415, 0.56625, 0.547, 0.56625, 0.564, 0.5645, 0.562, 0.559, 0.5555, 0.552, 0.56075], "test_same_acc": [0.49675, 0.49675, 0.49675, 0.5275, 0.55825, 0.555, 0.5605, 0.55775, 0.56325, 0.563, 0.57075, 0.57375, 0.57425, 0.577, 0.57625, 0.58325, 0.57325, 0.57275, 0.57775, 0.5835, 0.581, 0.54825, 0.577, 0.564, 0.57325, 0.552, 0.575, 0.56125, 0.569, 0.55025, 0.55575, 0.55525, 0.5335, 0.542, 0.56825, 0.571, 0.57475, 0.54475, 0.5535, 0.55675, 0.56825, 0.557, 0.57425, 0.55975, 0.56625, 0.574, 0.5645, 0.56825, 0.566, 0.566], "test_acc_ci": [0.010956445129323426, 0.010956445129323426, 0.010956445129323426, 0.010944103654479887, 0.01089731798196235, 0.01089054591133406, 0.010890243110234041, 0.01088498873894112, 0.010884040552570539, 0.010877227105143801, 0.01086303449720657, 0.010850369300166697, 0.010847660668175189, 0.010836480101807737, 0.010841739988927054, 0.010825601269531176, 0.01084766066817519, 0.010844522747538268, 0.010820418481827538, 0.010819106793076773, 0.010837297097149038, 0.010917778052189695, 0.010830262341582499, 0.01087821933900489, 0.010850753426461914, 0.01088902493109461, 0.010850369300166697, 0.010895586881279043, 0.01086231199837309, 0.010898458127184778, 0.010895586881279043, 0.010887795589558063, 0.01093374057090596, 0.010926692372351296, 0.010857175760182755, 0.010848825742908767, 0.010865896391347562, 0.010924835092433891, 0.010895586881279043, 0.010903721834624864, 0.010857175760182755, 0.010897317981962348, 0.010848049731979476, 0.01087251336154042, 0.010862673599159416, 0.010854932298268838, 0.010872854610795638, 0.01087251336154042, 0.010880184915708004, 0.010868363686492783], "best_epoch": 13, "best_val_acc": 0.588, "best_val_same_acc": 0.56, "best_val_tre": 0.7700015254616738, "best_val_tre_std": 0.19181740719145532, "best_test_acc": 0.5665, "best_test_same_acc": 0.57425, "best_test_acc_ci": 0.010847660668175189, "lowest_val_tre": 0.22047525447607041, "lowest_val_tre_std": 0.014784197950176159, "has_same": true} -------------------------------------------------------------------------------- /shapeworld/exp/lsl_shuffle_words/args.json: -------------------------------------------------------------------------------- 1 | {"exp_dir": "exp/lsl_shuffle_words", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 27075, "language_filter": null, "shuffle_words": true, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true} 2 | -------------------------------------------------------------------------------- /shapeworld/exp/lsl_shuffle_words/metrics.json: -------------------------------------------------------------------------------- 1 | {"train_acc": [0.5034444444444445, 0.5135555555555555, 0.6038888888888889, 0.6691111111111111, 0.6986666666666667, 0.703, 0.7335555555555555, 0.737, 0.7327777777777778, 0.7257777777777777, 0.7482222222222222, 0.7406666666666667, 0.7463333333333333, 0.7477777777777778, 0.7313333333333333, 0.747, 0.7385555555555555, 0.7303333333333333, 0.7364444444444445, 0.7438888888888889, 0.7351111111111112, 0.7403333333333333, 0.7396666666666667, 0.7406666666666667, 0.7463333333333333, 0.7496666666666667, 0.7395555555555555, 0.74, 0.7352222222222222, 0.7293333333333333, 0.7483333333333333, 0.7407777777777778, 0.7425555555555555, 0.7446666666666667, 0.7411111111111112, 0.758, 0.7344444444444445, 0.7464444444444445, 0.7391111111111112, 0.7352222222222222, 0.7374444444444445, 0.737, 0.744, 0.743, 0.7445555555555555, 0.7363333333333333, 0.7406666666666667, 0.7398888888888889, 0.7377777777777778, 0.7367777777777778], "val_acc": [0.508, 0.52, 0.568, 0.61, 0.62, 0.62, 0.63, 0.64, 0.644, 0.652, 0.656, 0.648, 0.648, 0.646, 0.654, 0.648, 0.66, 0.65, 0.634, 0.656, 0.648, 0.654, 0.66, 0.668, 0.66, 0.658, 0.648, 0.652, 0.652, 0.638, 0.658, 0.652, 0.66, 0.656, 0.656, 0.664, 0.66, 0.66, 0.646, 0.652, 0.654, 0.65, 0.656, 0.656, 0.662, 0.652, 0.654, 0.658, 0.668, 0.65], "val_same_acc": [0.496, 0.5, 0.542, 0.562, 0.566, 0.572, 0.61, 0.6, 0.62, 0.62, 0.62, 0.622, 0.654, 0.638, 0.646, 0.654, 0.646, 0.666, 0.66, 0.656, 0.65, 0.652, 0.666, 0.674, 0.674, 0.668, 0.67, 0.678, 0.676, 0.652, 0.676, 0.684, 0.68, 0.678, 0.67, 0.686, 0.68, 0.682, 0.678, 0.674, 0.672, 0.68, 0.682, 0.678, 0.664, 0.664, 0.668, 0.682, 0.688, 0.666], "val_tre": [0.27602067959308624, 0.40746447333693503, 0.5171051873266697, 0.6125595195889473, 0.6366198836266994, 0.6640413806140423, 0.7124214672148228, 0.7231019198596478, 0.7267296032011509, 0.7232981150150299, 0.736225801974535, 0.7374791561663151, 0.7319063286185264, 0.7416931557953358, 0.7135514880418777, 0.7375768918097019, 0.7422215968072414, 0.7392168205976486, 0.7333029879629612, 0.7477176522910595, 0.7457649510502815, 0.7435495406389236, 0.7467366912662983, 0.7510937738120556, 0.7458743584156037, 0.7536228042840958, 0.7528315424025058, 0.7484392536580563, 0.7481056715548039, 0.7274645104706288, 0.7511067868769169, 0.7526578099429607, 0.7544187552034854, 0.7587365226745606, 0.7553361043334007, 0.7578833720088005, 0.7530133697390556, 0.7546444187760353, 0.7501327090263367, 0.7505363914966583, 0.7546211395263672, 0.7538118029236793, 0.7556851161122322, 0.7571718983650207, 0.7616352689862251, 0.7572485668361187, 0.7550260179936886, 0.7595631908476352, 0.763034375667572, 0.7599161682724953], "val_tre_std": [0.027511252715802494, 0.07593431201113533, 0.1262049245646053, 0.14781622666731317, 0.19731004650994827, 0.1538107448595129, 0.15554684902972157, 0.14695279508565937, 0.17908736485187154, 0.18875130684038025, 0.1769762953390025, 0.16766753936312187, 0.17139012632148498, 0.17359321816174428, 0.20311660648997165, 0.17161296839295997, 0.16578064885892926, 0.17290399315845784, 0.16886487327310604, 0.14806971197739538, 0.16859308914903975, 0.13572207758448554, 0.16111478218014064, 0.1637356360401312, 0.14536515798208788, 0.18004281884501797, 0.1591164335824455, 0.1388791819160653, 0.1534606178792528, 0.15306302122473106, 0.14378495034669006, 0.1367711261502094, 0.12677138741279045, 0.1316132040423189, 0.15562621157817208, 0.14923601261638875, 0.1393943019165844, 0.13612622115365428, 0.1349727064430412, 0.11432319623566965, 0.12727456126632894, 0.1289669480511882, 0.11890626436770418, 0.1310952292714821, 0.12584475311577023, 0.1257270419068187, 0.12828189650193889, 0.13452211085484508, 0.11821392361223679, 0.1296130167067139], "test_acc": [0.496, 0.50625, 0.548, 0.57, 0.58325, 0.59225, 0.59775, 0.60475, 0.612, 0.61975, 0.62475, 0.6275, 0.6395, 0.6455, 0.637, 0.646, 0.65125, 0.65, 0.65675, 0.6535, 0.65925, 0.655, 0.65825, 0.65775, 0.66075, 0.6675, 0.6625, 0.6605, 0.6605, 0.6565, 0.6645, 0.666, 0.663, 0.664, 0.6585, 0.666, 0.6655, 0.66575, 0.664, 0.66325, 0.6605, 0.6695, 0.663, 0.664, 0.669, 0.66825, 0.6675, 0.6695, 0.667, 0.67075], "test_same_acc": [0.49675, 0.50525, 0.5515, 0.57875, 0.59125, 0.59175, 0.60225, 0.60675, 0.613, 0.6225, 0.64475, 0.64375, 0.6445, 0.64975, 0.64275, 0.647, 0.66125, 0.6615, 0.65925, 0.67175, 0.666, 0.67575, 0.67625, 0.67825, 0.6805, 0.67325, 0.67875, 0.67725, 0.68, 0.66025, 0.67175, 0.68125, 0.6785, 0.68075, 0.67325, 0.6795, 0.67525, 0.6745, 0.6735, 0.6725, 0.66925, 0.677, 0.67775, 0.67475, 0.6765, 0.67325, 0.6735, 0.67625, 0.67475, 0.673], "test_acc_ci": [0.010956445129323426, 0.010956008551817583, 0.010902360982259758, 0.010834837616313176, 0.010788625838701608, 0.010769660496041646, 0.010735362127101257, 0.010708869846417036, 0.010675788905275336, 0.010630374322989525, 0.010551338232068006, 0.010545953426403656, 0.010505581716402, 0.010468282278954603, 0.010519262307637119, 0.010475868820770904, 0.010407995349129437, 0.010411593460537154, 0.010395301207757281, 0.010360994311448829, 0.01036099431144883, 0.010340070730989945, 0.010325578699884088, 0.010319730384074962, 0.010299028591176694, 0.010301015822329125, 0.010299028591176696, 0.010312870053815038, 0.010302008201680875, 0.010392560607082116, 0.01031875279754656, 0.010274924008082734, 0.010298033738898899, 0.010285025516345352, 0.010336224146025226, 0.010282003768113489, 0.010301015822329125, 0.01030299975720057, 0.0103138525628157, 0.010320707151008355, 0.01034390427966515, 0.010277963168230366, 0.010301015822329125, 0.010308931803144057, 0.01028200376811349, 0.010298033738898899, 0.010300022618907202, 0.010280994861727876, 0.01029703806183482, 0.010289042924241059], "best_epoch": 49, "best_val_acc": 0.668, "best_val_same_acc": 0.688, "best_val_tre": 0.763034375667572, "best_val_tre_std": 0.11821392361223679, "best_test_acc": 0.667, "best_test_same_acc": 0.67475, "best_test_acc_ci": 0.01029703806183482, "lowest_val_tre": 0.27602067959308624, "lowest_val_tre_std": 0.027511252715802494, "has_same": true} 2 | -------------------------------------------------------------------------------- /shapeworld/exp/meta/args.json: -------------------------------------------------------------------------------- 1 | {"exp_dir": "exp/meta", "predict_concept_hyp": false, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 29493, "language_filter": null, "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 10.0, "save_checkpoint": false, "cuda": true, "predict_hyp": false, "use_hyp": false, "encode_hyp": false, "decode_hyp": false} 2 | -------------------------------------------------------------------------------- /shapeworld/exp/meta/metrics.json: -------------------------------------------------------------------------------- 1 | {"train_acc": [0.587, 0.7234444444444444, 0.7327777777777778, 0.7338888888888889, 0.7413333333333333, 0.7468888888888889, 0.7552222222222222, 0.7471111111111111, 0.7431111111111111, 0.7497777777777778, 0.7494444444444445, 0.7463333333333333, 0.7562222222222222, 0.7497777777777778, 0.7491111111111111, 0.7525555555555555, 0.751, 0.7422222222222222, 0.7553333333333333, 0.7503333333333333, 0.7524444444444445, 0.7457777777777778, 0.7596666666666667, 0.7496666666666667, 0.7512222222222222, 0.7546666666666667, 0.7551111111111111, 0.7454444444444445, 0.7604444444444445, 0.7524444444444445, 0.7546666666666667, 0.7572222222222222, 0.7493333333333333, 0.7554444444444445, 0.7573333333333333, 0.7665555555555555, 0.7581111111111111, 0.7556666666666667, 0.7588888888888888, 0.7476666666666667, 0.7614444444444445, 0.7563333333333333, 0.762, 0.764, 0.7618888888888888, 0.7621111111111111, 0.7532222222222222, 0.7506666666666667, 0.7588888888888888, 0.7618888888888888], "val_acc": [0.534, 0.568, 0.54, 0.558, 0.56, 0.592, 0.598, 0.606, 0.59, 0.608, 0.606, 0.614, 0.62, 0.596, 0.6, 0.6, 0.596, 0.618, 0.6, 0.57, 0.568, 0.572, 0.582, 0.586, 0.568, 0.594, 0.592, 0.59, 0.576, 0.584, 0.6, 0.586, 0.582, 0.594, 0.602, 0.606, 0.594, 0.568, 0.596, 0.594, 0.602, 0.612, 0.618, 0.604, 0.602, 0.58, 0.61, 0.594, 0.614, 0.614], "val_same_acc": [0.526, 0.524, 0.564, 0.586, 0.59, 0.596, 0.596, 0.588, 0.598, 0.604, 0.584, 0.606, 0.596, 0.598, 0.596, 0.608, 0.606, 0.602, 0.602, 0.602, 0.612, 0.606, 0.608, 0.6, 0.596, 0.604, 0.6, 0.584, 0.576, 0.58, 0.592, 0.58, 0.612, 0.614, 0.606, 0.604, 0.608, 0.604, 0.62, 0.594, 0.606, 0.628, 0.61, 0.59, 0.614, 0.62, 0.614, 0.602, 0.61, 0.606], "val_tre": [0.5916588901281357, 0.8042377699017524, 0.8100116618275642, 0.8228306672573089, 0.8091322653889657, 0.8141303354799747, 0.8227340990900993, 0.8060084843039512, 0.7999683973491192, 0.8017535305321216, 0.8079626688361168, 0.8147974437475205, 0.8114711083471775, 0.8146977169215679, 0.801983368396759, 0.8079127200245857, 0.7854678380787372, 0.8079566982984543, 0.8099119906425476, 0.7922585190832615, 0.8039774939119816, 0.7961225943863391, 0.807337353438139, 0.8064042346477509, 0.8044357794523239, 0.8018924917876721, 0.8004497699439526, 0.8031283156871796, 0.8019032092988491, 0.8100885750055313, 0.8107224105894566, 0.8069514398574829, 0.8026845241487026, 0.8086875425577164, 0.8073043225705624, 0.8142727278470993, 0.8092543596625328, 0.8049768908321857, 0.8047672483623027, 0.8046509600281715, 0.8093694306910038, 0.81609526014328, 0.8113720493614673, 0.8151067448556423, 0.8141982303857803, 0.8105539740622043, 0.8153765279650688, 0.8105866264998913, 0.812785368680954, 0.8175703119635582], "val_tre_std": [0.20776802624853982, 0.22506850861303512, 0.3550965570239999, 0.25240233315544186, 0.30489556283502167, 0.25192126585263014, 0.22163285096208185, 0.22398180211467097, 0.22136366888011189, 0.21504156556522225, 0.21378987075781997, 0.1961432996474801, 0.18039981992759827, 0.20156448897379214, 0.19800972594737676, 0.186366711354458, 0.2288897358079121, 0.19004879167233124, 0.17280907742916762, 0.22070847274461852, 0.17811587832274378, 0.18698121777982987, 0.17709791856239163, 0.18664346244800895, 0.16421413078859573, 0.163887529613401, 0.18290036932823, 0.18182176754039484, 0.18873458891343609, 0.18661134863337522, 0.17477719968532573, 0.17169934653941996, 0.16874678702261794, 0.16338142096155675, 0.15784144370700903, 0.1539296292378941, 0.15841589817970592, 0.16560336048870408, 0.15963311797982846, 0.1611074459988127, 0.15862874099906135, 0.14242522707771324, 0.1613770586649978, 0.15857943745800235, 0.14755391734663065, 0.1512508923237696, 0.15404908131241488, 0.15800266933740134, 0.15596154811266771, 0.14921550713938078], "test_acc": [0.5335, 0.55, 0.556, 0.56, 0.563, 0.5795, 0.58525, 0.583, 0.594, 0.59175, 0.58925, 0.59925, 0.5935, 0.5975, 0.601, 0.5965, 0.6025, 0.59175, 0.59625, 0.603, 0.605, 0.60125, 0.5995, 0.60775, 0.5975, 0.6005, 0.6085, 0.5945, 0.604, 0.59925, 0.60025, 0.6045, 0.60925, 0.61075, 0.60525, 0.59725, 0.601, 0.60675, 0.597, 0.59525, 0.60625, 0.60125, 0.59825, 0.6015, 0.597, 0.60075, 0.6015, 0.593, 0.60825, 0.59775], "test_same_acc": [0.53775, 0.55625, 0.56175, 0.56925, 0.57175, 0.579, 0.58425, 0.5825, 0.5915, 0.59425, 0.5975, 0.603, 0.59875, 0.60125, 0.59475, 0.60825, 0.6025, 0.60275, 0.607, 0.60675, 0.61175, 0.60225, 0.6155, 0.613, 0.60425, 0.617, 0.6065, 0.60775, 0.61275, 0.61275, 0.61275, 0.6105, 0.603, 0.6125, 0.61125, 0.61125, 0.60775, 0.61625, 0.6075, 0.60625, 0.6035, 0.61425, 0.6135, 0.61475, 0.61175, 0.608, 0.614, 0.60625, 0.616, 0.6175], "test_acc_ci": [0.010928886433295707, 0.010894711930421795, 0.010880510063727481, 0.010864828448800976, 0.010856803609805003, 0.010818228777738989, 0.010798191213694077, 0.010805637440128184, 0.01076657092520641, 0.01076553529556241, 0.010763976436330348, 0.010730300426683076, 0.010752345999681882, 0.010738148696208066, 0.010744762286662045, 0.010724607049765274, 0.010724033697727736, 0.010747487077801024, 0.010728031831229574, 0.01071300110832044, 0.010696260256364136, 0.010727462858826408, 0.010700499462641917, 0.010686435138617322, 0.01073143034953286, 0.010694432415397274, 0.010700499462641917, 0.010730300426683076, 0.010696260256364137, 0.01070768288660063, 0.010705300161602193, 0.010700499462641917, 0.010707088306205145, 0.010680198522119098, 0.010696868064414928, 0.010715930029050209, 0.010715345709396173, 0.010680825508826552, 0.01072517967157194, 0.010731994217641937, 0.010713001108320441, 0.010699291938605098, 0.01070827673329724, 0.010697475136539231, 0.010715345709396173, 0.010715345709396173, 0.010699291938605098, 0.010737036251772414, 0.010677683180206978, 0.010699896068274446], "best_epoch": 43, "best_val_acc": 0.618, "best_val_same_acc": 0.61, "best_val_tre": 0.8113720493614673, "best_val_tre_std": 0.1613770586649978, "best_test_acc": 0.59825, "best_test_same_acc": 0.6135, "best_test_acc_ci": 0.01070827673329724, "lowest_val_tre": 0.5916588901281357, "lowest_val_tre_std": 0.20776802624853982, "has_same": true} -------------------------------------------------------------------------------- /shapeworld/lsl/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset utilities 3 | """ 4 | 5 | import os 6 | import json 7 | import logging 8 | 9 | import torch 10 | import numpy as np 11 | import torch.utils.data as data 12 | from torchvision import transforms 13 | 14 | from utils import next_random, OrderedCounter 15 | 16 | # Set your data directory here! 17 | DATA_DIR = '/u/scr/muj/shapeworld_4k' 18 | SPLIT_OPTIONS = ['train', 'val', 'test', 'val_same', 'test_same'] 19 | 20 | logging.getLogger(__name__).setLevel(logging.INFO) 21 | 22 | SOS_TOKEN = '' 23 | EOS_TOKEN = '' 24 | PAD_TOKEN = '' 25 | UNK_TOKEN = '' 26 | N_EX = 4 # number of examples per task 27 | 28 | random = next_random() 29 | COLORS = { 30 | 'black', 'red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'white' 31 | } 32 | SHAPES = { 33 | 'square', 'rectangle', 'triangle', 'pentagon', 'cross', 'circle', 34 | 'semicircle', 'ellipse' 35 | } 36 | 37 | 38 | def get_max_hint_length(data_dir=None): 39 | """ 40 | Get the maximum number of words in a sentence across all splits 41 | """ 42 | if data_dir is None: 43 | data_dir = DATA_DIR 44 | max_len = 0 45 | for split in ['train', 'val', 'test', 'val_same', 'test_same']: 46 | for tf in ['hints.json', 'test_hints.json']: 47 | hints_file = os.path.join(data_dir, 'shapeworld', split, tf) 48 | if os.path.exists(hints_file): 49 | with open(hints_file) as fp: 50 | hints = json.load(fp) 51 | split_max_len = max([len(hint.split()) for hint in hints]) 52 | if split_max_len > max_len: 53 | max_len = split_max_len 54 | if max_len == 0: 55 | raise RuntimeError("Can't find any splits in {}".format(data_dir)) 56 | return max_len 57 | 58 | 59 | def get_black_mask(imgs): 60 | if len(imgs.shape) == 4: 61 | # Then color is 1st dim 62 | col_dim = 1 63 | else: 64 | col_dim = 0 65 | total = imgs.sum(dim=col_dim) 66 | 67 | # Put dim back in 68 | is_black = total == 0.0 69 | is_black = is_black.unsqueeze(col_dim).expand_as(imgs) 70 | 71 | return is_black 72 | 73 | 74 | class ShapeWorld(data.Dataset): 75 | r"""Loader for ShapeWorld data as in L3. 76 | 77 | @param split: string [default: train] 78 | train|val|test|val_same|test_same 79 | @param vocab: ?Object [default: None] 80 | initialize with a vocabulary 81 | important to do this for validation/test set. 82 | @param augment: boolean [default: False] 83 | negatively sample data from other concepts. 84 | @param max_size: limit size to this many training examples 85 | @param precomputed_features: load precomputed VGG features rather than raw image data 86 | @param noise: amount of uniform noise to add to examples 87 | @param class_noise_weight: how much of the noise added to examples should 88 | be the same across (pos/neg classes) (between 89 | 0.0 and 1.0) 90 | 91 | NOTE: for now noise/class_noise_weight has no impact on val/test datasets 92 | """ 93 | 94 | def __init__(self, 95 | split='train', 96 | vocab=None, 97 | augment=False, 98 | max_size=None, 99 | precomputed_features=True, 100 | preprocess=False, 101 | noise=0.0, 102 | class_noise_weight=0.5, 103 | fixed_noise_colors=None, 104 | fixed_noise_colors_max_rgb=0.2, 105 | noise_type='gaussian', 106 | data_dir=None, 107 | language_filter=None, 108 | shuffle_words=False, 109 | shuffle_captions=False): 110 | super(ShapeWorld, self).__init__() 111 | self.split = split 112 | assert self.split in SPLIT_OPTIONS 113 | self.vocab = vocab 114 | self.augment = augment 115 | self.max_size = max_size 116 | 117 | assert noise_type in ('gaussian', 'normal') 118 | self.noise_type = noise_type 119 | 120 | # Positive class noise 121 | if precomputed_features: 122 | self.image_dim = (4608, ) 123 | else: 124 | self.image_dim = (3, 64, 64) 125 | 126 | self.noise = noise 127 | self.fixed_noise_colors = fixed_noise_colors 128 | self.fixed_noise_colors_max_rgb = fixed_noise_colors_max_rgb 129 | if not class_noise_weight >= 0.0 and class_noise_weight <= 1.0: 130 | raise ValueError( 131 | "Class noise weight must be between 0 and 1, got {}".format( 132 | class_noise_weight)) 133 | self.class_noise_weight = class_noise_weight 134 | 135 | if data_dir is None: 136 | data_dir = DATA_DIR 137 | self.data_dir = data_dir 138 | split_dir = os.path.join(data_dir, 'shapeworld', split) 139 | if not os.path.exists(split_dir): 140 | raise RuntimeError("Can't find {}".format(split_dir)) 141 | 142 | self.precomputed_features = precomputed_features 143 | if self.precomputed_features: 144 | in_features_name = 'inputs.feats.npz' 145 | ex_features_name = 'examples.feats.npz' 146 | else: 147 | in_features_name = 'inputs.npz' 148 | ex_features_name = 'examples.npz' 149 | 150 | self.preprocess = None 151 | if preprocess: 152 | self.preprocess = transforms.Compose([ 153 | transforms.ToPILImage(), 154 | transforms.Resize((224, 224)), 155 | transforms.ToTensor(), 156 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 157 | std=[0.229, 0.224, 0.225]) 158 | ]) 159 | # hints = language 160 | # examples = images with positive labels (pre-training) 161 | # input = test time input 162 | # label = test time label 163 | labels = np.load(os.path.join(split_dir, 'labels.npz'))['arr_0'] 164 | in_features = np.load(os.path.join(split_dir, in_features_name))['arr_0'] 165 | ex_features = np.load(os.path.join(split_dir, ex_features_name))['arr_0'] 166 | with open(os.path.join(split_dir, 'hints.json')) as fp: 167 | hints = json.load(fp) 168 | 169 | test_hints = os.path.join(split_dir, 'test_hints.json') 170 | if self.fixed_noise_colors is not None: 171 | assert os.path.exists(test_hints) 172 | if os.path.exists(test_hints): 173 | with open(test_hints, 'r') as fp: 174 | test_hints = json.load(fp) 175 | self.test_hints = test_hints 176 | else: 177 | self.test_hints = None 178 | 179 | if self.test_hints is not None: 180 | for a, b, label in zip(hints, test_hints, labels): 181 | if label: 182 | assert a == b, (a, b, label) 183 | # else: # XXX: What?/ 184 | # assert a != b, (a, b, label) 185 | 186 | if not self.precomputed_features: 187 | # Bring channel to first dim 188 | in_features = np.transpose(in_features, (0, 3, 1, 2)) 189 | ex_features = np.transpose(ex_features, (0, 1, 4, 2, 3)) 190 | 191 | if self.max_size is not None: 192 | labels = labels[:self.max_size] 193 | in_features = in_features[:self.max_size] 194 | ex_features = ex_features[:self.max_size] 195 | hints = hints[:self.max_size] 196 | 197 | n_data = len(hints) 198 | 199 | self.in_features = in_features 200 | self.ex_features = ex_features 201 | self.hints = hints 202 | 203 | if self.vocab is None: 204 | self.create_vocab(hints, test_hints) 205 | 206 | self.w2i, self.i2w = self.vocab['w2i'], self.vocab['i2w'] 207 | self.vocab_size = len(self.w2i) 208 | 209 | # Language processing 210 | self.language_filter = language_filter 211 | if self.language_filter is not None: 212 | assert self.language_filter in ['color', 'nocolor'] 213 | self.shuffle_words = shuffle_words 214 | self.shuffle_captions = shuffle_captions 215 | 216 | # this is the maximum number of tokens in a sentence 217 | max_length = get_max_hint_length(data_dir) 218 | 219 | hints, hint_lengths = [], [] 220 | for hint in self.hints: 221 | hint_tokens = hint.split() 222 | # Hint processing 223 | if self.language_filter == 'color': 224 | hint_tokens = [t for t in hint_tokens if t in COLORS] 225 | elif self.language_filter == 'nocolor': 226 | hint_tokens = [t for t in hint_tokens if t not in COLORS] 227 | if self.shuffle_words: 228 | random.shuffle(hint_tokens) 229 | 230 | hint = [SOS_TOKEN, *hint_tokens, EOS_TOKEN] 231 | hint_length = len(hint) 232 | 233 | hint.extend([PAD_TOKEN] * (max_length + 2 - hint_length)) 234 | hint = [self.w2i.get(w, self.w2i[UNK_TOKEN]) for w in hint] 235 | 236 | hints.append(hint) 237 | hint_lengths.append(hint_length) 238 | 239 | hints = np.array(hints) 240 | hint_lengths = np.array(hint_lengths) 241 | 242 | if self.test_hints is not None: 243 | test_hints, test_hint_lengths = [], [] 244 | for test_hint in self.test_hints: 245 | test_hint_tokens = test_hint.split() 246 | 247 | if self.language_filter == 'color': 248 | test_hint_tokens = [ 249 | t for t in test_hint_tokens if t in COLORS 250 | ] 251 | elif self.language_filter == 'nocolor': 252 | test_hint_tokens = [ 253 | t for t in test_hint_tokens if t not in COLORS 254 | ] 255 | if self.shuffle_words: 256 | random.shuffle(test_hint_tokens) 257 | 258 | test_hint = [SOS_TOKEN, *test_hint_tokens, EOS_TOKEN] 259 | test_hint_length = len(test_hint) 260 | 261 | test_hint.extend([PAD_TOKEN] * (max_length + 2 - test_hint_length)) 262 | 263 | test_hint = [ 264 | self.w2i.get(w, self.w2i[UNK_TOKEN]) for w in test_hint 265 | ] 266 | 267 | test_hints.append(test_hint) 268 | test_hint_lengths.append(test_hint_length) 269 | 270 | test_hints = np.array(test_hints) 271 | test_hint_lengths = np.array(test_hint_lengths) 272 | 273 | data = [] 274 | for i in range(n_data): 275 | if self.shuffle_captions: 276 | hint_i = random.randint(len(hints)) 277 | test_hint_i = random.randint(len(test_hints)) 278 | else: 279 | hint_i = i 280 | test_hint_i = i 281 | if self.test_hints is not None: 282 | th = test_hints[test_hint_i] 283 | thl = test_hint_lengths[test_hint_i] 284 | else: 285 | th = hints[test_hint_i] 286 | thl = hint_lengths[test_hint_i] 287 | data_i = (ex_features[i], in_features[i], labels[i], hints[hint_i], 288 | hint_lengths[hint_i], th, thl) 289 | data.append(data_i) 290 | 291 | self.data = data 292 | self.max_length = max_length 293 | 294 | def create_vocab(self, hints, test_hints): 295 | w2i = dict() 296 | i2w = dict() 297 | w2c = OrderedCounter() 298 | 299 | special_tokens = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN] 300 | for st in special_tokens: 301 | i2w[len(w2i)] = st 302 | w2i[st] = len(w2i) 303 | 304 | for hint in hints: 305 | hint_tokens = hint.split() 306 | w2c.update(hint_tokens) 307 | 308 | if test_hints is not None: 309 | for hint in test_hints: 310 | hint_tokens = hint.split() 311 | w2c.update(hint_tokens) 312 | 313 | for w, c in list(w2c.items()): 314 | i2w[len(w2i)] = w 315 | w2i[w] = len(w2i) 316 | 317 | assert len(w2i) == len(i2w) 318 | vocab = dict(w2i=w2i, i2w=i2w) 319 | self.vocab = vocab 320 | 321 | logging.info('Created vocab with %d words.' % len(w2c)) 322 | 323 | def __len__(self): 324 | return len(self.data) 325 | 326 | def sample_train(self, n_batch): 327 | assert self.split == 'train' 328 | n_train = len(self.data) 329 | batch_examples = [] 330 | batch_image = [] 331 | batch_label = [] 332 | batch_hint = [] 333 | batch_hint_length = [] 334 | if self.test_hints is not None: 335 | batch_test_hint = [] 336 | batch_test_hint_length = [] 337 | 338 | for _ in range(n_batch): 339 | index = random.randint(n_train) 340 | examples, image, label, hint, hint_length, test_hint, test_hint_length = \ 341 | self.__getitem__(index) 342 | 343 | batch_examples.append(examples) 344 | batch_image.append(image) 345 | batch_label.append(label) 346 | batch_hint.append(hint) 347 | batch_hint_length.append(hint_length) 348 | if self.test_hints is not None: 349 | batch_test_hint.append(test_hint) 350 | batch_test_hint_length.append(test_hint_length) 351 | 352 | batch_examples = torch.stack(batch_examples) 353 | batch_image = torch.stack(batch_image) 354 | batch_label = torch.from_numpy(np.array(batch_label)).long() 355 | batch_hint = torch.stack(batch_hint) 356 | batch_hint_length = torch.from_numpy( 357 | np.array(batch_hint_length)).long() 358 | if self.test_hints is not None: 359 | batch_test_hint = torch.stack(batch_test_hint) 360 | batch_test_hint_length = torch.from_numpy( 361 | np.array(batch_test_hint_length)).long() 362 | else: 363 | batch_test_hint = None 364 | batch_test_hint_length = None 365 | 366 | return ( 367 | batch_examples, batch_image, batch_label, batch_hint, 368 | batch_hint_length, batch_test_hint, batch_test_hint_length 369 | ) 370 | 371 | def __getitem__(self, index): 372 | if self.split == 'train' and self.augment: 373 | examples, image, label, hint, hint_length, test_hint, test_hint_length = self.data[ 374 | index] 375 | 376 | # tie a language to a concept; convert to pytorch. 377 | hint = torch.from_numpy(hint).long() 378 | test_hint = torch.from_numpy(test_hint).long() 379 | 380 | # in training, pick whether to show positive or negative example. 381 | sample_label = random.randint(2) 382 | n_train = len(self.data) 383 | 384 | if sample_label == 0: 385 | # if we are training, we need to negatively sample data and 386 | # return a tuple (example_z, hint_z, 1) or... 387 | # return a tuple (example_z, hint_other_z, 0). 388 | # Sample a new test hint as well. 389 | examples2, image2, _, support_hint2, support_hint_length2, query_hint2, query_hint_length2 = self.data[ 390 | random.randint(n_train)] 391 | 392 | # pick either an example or an image. 393 | swap = random.randint(N_EX + 1) 394 | if swap == N_EX: 395 | feats = image2 396 | # Use the QUERY hint of the new example 397 | test_hint = query_hint2 398 | test_hint_length = query_hint_length2 399 | else: 400 | feats = examples2[swap, ...] 401 | # Use the SUPPORT hint of the new example 402 | test_hint = support_hint2 403 | test_hint_length = support_hint_length2 404 | 405 | test_hint = torch.from_numpy(test_hint).long() 406 | 407 | feats = torch.from_numpy(feats).float() 408 | examples = torch.from_numpy(examples).float() 409 | 410 | if self.preprocess is not None: 411 | feats = self.preprocess(feats) 412 | examples = torch.stack( 413 | [self.preprocess(e) for e in examples]) 414 | return examples, feats, 0, hint, hint_length, test_hint, test_hint_length 415 | else: # sample_label == 1 416 | swap = random.randint((N_EX + 1 if label == 1 else N_EX)) 417 | # pick either an example or an image. 418 | if swap == N_EX: 419 | feats = image 420 | else: 421 | feats = examples[swap, ...] 422 | if label == 1: 423 | examples[swap, ...] = image 424 | else: 425 | examples[swap, ...] = examples[random.randint(N_EX 426 | ), ...] 427 | 428 | # This is a positive example, so whatever example we've chosen, 429 | # assume the query hint matches the support hint. 430 | test_hint = hint 431 | test_hint_length = hint_length 432 | 433 | feats = torch.from_numpy(feats).float() 434 | examples = torch.from_numpy(examples).float() 435 | 436 | if self.preprocess is not None: 437 | feats = self.preprocess(feats) 438 | examples = torch.stack( 439 | [self.preprocess(e) for e in examples]) 440 | return examples, feats, 1, hint, hint_length, test_hint, test_hint_length 441 | 442 | else: # val, val_same, test, test_same 443 | examples, image, label, hint, hint_length, test_hint, test_hint_length = self.data[ 444 | index] 445 | 446 | # no fancy stuff. just return image. 447 | image = torch.from_numpy(image).float() 448 | 449 | # NOTE: we provide the oracle text. 450 | hint = torch.from_numpy(hint).long() 451 | test_hint = torch.from_numpy(test_hint).long() 452 | examples = torch.from_numpy(examples).float() 453 | 454 | if self.preprocess is not None: 455 | image = self.preprocess(image) 456 | examples = torch.stack([self.preprocess(e) for e in examples]) 457 | return examples, image, label, hint, hint_length, test_hint, test_hint_length 458 | 459 | def to_text(self, hints): 460 | texts = [] 461 | for hint in hints: 462 | text = [] 463 | for tok in hint: 464 | i = tok.item() 465 | w = self.vocab['i2w'].get(i, UNK_TOKEN) 466 | if w == PAD_TOKEN: 467 | break 468 | text.append(w) 469 | texts.append(text) 470 | 471 | return texts 472 | 473 | 474 | def extract_features(hints): 475 | """ 476 | Extract features from hints 477 | """ 478 | all_feats = [] 479 | for hint in hints: 480 | feats = [] 481 | for maybe_rel in ['above', 'below', 'left', 'right']: 482 | if maybe_rel in hint: 483 | rel = maybe_rel 484 | rel_idx = hint.index(rel) 485 | break 486 | else: 487 | raise RuntimeError("Didn't find relation: {}".format(hint)) 488 | # Add relation 489 | feats.append('rel:{}'.format(rel)) 490 | fst, snd = hint[:rel_idx], hint[rel_idx:] 491 | # fst: [, a, ..., is] 492 | fst_shape = fst[2:fst.index('is')] 493 | # snd: [..., a, ..., ., ] 494 | try: 495 | snd_shape = snd[snd.index('a') + 1:-2] 496 | except ValueError: 497 | # Use "an" 498 | snd_shape = snd[snd.index('an') + 1:-2] 499 | 500 | for name, fragment in [('fst', fst_shape), ('snd', snd_shape)]: 501 | for feat in fragment: 502 | if feat != 'shape': 503 | if feat in COLORS: 504 | feats.append('{}:color:{}'.format(name, feat)) 505 | else: 506 | assert feat in SHAPES, hint 507 | feats.append('{}:shape:{}'.format(name, feat)) 508 | all_feats.append(feats) 509 | return all_feats 510 | -------------------------------------------------------------------------------- /shapeworld/lsl/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models 3 | """ 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | import torch.nn.utils.rnn as rnn_utils 11 | 12 | 13 | class ExWrapper(nn.Module): 14 | """ 15 | Wrap around a model and allow training on examples 16 | i.e. tensor inputs of shape 17 | (batch_size, n_ex, *img_dims) 18 | """ 19 | 20 | def __init__(self, model): 21 | super(ExWrapper, self).__init__() 22 | self.model = model 23 | 24 | def forward(self, x): 25 | batch_size = x.shape[0] 26 | if len(x.shape) == 5: 27 | n_ex = x.shape[1] 28 | img_dim = x.shape[2:] 29 | # Flatten out examples first 30 | x_flat = x.view(batch_size * n_ex, *img_dim) 31 | else: 32 | x_flat = x 33 | 34 | x_enc = self.model(x_flat) 35 | 36 | if len(x.shape) == 5: 37 | x_enc = x_enc.view(batch_size, n_ex, -1) 38 | 39 | return x_enc 40 | 41 | 42 | class Identity(nn.Module): 43 | def forward(self, x): 44 | return x 45 | 46 | 47 | class ImageRep(nn.Module): 48 | r"""Two fully-connected layers to form a final image 49 | representation. 50 | 51 | VGG-16 -> FC -> ReLU -> FC 52 | 53 | Paper uses 512 hidden dimension. 54 | """ 55 | 56 | def __init__(self, backbone=None, hidden_size=512): 57 | super(ImageRep, self).__init__() 58 | if backbone is None: 59 | self.backbone = Identity() 60 | self.backbone.final_feat_dim = 4608 61 | else: 62 | self.backbone = backbone 63 | self.model = nn.Sequential( 64 | nn.Linear(self.backbone.final_feat_dim, hidden_size), nn.ReLU(), 65 | nn.Linear(hidden_size, hidden_size)) 66 | 67 | def forward(self, x): 68 | x_enc = self.backbone(x) 69 | return self.model(x_enc) 70 | 71 | 72 | class TextRep(nn.Module): 73 | r"""Deterministic Bowman et. al. model to form 74 | text representation. 75 | 76 | Again, this uses 512 hidden dimensions. 77 | """ 78 | 79 | def __init__(self, embedding_module): 80 | super(TextRep, self).__init__() 81 | self.embedding = embedding_module 82 | self.embedding_dim = embedding_module.embedding_dim 83 | self.gru = nn.GRU(self.embedding_dim, 512) 84 | 85 | def forward(self, seq, length): 86 | batch_size = seq.size(0) 87 | 88 | if batch_size > 1: 89 | sorted_lengths, sorted_idx = torch.sort(length, descending=True) 90 | seq = seq[sorted_idx] 91 | 92 | # reorder from (B,L,D) to (L,B,D) 93 | seq = seq.transpose(0, 1) 94 | 95 | # embed your sequences 96 | embed_seq = self.embedding(seq) 97 | 98 | packed = rnn_utils.pack_padded_sequence( 99 | embed_seq, 100 | sorted_lengths.data.cpu().tolist() 101 | if batch_size > 1 else length.data.tolist()) 102 | 103 | _, hidden = self.gru(packed) 104 | hidden = hidden[-1, ...] 105 | 106 | if batch_size > 1: 107 | _, reversed_idx = torch.sort(sorted_idx) 108 | hidden = hidden[reversed_idx] 109 | 110 | return hidden 111 | 112 | 113 | class MultimodalDeepRep(nn.Module): 114 | def __init__(self): 115 | super(MultimodalDeepRep, self).__init__() 116 | self.model = nn.Sequential(nn.Linear(512 * 2, 512 * 2), nn.ReLU(), 117 | nn.Linear(512 * 2, 512), nn.ReLU(), 118 | nn.Linear(512, 512)) 119 | 120 | def forward(self, x, y): 121 | xy = torch.cat([x, y], dim=1) 122 | return self.model(xy) 123 | 124 | 125 | class MultimodalRep(nn.Module): 126 | r"""Concat Image and Text representations.""" 127 | 128 | def __init__(self): 129 | super(MultimodalRep, self).__init__() 130 | self.model = nn.Sequential(nn.Linear(512 * 2, 512), nn.ReLU(), 131 | nn.Linear(512, 512)) 132 | 133 | def forward(self, x, y): 134 | xy = torch.cat([x, y], dim=1) 135 | return self.model(xy) 136 | 137 | 138 | class MultimodalSumExp(nn.Module): 139 | def forward(self, x, y): 140 | return x + y 141 | 142 | 143 | class MultimodalLinearRep(nn.Module): 144 | def __init__(self): 145 | super(MultimodalLinearRep, self).__init__() 146 | self.model = nn.Linear(512 * 2, 512) 147 | 148 | def forward(self, x, y): 149 | xy = torch.cat([x, y], dim=1) 150 | return self.model(xy) 151 | 152 | 153 | class MultimodalWeightedRep(nn.Module): 154 | def __init__(self): 155 | super(MultimodalWeightedRep, self).__init__() 156 | self.model = nn.Sequential(nn.Linear(512 * 2, 512), nn.ReLU(), 157 | nn.Linear(512, 1), nn.Sigmoid()) 158 | 159 | def forward(self, x, y): 160 | xy = torch.cat([x, y], dim=1) 161 | w = self.model(xy) 162 | out = w * x + (1. - w) * y 163 | return out 164 | 165 | 166 | class MultimodalSingleWeightRep(nn.Module): 167 | def __init__(self): 168 | super(MultimodalSingleWeightRep, self).__init__() 169 | self.w = nn.Parameter(torch.normal(torch.zeros(1), 1)) 170 | 171 | def forward(self, x, y): 172 | w = torch.sigmoid(self.w) 173 | out = w * x + (1. - w) * y 174 | return out 175 | 176 | 177 | class TextProposal(nn.Module): 178 | r"""Reverse proposal model, estimating: 179 | 180 | argmax_lambda log q(w_i|x_1, y_1, ..., x_n, y_n; lambda) 181 | 182 | approximation to the distribution of descriptions. 183 | 184 | Because they use only positive labels, it actually simplifies to 185 | 186 | argmax_lambda log q(w_i|x_1, ..., x_4; lambda) 187 | 188 | https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/model.py 189 | """ 190 | 191 | def __init__(self, embedding_module): 192 | super(TextProposal, self).__init__() 193 | self.embedding = embedding_module 194 | self.embedding_dim = embedding_module.embedding_dim 195 | self.vocab_size = embedding_module.num_embeddings 196 | self.gru = nn.GRU(self.embedding_dim, 512) 197 | self.outputs2vocab = nn.Linear(512, self.vocab_size) 198 | 199 | def forward(self, feats, seq, length): 200 | # feats is from example images 201 | batch_size = seq.size(0) 202 | 203 | if batch_size > 1: 204 | # BUGFIX? dont we need to sort feats too? 205 | sorted_lengths, sorted_idx = torch.sort(length, descending=True) 206 | seq = seq[sorted_idx] 207 | feats = feats[sorted_idx] 208 | 209 | feats = feats.unsqueeze(0) 210 | # reorder from (B,L,D) to (L,B,D) 211 | seq = seq.transpose(0, 1) 212 | 213 | # embed your sequences 214 | embed_seq = self.embedding(seq) 215 | 216 | packed_input = rnn_utils.pack_padded_sequence(embed_seq, 217 | sorted_lengths.cpu()) 218 | 219 | # shape = (seq_len, batch, hidden_dim) 220 | packed_output, _ = self.gru(packed_input, feats) 221 | output = rnn_utils.pad_packed_sequence(packed_output) 222 | output = output[0].contiguous() 223 | 224 | # reorder from (L,B,D) to (B,L,D) 225 | output = output.transpose(0, 1) 226 | 227 | if batch_size > 1: 228 | _, reversed_idx = torch.sort(sorted_idx) 229 | output = output[reversed_idx] 230 | 231 | max_length = output.size(1) 232 | output_2d = output.view(batch_size * max_length, 512) 233 | outputs_2d = self.outputs2vocab(output_2d) 234 | outputs = outputs_2d.view(batch_size, max_length, self.vocab_size) 235 | 236 | return outputs 237 | 238 | def sample(self, feats, sos_index, eos_index, pad_index, greedy=False): 239 | """Generate from image features using greedy search.""" 240 | with torch.no_grad(): 241 | batch_size = feats.size(0) 242 | 243 | # initialize hidden states using image features 244 | states = feats.unsqueeze(0) 245 | 246 | # first input is SOS token 247 | inputs = np.array([sos_index for _ in range(batch_size)]) 248 | inputs = torch.from_numpy(inputs) 249 | inputs = inputs.unsqueeze(1) 250 | inputs = inputs.to(feats.device) 251 | 252 | # save SOS as first generated token 253 | inputs_npy = inputs.squeeze(1).cpu().numpy() 254 | sampled_ids = [[w] for w in inputs_npy] 255 | 256 | # (B,L,D) to (L,B,D) 257 | inputs = inputs.transpose(0, 1) 258 | 259 | # compute embeddings 260 | inputs = self.embedding(inputs) 261 | 262 | for i in range(20): # like in jacobs repo 263 | outputs, states = self.gru(inputs, 264 | states) # outputs: (L=1,B,H) 265 | outputs = outputs.squeeze(0) # outputs: (B,H) 266 | outputs = self.outputs2vocab(outputs) # outputs: (B,V) 267 | 268 | if greedy: 269 | predicted = outputs.max(1)[1] 270 | predicted = predicted.unsqueeze(1) 271 | else: 272 | outputs = F.softmax(outputs, dim=1) 273 | predicted = torch.multinomial(outputs, 1) 274 | 275 | predicted_npy = predicted.squeeze(1).cpu().numpy() 276 | predicted_lst = predicted_npy.tolist() 277 | 278 | for w, so_far in zip(predicted_lst, sampled_ids): 279 | if so_far[-1] != eos_index: 280 | so_far.append(w) 281 | 282 | inputs = predicted.transpose(0, 1) # inputs: (L=1,B) 283 | inputs = self.embedding(inputs) # inputs: (L=1,B,E) 284 | 285 | sampled_lengths = [len(text) for text in sampled_ids] 286 | sampled_lengths = np.array(sampled_lengths) 287 | 288 | max_length = max(sampled_lengths) 289 | padded_ids = np.ones((batch_size, max_length)) * pad_index 290 | 291 | for i in range(batch_size): 292 | padded_ids[i, :sampled_lengths[i]] = sampled_ids[i] 293 | 294 | sampled_lengths = torch.from_numpy(sampled_lengths).long() 295 | sampled_ids = torch.from_numpy(padded_ids).long() 296 | 297 | return sampled_ids, sampled_lengths 298 | 299 | 300 | class EmbedImageRep(nn.Module): 301 | def __init__(self, z_dim): 302 | super(EmbedImageRep, self).__init__() 303 | self.z_dim = z_dim 304 | self.model = nn.Sequential(nn.Linear(self.z_dim, 512), nn.ReLU(), 305 | nn.Linear(512, 512)) 306 | 307 | def forward(self, x): 308 | return self.model(x) 309 | 310 | 311 | class EmbedTextRep(nn.Module): 312 | def __init__(self, z_dim): 313 | super(EmbedTextRep, self).__init__() 314 | self.z_dim = z_dim 315 | self.model = nn.Sequential(nn.Linear(self.z_dim, 512), nn.ReLU(), 316 | nn.Linear(512, 512)) 317 | 318 | def forward(self, x): 319 | return self.model(x) 320 | 321 | 322 | class Scorer(nn.Module): 323 | def __init__(self): 324 | super(Scorer, self).__init__() 325 | 326 | def forward(self, x, y): 327 | raise NotImplementedError 328 | 329 | def score(self, x, y): 330 | raise NotImplementedError 331 | 332 | def batchwise_score(self, x, y): 333 | raise NotImplementedError 334 | 335 | 336 | class DotPScorer(Scorer): 337 | def __init__(self): 338 | super(DotPScorer, self).__init__() 339 | 340 | def score(self, x, y): 341 | return torch.sum(x * y, dim=1) 342 | 343 | def batchwise_score(self, y, x): 344 | # REVERSED 345 | bw_scores = torch.einsum('ijk,ik->ij', (x, y)) 346 | return torch.sum(bw_scores, dim=1) 347 | 348 | 349 | class BilinearScorer(DotPScorer): 350 | def __init__(self, hidden_size, dropout=0.0, identity_debug=False): 351 | super(BilinearScorer, self).__init__() 352 | self.bilinear = nn.Linear(hidden_size, hidden_size, bias=False) 353 | self.dropout_p = dropout 354 | if self.dropout_p > 0.0: 355 | self.dropout = nn.Dropout(p=self.dropout_p) 356 | else: 357 | self.dropout = lambda x: x 358 | if identity_debug: 359 | # Set this as identity matrix to make sure we get the same output 360 | # as DotPScorer 361 | self.bilinear.weight = nn.Parameter( 362 | torch.eye(hidden_size, dtype=torch.float32)) 363 | self.bilinear.weight.requires_grad = False 364 | 365 | def score(self, x, y): 366 | wy = self.bilinear(y) 367 | wy = self.dropout(wy) 368 | return super(BilinearScorer, self).score(x, wy) 369 | 370 | def batchwise_score(self, x, y): 371 | """ 372 | x: (batch_size, h) 373 | y: (batch_size, n_examples, h) 374 | """ 375 | batch_size, n_examples, h = y.shape 376 | wy = self.bilinear(y.view(batch_size * n_examples, 377 | -1)).unsqueeze(1).view_as(y) 378 | wy = self.dropout(wy) 379 | # wy: (batch_size, n_examples, h) 380 | return super(BilinearScorer, self).batchwise_score(x, wy) 381 | -------------------------------------------------------------------------------- /shapeworld/lsl/tre.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | from tqdm import trange 5 | from torch.nn.modules.distance import CosineSimilarity 6 | 7 | 8 | def flatten(l): 9 | if not isinstance(l, tuple): 10 | return (l, ) 11 | 12 | out = () 13 | for ll in l: 14 | out = out + flatten(ll) 15 | return out 16 | 17 | 18 | class L1Dist(nn.Module): 19 | def forward(self, pred, target): 20 | return torch.norm(pred - target, p=1, dim=1) 21 | 22 | 23 | class L2Dist(nn.Module): 24 | def forward(self, pred, target): 25 | return torch.norm(pred - target, p=2, dim=1) 26 | 27 | 28 | class CosDist(nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | self.cossim = CosineSimilarity() 32 | 33 | def forward(self, x, y): 34 | return 1 - self.cossim(x, y) 35 | 36 | 37 | class AddComp(nn.Module): 38 | def forward(self, embs, embs_mask): 39 | """ 40 | embs: (batch_size, max_feats, h) 41 | embs_mask: (batch_size, max_feats) 42 | """ 43 | embs_mask_exp = embs_mask.float().unsqueeze(2).expand_as(embs) 44 | embs_zeroed = embs * embs_mask_exp 45 | composed = embs_zeroed.sum(1) 46 | return composed 47 | 48 | 49 | class MulComp(nn.Module): 50 | def forward(self, embs, embs_mask): 51 | """ 52 | embs: (batch_size, max_feats, h) 53 | embs_mask: (batch_size, max_feats) 54 | """ 55 | raise NotImplementedError 56 | 57 | 58 | class Objective(nn.Module): 59 | def __init__(self, vocab, repr_size, comp_fn, err_fn, zero_init): 60 | super().__init__() 61 | self.emb = nn.Embedding(len(vocab), repr_size) 62 | if zero_init: 63 | self.emb.weight.data.zero_() 64 | self.comp = comp_fn 65 | self.err = err_fn 66 | 67 | def compose(self, feats, feats_mask): 68 | """ 69 | Input: 70 | batch_size, max_feats 71 | Output: 72 | batch_size, h 73 | """ 74 | embs = self.emb(feats) 75 | # Compose embeddings 76 | composed = self.comp(embs, feats_mask) 77 | return composed 78 | 79 | def forward(self, rep, feats, feats_mask): 80 | return self.err(self.compose(feats, feats_mask), rep) 81 | 82 | 83 | def tre(reps, 84 | feats, 85 | feats_mask, 86 | vocab, 87 | comp_fn, 88 | err_fn, 89 | quiet=False, 90 | steps=400, 91 | include_pred=False, 92 | zero_init=True): 93 | 94 | obj = Objective(vocab, reps.shape[1], comp_fn, err_fn, zero_init) 95 | obj = obj.to(reps.device) 96 | opt = optim.Adam(obj.parameters(), lr=0.001) 97 | 98 | if not quiet: 99 | ranger = trange(steps, desc='TRE') 100 | else: 101 | ranger = range(steps) 102 | for t in ranger: 103 | opt.zero_grad() 104 | loss = obj(reps, feats, feats_mask) 105 | total_loss = loss.sum() 106 | total_loss.backward() 107 | if not quiet and t % 100 == 0: 108 | print(total_loss.item()) 109 | opt.step() 110 | 111 | final_losses = [l.item() for l in loss] 112 | if include_pred: 113 | lexicon = { 114 | k: obj.emb(torch.LongTensor([v])).data.cpu().numpy() 115 | for k, v in vocab.items() 116 | } 117 | composed = [obj.compose(f, fm) for f, fm in zip(feats, feats_mask)] 118 | return final_losses, lexicon, composed 119 | else: 120 | return final_losses 121 | -------------------------------------------------------------------------------- /shapeworld/lsl/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities 3 | """ 4 | 5 | from collections import Counter, OrderedDict 6 | import json 7 | import os 8 | import shutil 9 | 10 | import numpy as np 11 | import torch 12 | 13 | random_counter = [0] 14 | 15 | 16 | def next_random(): 17 | random = np.random.RandomState(random_counter[0]) 18 | random_counter[0] += 1 19 | return random 20 | 21 | 22 | class OrderedCounter(Counter, OrderedDict): 23 | """Counter that remembers the order elements are first encountered""" 24 | 25 | def __repr__(self): 26 | return '%s(%r)' % (self.__class__.__name__, OrderedDict(self)) 27 | 28 | def __reduce__(self): 29 | return self.__class__, (OrderedDict(self), ) 30 | 31 | 32 | class AverageMeter(object): 33 | """Computes and stores the average and current value""" 34 | 35 | def __init__(self, raw=False): 36 | self.raw = raw 37 | self.reset() 38 | 39 | def reset(self): 40 | self.val = 0 41 | self.avg = 0 42 | self.sum = 0 43 | self.count = 0 44 | if self.raw: 45 | self.raw_scores = [] 46 | 47 | def update(self, val, n=1, raw_scores=None): 48 | self.val = val 49 | self.sum += val * n 50 | self.count += n 51 | self.avg = self.sum / self.count 52 | if self.raw: 53 | self.raw_scores.extend(list(raw_scores)) 54 | 55 | 56 | def save_checkpoint(state, is_best, folder='./', 57 | filename='checkpoint.pth.tar'): 58 | if not os.path.isdir(folder): 59 | os.mkdir(folder) 60 | torch.save(state, os.path.join(folder, filename)) 61 | if is_best: 62 | shutil.copyfile(os.path.join(folder, filename), 63 | os.path.join(folder, 'model_best.pth.tar')) 64 | 65 | 66 | def merge_args_with_dict(args, dic): 67 | for k, v in list(dic.items()): 68 | setattr(args, k, v) 69 | 70 | 71 | def make_output_and_sample_dir(out_dir): 72 | if not os.path.exists(out_dir): 73 | os.makedirs(out_dir) 74 | 75 | sample_dir = os.path.join(out_dir, 'samples') 76 | if not os.path.exists(sample_dir): 77 | os.makedirs(sample_dir) 78 | 79 | return out_dir, sample_dir 80 | 81 | 82 | def save_defaultdict_to_fs(d, out_path): 83 | d = dict(d) 84 | with open(out_path, 'w') as fp: 85 | d_str = json.dumps(d, ensure_ascii=True) 86 | fp.write(d_str) 87 | 88 | 89 | def idx2word(idx, i2w): 90 | sent_str = [str()] * len(idx) 91 | for i, sent in enumerate(idx): 92 | for word_id in sent: 93 | sent_str[i] += str(i2w[word_id.item()]) + " " 94 | sent_str[i] = sent_str[i].strip() 95 | 96 | return sent_str 97 | -------------------------------------------------------------------------------- /shapeworld/run_l3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python lsl/train.py --cuda \ 4 | --infer_hyp \ 5 | --hypo_lambda 1.0 \ 6 | --batch_size 100 \ 7 | --seed $RANDOM \ 8 | exp/l3 9 | -------------------------------------------------------------------------------- /shapeworld/run_lang_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python lsl/train.py --cuda \ 6 | --predict_concept_hyp \ 7 | --hypo_lambda 20.0 \ 8 | --seed "$RANDOM" \ 9 | --batch_size 100 \ 10 | --language_filter color \ 11 | exp/lsl_color 12 | 13 | python lsl/train.py --cuda \ 14 | --predict_concept_hyp \ 15 | --hypo_lambda 20.0 \ 16 | --seed "$RANDOM" \ 17 | --batch_size 100 \ 18 | --language_filter nocolor \ 19 | exp/lsl_nocolor 20 | 21 | python lsl/train.py --cuda \ 22 | --predict_concept_hyp \ 23 | --hypo_lambda 20.0 \ 24 | --seed "$RANDOM" \ 25 | --batch_size 100 \ 26 | --shuffle_words \ 27 | exp/lsl_shuffle_words 28 | 29 | python lsl/train.py --cuda \ 30 | --predict_concept_hyp \ 31 | --hypo_lambda 20.0 \ 32 | --seed "$RANDOM" \ 33 | --batch_size 100 \ 34 | --shuffle_captions \ 35 | exp/lsl_shuffle_captions 36 | -------------------------------------------------------------------------------- /shapeworld/run_lsl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HYPO_LAMBDA=20 4 | 5 | python lsl/train.py --cuda \ 6 | --predict_concept_hyp \ 7 | --hypo_lambda $HYPO_LAMBDA \ 8 | --batch_size 100 \ 9 | --seed $RANDOM \ 10 | exp/lsl 11 | -------------------------------------------------------------------------------- /shapeworld/run_lsl_img.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HYPO_LAMBDA=20 4 | 5 | python lsl/train.py --cuda \ 6 | --predict_concept_hyp \ 7 | --hypo_lambda $HYPO_LAMBDA \ 8 | --batch_size 100 \ 9 | --seed $RANDOM \ 10 | exp/lsl_img 11 | -------------------------------------------------------------------------------- /shapeworld/run_meta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python lsl/train.py --cuda \ 4 | --batch_size 100 \ 5 | --seed $RANDOM \ 6 | exp/meta 7 | --------------------------------------------------------------------------------