├── .gitignore
├── README.md
├── birds
├── .gitignore
├── LICENSE.txt
├── README.md
├── custom_filelists
│ └── CUB
│ │ ├── base.json
│ │ ├── novel.json
│ │ └── val.json
├── exp
│ └── README.md
├── fewshot
│ ├── backbone.py
│ ├── constants.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── additional_transforms.py
│ │ ├── datamgr.py
│ │ ├── dataset.py
│ │ └── lang_utils.py
│ ├── io_utils.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── language.py
│ │ └── protonet.py
│ ├── run_cl.py
│ ├── test.py
│ └── train.py
├── filelists
│ └── CUB
│ │ ├── download_CUB.sh
│ │ ├── save_np.py
│ │ └── write_CUB_filelist.py
├── run_l3.sh
├── run_lang_ablation.sh
├── run_lang_amount.sh
├── run_lsl.sh
└── run_meta.sh
└── shapeworld
├── .gitignore
├── LICENSE
├── README.md
├── analysis
├── analysis.Rproj
└── metrics.Rmd
├── exp
├── README.md
├── l3
│ ├── args.json
│ └── metrics.json
├── lsl
│ ├── args.json
│ └── metrics.json
├── lsl_color
│ ├── args.json
│ └── metrics.json
├── lsl_nocolor
│ ├── args.json
│ └── metrics.json
├── lsl_shuffle_captions
│ ├── args.json
│ └── metrics.json
├── lsl_shuffle_words
│ ├── args.json
│ └── metrics.json
└── meta
│ ├── args.json
│ └── metrics.json
├── lsl
├── datasets.py
├── models.py
├── train.py
├── tre.py
├── utils.py
└── vision.py
├── run_l3.sh
├── run_lang_ablation.sh
├── run_lsl.sh
├── run_lsl_img.sh
└── run_meta.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .Rproj.user
2 | .Rhistory
3 | .RData
4 | .Ruserdata
5 |
6 | sync_results.sh
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Shaping Visual Representations with Language for Few-shot Classification
2 |
3 | Code and data for
4 |
5 | > Jesse Mu, Percy Liang, and Noah Goodman. Shaping Visual Representations with Language for Few-shot Classification. ACL 2020. https://arxiv.org/abs/1911.02683
6 |
7 | In addition, a CodaLab executable paper (docker containers with code, data, and experiment runs) is available [here](https://bit.ly/lsl_acl20). There are some minor fixes for CodaLab compatibility on the codalab branch.
8 |
9 | The codebase is split into two repositories, `shapeworld` and `birds`, for the
10 | different tasks explored in this paper. Each have their own READMEs,
11 | instructions, and licenses, since they were extended from different existing
12 | codebases.
13 |
14 | If you found this code useful, please cite
15 |
16 | ```
17 | @inproceedings{mu2020shaping,
18 | author = {Jesse Mu, Percy Liang, and Noah Goodman},
19 | title = {Shaping Visual Representations with Language for Few-Shot Classification},
20 | booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
21 | year = {2020}
22 | }
23 | ```
24 |
--------------------------------------------------------------------------------
/birds/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__
3 |
4 | saves/*
5 | !saves/README.md
6 |
7 | filelists/CUB/*
8 | !filelists/CUB/download_cub.sh
9 | !filelists/CUB/*.py
10 |
11 | filelists/scenes/*
12 | !filelists/scenes/download_scenes.sh
13 | !filelists/scenes/*.py
14 |
15 | # Ignore codalab
16 | /checkpoints/
17 | /features/
18 | /args.json
19 | /results.json
20 | /reed-birds
21 | .Rproj.user
22 |
23 | .Rhistory
24 |
25 | *.out
26 |
27 | exp/*
28 | exp/*/*/checkpoints/*.tar
29 | exp/*/*/features/*.hdf5
30 | !exp/README.md
31 |
32 | test/*
33 |
34 | *.RData
35 |
36 | analysis/*.html
37 |
--------------------------------------------------------------------------------
/birds/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
--------------------------------------------------------------------------------
/birds/README.md:
--------------------------------------------------------------------------------
1 | # LSL - Birds
2 |
3 | This codebase is built off of [wyharveychen/CloserLookFewShot](https://github.com/wyharveychen/CloserLookFewShot) ([paper](https://openreview.net/pdf?id=HkxLXnAcFQ)) - thanks to them!
4 |
5 | ## Dependencies
6 |
7 | Tested with Python 3.7.4, torch 1.4.0, torchvision 0.4.1, numpy 1.16.2, PIL
8 | 5.4.1, torchfile 0.1.0, sklearn 0.20.3, pandas 0.25.2
9 |
10 | Glove initialization depends on spacy 2.2.2 and the spacy `en_vectors_web_lg`
11 | model:
12 |
13 | ```
14 | python -m spacy download en_vectors_web_lg
15 | ```
16 |
17 | ## Data
18 |
19 | To download data, cd to `filelists/CUB` and run `source download_CUB.sh`. This
20 | downloads the CUB 200-2011 dataset and also runs `python write_CUB_filelist.py`.
21 |
22 | `python write_CUB_filelist.py` saves a filelist (train/val/test) split
23 | to `./custom_filelists/CUB/{base,val,novel}.json`.
24 |
25 | Then run `python save_np.py` which takes the images and serializes them as NP arrays
26 | (for speed).
27 |
28 | The language data is available from
29 | [reedscot/cvpr2016](https://github.com/reedscot/cvpr2016) ([GDrive link](https://drive.google.com/open?id=0B0ywwgffWnLLZW9uVHNjb2JmNlE)). Download it and unzip to `reed-birds` directory in the main directory (e.g. the path to the vocab file should be `./reed-birds/vocab_c10.t7`).
30 |
31 | ## Running
32 |
33 | To train and evaluate a model, you will run `fewshot/train.py` and `fewshot/test.py`,
34 | respectively. Alternatively, for CodaLab, the `fewshot/run_cl.py` script does
35 | both training and testing, with slightly more friendly argument names
36 | (`fewshot/run_cl.py --help`) for more.
37 |
38 | The shell scripts contain commands for running the various models:
39 |
40 | - `run_meta.sh`: Non-linguistic protonet baseline
41 | - `run_l3.sh`: learning with latent language (Andreas et al., 2018)
42 | - `run_lsl.sh`: Ours
43 | - `run_lang_ablation.sh`: Language ablation studies
44 | - `run_lang_amount.sh`: Language amount
45 |
46 | ## References
47 |
48 | (from the original CloserLookFewShot repo)
49 |
50 | Our testbed builds upon several existing publicly available code. Specifically, we have modified and integrated the following code into this project:
51 |
52 | * Framework, Backbone, Method: Matching Network
53 | https://github.com/facebookresearch/low-shot-shrink-hallucinate
54 | * Omniglot dataset, Method: Prototypical Network
55 | https://github.com/jakesnell/prototypical-networks
56 | * Method: Relational Network
57 | https://github.com/floodsung/LearningToCompare_FSL
58 | * Method: MAML
59 | https://github.com/cbfinn/maml
60 | https://github.com/dragen1860/MAML-Pytorch
61 | https://github.com/katerakelly/pytorch-maml
62 |
--------------------------------------------------------------------------------
/birds/exp/README.md:
--------------------------------------------------------------------------------
1 | # Exp
2 |
3 | Placeholder for model experiments.
4 |
5 | Use `fewshot/run_cl.py` with a `--log_dir` pointing to a directory within this
6 | folder.
7 |
--------------------------------------------------------------------------------
/birds/fewshot/backbone.py:
--------------------------------------------------------------------------------
1 | """
2 | Backbone vision models.
3 |
4 | This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate
5 | """
6 |
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torchvision.models as models
13 |
14 | CONV4_HIDDEN_SIZES = [112896, 28224, 6400, 1600]
15 |
16 |
17 | class Identity(nn.Module):
18 | def __init__(self):
19 | super(Identity, self).__init__()
20 |
21 | def forward(self, x):
22 | return x
23 |
24 |
25 | # Basic ResNet model
26 | def init_layer(L):
27 | # Initialization using fan-in
28 | if isinstance(L, nn.Conv2d):
29 | n = L.kernel_size[0] * L.kernel_size[1] * L.out_channels
30 | L.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
31 | elif isinstance(L, nn.BatchNorm2d):
32 | L.weight.data.fill_(1)
33 | L.bias.data.fill_(0)
34 |
35 |
36 | class Flatten(nn.Module):
37 | def __init__(self):
38 | super(Flatten, self).__init__()
39 |
40 | def forward(self, x):
41 | return x.view(x.size(0), -1)
42 |
43 |
44 | class Linear_fw(nn.Linear): # used in MAML to forward input with fast weight
45 | def __init__(self, in_features, out_features):
46 | super(Linear_fw, self).__init__(in_features, out_features)
47 | self.weight.fast = None # Lazy hack to add fast weight link
48 | self.bias.fast = None
49 |
50 | def forward(self, x):
51 | if self.weight.fast is not None and self.bias.fast is not None:
52 | out = F.linear(x, self.weight.fast, self.bias.fast)
53 | else:
54 | out = super(Linear_fw, self).forward(x)
55 | return out
56 |
57 |
58 | class Conv2d_fw(nn.Conv2d): # used in MAML to forward input with fast weight
59 | def __init__(
60 | self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True
61 | ):
62 | super(Conv2d_fw, self).__init__(
63 | in_channels,
64 | out_channels,
65 | kernel_size,
66 | stride=stride,
67 | padding=padding,
68 | bias=bias,
69 | )
70 | self.weight.fast = None
71 | if self.bias is not None:
72 | self.bias.fast = None
73 |
74 | def forward(self, x):
75 | if self.bias is None:
76 | if self.weight.fast is not None:
77 | out = F.conv2d(
78 | x, self.weight.fast, None, stride=self.stride, padding=self.padding
79 | )
80 | else:
81 | out = super(Conv2d_fw, self).forward(x)
82 | else:
83 | if self.weight.fast is not None and self.bias.fast is not None:
84 | out = F.conv2d(
85 | x,
86 | self.weight.fast,
87 | self.bias.fast,
88 | stride=self.stride,
89 | padding=self.padding,
90 | )
91 | else:
92 | out = super(Conv2d_fw, self).forward(x)
93 |
94 | return out
95 |
96 |
97 | class BatchNorm2d_fw(nn.BatchNorm2d): # used in MAML to forward input with fast weight
98 | def __init__(self, num_features):
99 | super(BatchNorm2d_fw, self).__init__(num_features)
100 | self.weight.fast = None
101 | self.bias.fast = None
102 |
103 | def forward(self, x):
104 | running_mean = torch.zeros(x.data.size()[1]).cuda()
105 | running_var = torch.ones(x.data.size()[1]).cuda()
106 | if self.weight.fast is not None and self.bias.fast is not None:
107 | out = F.batch_norm(
108 | x,
109 | running_mean,
110 | running_var,
111 | self.weight.fast,
112 | self.bias.fast,
113 | training=True,
114 | momentum=1,
115 | )
116 | # batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py
117 | else:
118 | out = F.batch_norm(
119 | x,
120 | running_mean,
121 | running_var,
122 | self.weight,
123 | self.bias,
124 | training=True,
125 | momentum=1,
126 | )
127 | return out
128 |
129 |
130 | # Simple Conv Block
131 | class ConvBlock(nn.Module):
132 | maml = False # Default
133 |
134 | def __init__(self, indim, outdim, pool=True, padding=1):
135 | super(ConvBlock, self).__init__()
136 | self.indim = indim
137 | self.outdim = outdim
138 | if self.maml:
139 | self.C = Conv2d_fw(indim, outdim, 3, padding=padding)
140 | self.BN = BatchNorm2d_fw(outdim)
141 | else:
142 | self.C = nn.Conv2d(indim, outdim, 3, padding=padding)
143 | self.BN = nn.BatchNorm2d(outdim)
144 | self.relu = nn.ReLU(inplace=True)
145 |
146 | self.parametrized_layers = [self.C, self.BN, self.relu]
147 | if pool:
148 | self.pool = nn.MaxPool2d(2)
149 | self.parametrized_layers.append(self.pool)
150 |
151 | for layer in self.parametrized_layers:
152 | init_layer(layer)
153 |
154 | self.trunk = nn.Sequential(*self.parametrized_layers)
155 |
156 | def forward(self, x):
157 | out = self.trunk(x)
158 | return out
159 |
160 |
161 | # Simple ResNet Block
162 | class SimpleBlock(nn.Module):
163 | maml = False # Default
164 |
165 | def __init__(self, indim, outdim, half_res):
166 | super(SimpleBlock, self).__init__()
167 | self.indim = indim
168 | self.outdim = outdim
169 | if self.maml:
170 | self.C1 = Conv2d_fw(
171 | indim,
172 | outdim,
173 | kernel_size=3,
174 | stride=2 if half_res else 1,
175 | padding=1,
176 | bias=False,
177 | )
178 | self.BN1 = BatchNorm2d_fw(outdim)
179 | self.C2 = Conv2d_fw(outdim, outdim, kernel_size=3, padding=1, bias=False)
180 | self.BN2 = BatchNorm2d_fw(outdim)
181 | else:
182 | self.C1 = nn.Conv2d(
183 | indim,
184 | outdim,
185 | kernel_size=3,
186 | stride=2 if half_res else 1,
187 | padding=1,
188 | bias=False,
189 | )
190 | self.BN1 = nn.BatchNorm2d(outdim)
191 | self.C2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1, bias=False)
192 | self.BN2 = nn.BatchNorm2d(outdim)
193 | self.relu1 = nn.ReLU(inplace=True)
194 | self.relu2 = nn.ReLU(inplace=True)
195 |
196 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
197 |
198 | self.half_res = half_res
199 |
200 | # if the input number of channels is not equal to the output, then need a 1x1 convolution
201 | if indim != outdim:
202 | if self.maml:
203 | self.shortcut = Conv2d_fw(
204 | indim, outdim, 1, 2 if half_res else 1, bias=False
205 | )
206 | self.BNshortcut = BatchNorm2d_fw(outdim)
207 | else:
208 | self.shortcut = nn.Conv2d(
209 | indim, outdim, 1, 2 if half_res else 1, bias=False
210 | )
211 | self.BNshortcut = nn.BatchNorm2d(outdim)
212 |
213 | self.parametrized_layers.append(self.shortcut)
214 | self.parametrized_layers.append(self.BNshortcut)
215 | self.shortcut_type = "1x1"
216 | else:
217 | self.shortcut_type = "identity"
218 |
219 | for layer in self.parametrized_layers:
220 | init_layer(layer)
221 |
222 | def forward(self, x):
223 | out = self.C1(x)
224 | out = self.BN1(out)
225 | out = self.relu1(out)
226 | out = self.C2(out)
227 | out = self.BN2(out)
228 | short_out = (
229 | x if self.shortcut_type == "identity" else self.BNshortcut(self.shortcut(x))
230 | )
231 | out = out + short_out
232 | out = self.relu2(out)
233 | return out
234 |
235 |
236 | # Bottleneck block
237 | class BottleneckBlock(nn.Module):
238 | maml = False # Default
239 |
240 | def __init__(self, indim, outdim, half_res):
241 | super(BottleneckBlock, self).__init__()
242 | bottleneckdim = int(outdim / 4)
243 | self.indim = indim
244 | self.outdim = outdim
245 | if self.maml:
246 | self.C1 = Conv2d_fw(indim, bottleneckdim, kernel_size=1, bias=False)
247 | self.BN1 = BatchNorm2d_fw(bottleneckdim)
248 | self.C2 = Conv2d_fw(
249 | bottleneckdim,
250 | bottleneckdim,
251 | kernel_size=3,
252 | stride=2 if half_res else 1,
253 | padding=1,
254 | )
255 | self.BN2 = BatchNorm2d_fw(bottleneckdim)
256 | self.C3 = Conv2d_fw(bottleneckdim, outdim, kernel_size=1, bias=False)
257 | self.BN3 = BatchNorm2d_fw(outdim)
258 | else:
259 | self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False)
260 | self.BN1 = nn.BatchNorm2d(bottleneckdim)
261 | self.C2 = nn.Conv2d(
262 | bottleneckdim,
263 | bottleneckdim,
264 | kernel_size=3,
265 | stride=2 if half_res else 1,
266 | padding=1,
267 | )
268 | self.BN2 = nn.BatchNorm2d(bottleneckdim)
269 | self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False)
270 | self.BN3 = nn.BatchNorm2d(outdim)
271 |
272 | self.relu = nn.ReLU()
273 | self.parametrized_layers = [
274 | self.C1,
275 | self.BN1,
276 | self.C2,
277 | self.BN2,
278 | self.C3,
279 | self.BN3,
280 | ]
281 | self.half_res = half_res
282 |
283 | # if the input number of channels is not equal to the output, then need a 1x1 convolution
284 | if indim != outdim:
285 | if self.maml:
286 | self.shortcut = Conv2d_fw(
287 | indim, outdim, 1, stride=2 if half_res else 1, bias=False
288 | )
289 | else:
290 | self.shortcut = nn.Conv2d(
291 | indim, outdim, 1, stride=2 if half_res else 1, bias=False
292 | )
293 |
294 | self.parametrized_layers.append(self.shortcut)
295 | self.shortcut_type = "1x1"
296 | else:
297 | self.shortcut_type = "identity"
298 |
299 | for layer in self.parametrized_layers:
300 | init_layer(layer)
301 |
302 | def forward(self, x):
303 |
304 | short_out = x if self.shortcut_type == "identity" else self.shortcut(x)
305 | out = self.C1(x)
306 | out = self.BN1(out)
307 | out = self.relu(out)
308 | out = self.C2(out)
309 | out = self.BN2(out)
310 | out = self.relu(out)
311 | out = self.C3(out)
312 | out = self.BN3(out)
313 | out = out + short_out
314 |
315 | out = self.relu(out)
316 | return out
317 |
318 |
319 | class ConvNet(nn.Module):
320 | def __init__(self, depth, flatten=True):
321 | super(ConvNet, self).__init__()
322 | trunk = []
323 | for i in range(depth):
324 | indim = 3 if i == 0 else 64
325 | outdim = 64
326 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers
327 | trunk.append(B)
328 |
329 | self.flatten = flatten
330 | if self.flatten:
331 | trunk.append(Flatten())
332 |
333 | self.trunk = nn.Sequential(*trunk)
334 | self.final_feat_dim = 1600
335 |
336 | def forward(self, x):
337 | out = self.trunk(x)
338 | return out
339 |
340 | def forward_seq(self, x):
341 | hiddens = []
342 | if self.flatten:
343 | seq = self.trunk[:-1]
344 | else:
345 | seq = self.trunk[:-1]
346 | for layer in seq:
347 | x = layer(x)
348 | hiddens.append(x)
349 | return hiddens
350 |
351 |
352 | class ConvNetNopool(
353 | nn.Module
354 | ): # Relation net use a 4 layer conv with pooling in only first two layers, else no pooling
355 | def __init__(self, depth):
356 | super(ConvNetNopool, self).__init__()
357 | trunk = []
358 | for i in range(depth):
359 | indim = 3 if i == 0 else 64
360 | outdim = 64
361 | B = ConvBlock(
362 | indim, outdim, pool=(i in [0, 1]), padding=0 if i in [0, 1] else 1
363 | ) # only first two layer has pooling and no padding
364 | trunk.append(B)
365 |
366 | self.trunk = nn.Sequential(*trunk)
367 | self.final_feat_dim = [64, 19, 19]
368 |
369 | def forward(self, x):
370 | out = self.trunk(x)
371 | return out
372 |
373 |
374 | class ConvNetS(nn.Module): # For omniglot, only 1 input channel, output dim is 64
375 | def __init__(self, depth, flatten=True):
376 | super(ConvNetS, self).__init__()
377 | trunk = []
378 | for i in range(depth):
379 | indim = 1 if i == 0 else 64
380 | outdim = 64
381 | B = ConvBlock(indim, outdim, pool=(i < 4)) # only pooling for fist 4 layers
382 | trunk.append(B)
383 |
384 | if flatten:
385 | trunk.append(Flatten())
386 |
387 | self.trunk = nn.Sequential(*trunk)
388 | self.final_feat_dim = 64
389 |
390 | def forward(self, x):
391 | out = x[:, 0:1, :, :] # only use the first dimension
392 | out = self.trunk(out)
393 | return out
394 |
395 |
396 | class ConvNetSNopool(nn.Module):
397 | def __init__(self, depth):
398 | super(ConvNetSNopool, self).__init__()
399 | trunk = []
400 | for i in range(depth):
401 | indim = 1 if i == 0 else 64
402 | outdim = 64
403 | B = ConvBlock(
404 | indim, outdim, pool=(i in [0, 1]), padding=0 if i in [0, 1] else 1
405 | ) # only first two layer has pooling and no padding
406 | trunk.append(B)
407 |
408 | self.trunk = nn.Sequential(*trunk)
409 | self.final_feat_dim = [64, 5, 5]
410 |
411 | def forward(self, x):
412 | out = x[:, 0:1, :, :] # only use the first dimension
413 | out = self.trunk(out)
414 | return out
415 |
416 |
417 | class ResNet(nn.Module):
418 | maml = False # Default
419 |
420 | def __init__(self, block, list_of_num_layers, list_of_out_dims, flatten=True):
421 | # list_of_num_layers specifies number of layers in each stage
422 | # list_of_out_dims specifies number of output channel for each stage
423 | super(ResNet, self).__init__()
424 | assert len(list_of_num_layers) == 4, "Can have only four stages"
425 | if self.maml:
426 | conv1 = Conv2d_fw(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
427 | bn1 = BatchNorm2d_fw(64)
428 | else:
429 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
430 | bn1 = nn.BatchNorm2d(64)
431 |
432 | relu = nn.ReLU()
433 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
434 |
435 | init_layer(conv1)
436 | init_layer(bn1)
437 |
438 | trunk = [conv1, bn1, relu, pool1]
439 |
440 | indim = 64
441 | for i in range(4):
442 |
443 | for j in range(list_of_num_layers[i]):
444 | half_res = (i >= 1) and (j == 0)
445 | B = block(indim, list_of_out_dims[i], half_res)
446 | trunk.append(B)
447 | indim = list_of_out_dims[i]
448 |
449 | if flatten:
450 | avgpool = nn.AvgPool2d(7)
451 | trunk.append(avgpool)
452 | trunk.append(Flatten())
453 | self.final_feat_dim = indim
454 | else:
455 | self.final_feat_dim = [indim, 7, 7]
456 |
457 | self.trunk = nn.Sequential(*trunk)
458 |
459 | def forward(self, x):
460 | out = self.trunk(x)
461 | return out
462 |
463 |
464 | def Conv4():
465 | return ConvNet(4)
466 |
467 |
468 | def Conv6():
469 | return ConvNet(6)
470 |
471 |
472 | def Conv4NP():
473 | return ConvNetNopool(4)
474 |
475 |
476 | def Conv6NP():
477 | return ConvNetNopool(6)
478 |
479 |
480 | def Conv4S():
481 | return ConvNetS(4)
482 |
483 |
484 | def Conv4SNP():
485 | return ConvNetSNopool(4)
486 |
487 |
488 | def ResNet10(flatten=True):
489 | return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], flatten)
490 |
491 |
492 | def ResNet18(flatten=True):
493 | return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], flatten)
494 |
495 |
496 | def PretrainedResNet18():
497 | rn18 = models.resnet18(pretrained=True)
498 | rn18.final_feat_dim = 512
499 | rn18.fc = Identity() # We don't use final fc
500 | return rn18
501 |
502 |
503 | def ResNet34(flatten=True):
504 | return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], flatten)
505 |
506 |
507 | def ResNet50(flatten=True):
508 | return ResNet(BottleneckBlock, [3, 4, 6, 3], [256, 512, 1024, 2048], flatten)
509 |
510 |
511 | def ResNet101(flatten=True):
512 | return ResNet(BottleneckBlock, [3, 4, 23, 3], [256, 512, 1024, 2048], flatten)
513 |
--------------------------------------------------------------------------------
/birds/fewshot/constants.py:
--------------------------------------------------------------------------------
1 | DATA_DIR = "./custom_filelists/CUB/"
2 | LANG_DIR = "./reed-birds/"
3 |
--------------------------------------------------------------------------------
/birds/fewshot/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import additional_transforms, datamgr, dataset
2 |
--------------------------------------------------------------------------------
/birds/fewshot/data/additional_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from PIL import ImageEnhance
9 |
10 | transformtypedict = dict(
11 | Brightness=ImageEnhance.Brightness,
12 | Contrast=ImageEnhance.Contrast,
13 | Sharpness=ImageEnhance.Sharpness,
14 | Color=ImageEnhance.Color,
15 | )
16 |
17 |
18 | class ImageJitter(object):
19 | def __init__(self, transformdict):
20 | self.transforms = [
21 | (transformtypedict[k], transformdict[k]) for k in transformdict
22 | ]
23 |
24 | def __call__(self, img):
25 | out = img
26 | randtensor = torch.rand(len(self.transforms))
27 |
28 | for i, (transformer, alpha) in enumerate(self.transforms):
29 | r = alpha * (randtensor[i] * 2.0 - 1.0) + 1
30 | out = transformer(out).enhance(r).convert("RGB")
31 |
32 | return out
33 |
--------------------------------------------------------------------------------
/birds/fewshot/data/datamgr.py:
--------------------------------------------------------------------------------
1 | # This code is modified from
2 | # https://github.com/facebookresearch/low-shot-shrink-hallucinate
3 |
4 | from abc import abstractmethod
5 |
6 | import torch
7 | import torchvision.transforms as transforms
8 |
9 | import data.additional_transforms as add_transforms
10 | from data.dataset import EpisodicBatchSampler, SetDataset, SimpleDataset
11 |
12 |
13 | class TransformLoader:
14 | def __init__(
15 | self,
16 | image_size,
17 | normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18 | jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4),
19 | ):
20 | self.image_size = image_size
21 | self.normalize_param = normalize_param
22 | self.jitter_param = jitter_param
23 |
24 | def parse_transform(self, transform_type):
25 | if transform_type == "ImageJitter":
26 | method = add_transforms.ImageJitter(self.jitter_param)
27 | return method
28 | method = getattr(transforms, transform_type)
29 | if transform_type == "RandomResizedCrop":
30 | return method(self.image_size)
31 | elif transform_type == "CenterCrop":
32 | return method(self.image_size)
33 | elif transform_type == "Resize":
34 | return method([int(self.image_size * 1.15), int(self.image_size * 1.15)])
35 | elif transform_type == "Normalize":
36 | return method(**self.normalize_param)
37 | else:
38 | return method()
39 |
40 | def get_composed_transform(
41 | self,
42 | aug=False,
43 | normalize=True,
44 | to_pil=True,
45 | confound_noise=0.0,
46 | confound_noise_class_weight=0.0,
47 | ):
48 | if aug:
49 | transform_list = [
50 | "RandomResizedCrop",
51 | "ImageJitter",
52 | "RandomHorizontalFlip",
53 | "ToTensor",
54 | ]
55 | else:
56 | transform_list = ["Resize", "CenterCrop", "ToTensor"]
57 |
58 | if confound_noise != 0.0:
59 | transform_list.append(
60 | ("Noise", confound_noise, confound_noise_class_weight)
61 | )
62 |
63 | if normalize:
64 | transform_list.append("Normalize")
65 |
66 | if to_pil:
67 | transform_list = ["ToPILImage"] + transform_list
68 |
69 | transform_funcs = [self.parse_transform(x) for x in transform_list]
70 | transform = transforms.Compose(transform_funcs)
71 | return transform
72 |
73 | def get_normalize(self):
74 | return self.parse_transform("Normalize")
75 |
76 |
77 | class DataManager:
78 | @abstractmethod
79 | def get_data_loader(self, data_file, aug):
80 | pass
81 |
82 |
83 | class SimpleDataManager(DataManager):
84 | def __init__(self, image_size, batch_size, num_workers=12):
85 | super(SimpleDataManager, self).__init__()
86 | self.batch_size = batch_size
87 | self.trans_loader = TransformLoader(image_size)
88 | self.num_workers = num_workers
89 |
90 | def get_data_loader(
91 | self, data_file, aug, lang_dir=None, normalize=True, to_pil=False
92 | ): # parameters that would change on train/val set
93 | if lang_dir is not None:
94 | raise NotImplementedError
95 | transform = self.trans_loader.get_composed_transform(
96 | aug, normalize=normalize, to_pil=to_pil
97 | )
98 | dataset = SimpleDataset(data_file, transform)
99 | data_loader_params = dict(
100 | batch_size=self.batch_size,
101 | shuffle=True,
102 | num_workers=self.num_workers,
103 | pin_memory=True,
104 | )
105 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
106 |
107 | return data_loader
108 |
109 |
110 | class SetDataManager(DataManager):
111 | def __init__(
112 | self, name, image_size, n_way, n_support, n_query, n_episode=100, args=None
113 | ):
114 | super(SetDataManager, self).__init__()
115 | self.name = name
116 | self.image_size = image_size
117 | self.n_way = n_way
118 | self.batch_size = n_support + n_query
119 | self.n_episode = n_episode
120 | self.args = args
121 |
122 | self.trans_loader = TransformLoader(image_size)
123 |
124 | def get_data_loader(
125 | self,
126 | data_file,
127 | aug,
128 | lang_dir=None,
129 | normalize=True,
130 | vocab=None,
131 | max_class=None,
132 | max_img_per_class=None,
133 | max_lang_per_class=None,
134 | ):
135 | transform = self.trans_loader.get_composed_transform(aug, normalize=normalize)
136 |
137 | dataset = SetDataset(
138 | self.name,
139 | data_file,
140 | self.batch_size,
141 | transform,
142 | args=self.args,
143 | lang_dir=lang_dir,
144 | vocab=vocab,
145 | max_class=max_class,
146 | max_img_per_class=max_img_per_class,
147 | max_lang_per_class=max_lang_per_class,
148 | )
149 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_episode)
150 | data_loader_params = dict(
151 | batch_sampler=sampler, num_workers=self.args.n_workers, pin_memory=True,
152 | )
153 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
154 | return data_loader
155 |
--------------------------------------------------------------------------------
/birds/fewshot/data/dataset.py:
--------------------------------------------------------------------------------
1 | # This code is modified from
2 | # https://github.com/facebookresearch/low-shot-shrink-hallucinate
3 |
4 | import glob
5 | import json
6 | import os
7 |
8 | import numpy as np
9 | import torch
10 | import torchvision.transforms as transforms
11 | from numpy import random
12 | from PIL import Image
13 | import torchfile
14 |
15 | from . import lang_utils
16 |
17 |
18 | CUB_IMAGES_PATH = "CUB_200_2011/images"
19 |
20 |
21 | def identity(x):
22 | return x
23 |
24 |
25 | def load_image(image_path):
26 | img = Image.open(image_path).convert("RGB")
27 | return img
28 |
29 |
30 | class SimpleDataset:
31 | def __init__(self, data_file, transform, target_transform=identity):
32 | with open(data_file, "r") as f:
33 | self.meta = json.load(f)
34 | self.transform = transform
35 | self.target_transform = target_transform
36 |
37 | def __getitem__(self, i):
38 | image_path = os.path.join(self.meta["image_names"][i])
39 | img = load_image(image_path)
40 | img = self.transform(img)
41 | target = self.target_transform(self.meta["image_labels"][i])
42 | return img, target
43 |
44 | def __len__(self):
45 | return len(self.meta["image_names"])
46 |
47 |
48 | class SetDataset:
49 | def __init__(
50 | self,
51 | name,
52 | data_file,
53 | batch_size,
54 | transform,
55 | args=None,
56 | lang_dir=None,
57 | vocab=None,
58 | max_class=None,
59 | max_img_per_class=None,
60 | max_lang_per_class=None,
61 | ):
62 | self.name = name
63 | with open(data_file, "r") as f:
64 | self.meta = json.load(f)
65 |
66 | self.args = args
67 | self.max_class = max_class
68 | self.max_img_per_class = max_img_per_class
69 | self.max_lang_per_class = max_lang_per_class
70 |
71 | if not (1 <= args.n_caption <= 10):
72 | raise ValueError("Invalid # captions {}".format(args.n_caption))
73 |
74 | self.cl_list = np.unique(self.meta["image_labels"]).tolist()
75 |
76 | if self.max_class is not None:
77 | if self.max_class > len(self.cl_list):
78 | raise ValueError(
79 | "max_class set to {} but only {} classes in {}".format(
80 | self.max_class, len(self.cl_list), data_file
81 | )
82 | )
83 | self.cl_list = self.cl_list[: self.max_class]
84 |
85 | if args.language_filter not in ["all", "color", "nocolor"]:
86 | raise NotImplementedError(
87 | "language_filter = {}".format(args.language_filter)
88 | )
89 |
90 | self.sub_meta_lang = {}
91 | self.sub_meta_lang_length = {}
92 | self.sub_meta_lang_mask = {}
93 | self.sub_meta = {}
94 |
95 | for cl in self.cl_list:
96 | self.sub_meta[cl] = []
97 | self.sub_meta_lang[cl] = []
98 | self.sub_meta_lang_length[cl] = []
99 | self.sub_meta_lang_mask[cl] = []
100 |
101 | # Load language and mapping from image names -> lang idx
102 | self.lang = {}
103 | self.lang_lengths = {}
104 | self.lang_masks = {}
105 | self.image_name_idx = {}
106 | for cln, label_name in enumerate(self.meta["label_names"]):
107 | # Use the numeric class id instead of label name due to
108 | # inconsistencies
109 | digits = label_name.split(".")[0]
110 | matching_names = [
111 | x
112 | for x in os.listdir(os.path.join(lang_dir, "word_c10"))
113 | if x.startswith(digits)
114 | ]
115 | assert len(matching_names) == 1, matching_names
116 | label_file = os.path.join(lang_dir, "word_c10", matching_names[0])
117 | lang_tensor = torch.from_numpy(torchfile.load(label_file)).long()
118 | # Make words last dim
119 | lang_tensor = lang_tensor.transpose(2, 1)
120 | lang_tensor = lang_tensor - 1 # XXX: Decrement language by 1 upon load
121 |
122 | if (
123 | self.args.language_filter == "color"
124 | or self.args.language_filter == "nocolor"
125 | ):
126 | lang_tensor = lang_utils.filter_language(
127 | lang_tensor, self.args.language_filter, vocab
128 | )
129 |
130 | if self.args.shuffle_lang:
131 | lang_tensor = lang_utils.shuffle_language(lang_tensor)
132 |
133 | lang_lengths = lang_utils.get_lang_lengths(lang_tensor)
134 |
135 | # Add start and end of sentence tokens to language
136 | lang_tensor, lang_lengths = lang_utils.add_sos_eos(
137 | lang_tensor, lang_lengths, vocab
138 | )
139 | lang_masks = lang_utils.get_lang_masks(
140 | lang_lengths, max_len=lang_tensor.shape[2]
141 | )
142 |
143 | self.lang[label_name] = lang_tensor
144 | self.lang_lengths[label_name] = lang_lengths
145 | self.lang_masks[label_name] = lang_masks
146 |
147 | # Give images their numeric ids according to alphabetical order
148 | if self.name == "CUB":
149 | img_dir = os.path.join(lang_dir, "text_c10", label_name, "*.txt")
150 | sorted_imgs = sorted(
151 | [
152 | os.path.splitext(os.path.basename(i))[0]
153 | for i in glob.glob(img_dir)
154 | ]
155 | )
156 | for i, img_fname in enumerate(sorted_imgs):
157 | self.image_name_idx[img_fname] = i
158 |
159 | for x, y in zip(self.meta["image_names"], self.meta["image_labels"]):
160 | if y in self.sub_meta:
161 | self.sub_meta[y].append(x)
162 | label_name = self.meta["label_names"][y]
163 |
164 | image_basename = os.path.splitext(os.path.basename(x))[0]
165 | if self.name == "CUB":
166 | image_lang_idx = self.image_name_idx[image_basename]
167 | else:
168 | image_lang_idx = int(image_basename[-1])
169 |
170 | captions = self.lang[label_name][image_lang_idx]
171 | lengths = self.lang_lengths[label_name][image_lang_idx]
172 | masks = self.lang_masks[label_name][image_lang_idx]
173 |
174 | self.sub_meta_lang[y].append(captions)
175 | self.sub_meta_lang_length[y].append(lengths)
176 | self.sub_meta_lang_mask[y].append(masks)
177 | else:
178 | assert self.max_class is not None
179 |
180 | if self.args.scramble_lang:
181 | # For each class, shuffle captions for each image
182 | (
183 | self.sub_meta_lang,
184 | self.sub_meta_lang_length,
185 | self.sub_meta_lang_mask,
186 | ) = lang_utils.shuffle_lang_class(
187 | self.sub_meta_lang, self.sub_meta_lang_length, self.sub_meta_lang_mask
188 | )
189 |
190 | if self.args.scramble_lang_class:
191 | raise NotImplementedError
192 |
193 | if self.args.scramble_all:
194 | # Shuffle captions completely randomly
195 | (
196 | self.sub_meta_lang,
197 | self.sub_meta_lang_length,
198 | self.sub_meta_lang_mask,
199 | ) = lang_utils.shuffle_all_class(
200 | self.sub_meta_lang, self.sub_meta_lang_length, self.sub_meta_lang_mask
201 | )
202 |
203 | if self.max_img_per_class is not None:
204 | # Trim number of images available per class
205 | for cl in self.sub_meta.keys():
206 | self.sub_meta[cl] = self.sub_meta[cl][: self.max_img_per_class]
207 | self.sub_meta_lang[cl] = self.sub_meta_lang[cl][
208 | : self.max_img_per_class
209 | ]
210 | self.sub_meta_lang_length[cl] = self.sub_meta_lang_length[cl][
211 | : self.max_img_per_class
212 | ]
213 | self.sub_meta_lang_mask[cl] = self.sub_meta_lang_mask[cl][
214 | : self.max_img_per_class
215 | ]
216 |
217 | if self.max_lang_per_class is not None:
218 | # Trim language available for each class; recycle language if not enough
219 | for cl in self.sub_meta.keys():
220 | self.sub_meta_lang[cl] = lang_utils.recycle_lang(
221 | self.sub_meta_lang[cl], self.max_lang_per_class
222 | )
223 | self.sub_meta_lang_length[cl] = lang_utils.recycle_lang(
224 | self.sub_meta_lang_length[cl], self.max_lang_per_class
225 | )
226 | self.sub_meta_lang_mask[cl] = lang_utils.recycle_lang(
227 | self.sub_meta_lang_mask[cl], self.max_lang_per_class
228 | )
229 |
230 | self.sub_dataloader = []
231 | sub_data_loader_params = dict(
232 | batch_size=batch_size,
233 | shuffle=True,
234 | num_workers=0, # use main thread only or may receive multiple batches
235 | pin_memory=False,
236 | )
237 | for i, cl in enumerate(self.cl_list):
238 | sub_dataset = SubDataset(
239 | self.name,
240 | self.sub_meta[cl],
241 | cl,
242 | sub_meta_lang=self.sub_meta_lang[cl],
243 | sub_meta_lang_length=self.sub_meta_lang_length[cl],
244 | sub_meta_lang_mask=self.sub_meta_lang_mask[cl],
245 | transform=transform,
246 | n_caption=self.args.n_caption,
247 | args=self.args,
248 | max_lang_per_class=self.max_lang_per_class,
249 | )
250 | self.sub_dataloader.append(
251 | torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params)
252 | )
253 |
254 | def __getitem__(self, i):
255 | return next(iter(self.sub_dataloader[i]))
256 |
257 | def __len__(self):
258 | return len(self.sub_dataloader)
259 |
260 |
261 | class SubDataset:
262 | def __init__(
263 | self,
264 | name,
265 | sub_meta,
266 | cl,
267 | sub_meta_lang=None,
268 | sub_meta_lang_length=None,
269 | sub_meta_lang_mask=None,
270 | transform=transforms.ToTensor(),
271 | target_transform=identity,
272 | n_caption=10,
273 | args=None,
274 | max_lang_per_class=None,
275 | ):
276 | self.name = name
277 | self.sub_meta = sub_meta
278 | self.sub_meta_lang = sub_meta_lang
279 | self.sub_meta_lang_length = sub_meta_lang_length
280 | self.sub_meta_lang_mask = sub_meta_lang_mask
281 | self.cl = cl
282 | self.transform = transform
283 | self.target_transform = target_transform
284 | if not (1 <= n_caption <= 10):
285 | raise ValueError("Invalid # captions {}".format(n_caption))
286 | self.n_caption = n_caption
287 | cl_path = os.path.split(self.sub_meta[0])[0]
288 | self.img = dict(np.load(os.path.join(cl_path, "img.npz")))
289 |
290 | # Used if sampling from class
291 | self.args = args
292 | self.max_lang_per_class = max_lang_per_class
293 |
294 | def __getitem__(self, i):
295 | image_path = self.sub_meta[i]
296 | img = self.img[image_path]
297 | img = self.transform(img)
298 | target = self.target_transform(self.cl)
299 |
300 | if self.n_caption == 1:
301 | lang_idx = 0
302 | else:
303 | lang_idx = random.randint(min(self.n_caption, len(self.sub_meta_lang[i])))
304 |
305 | if self.args.sample_class_lang:
306 | # Sample from all language, rather than the ith image
307 | if self.max_lang_per_class is None:
308 | max_i = len(self.sub_meta_lang)
309 | else:
310 | max_i = min(self.max_lang_per_class, len(self.sub_meta_lang))
311 | which_img_lang_i = random.randint(0, max_i)
312 | else:
313 | which_img_lang_i = i
314 |
315 | lang = self.sub_meta_lang[which_img_lang_i][lang_idx]
316 | lang_length = self.sub_meta_lang_length[which_img_lang_i][lang_idx]
317 | lang_mask = self.sub_meta_lang_mask[which_img_lang_i][lang_idx]
318 |
319 | return img, target, (lang, lang_length, lang_mask)
320 |
321 | def __len__(self):
322 | return len(self.sub_meta)
323 |
324 |
325 | class EpisodicBatchSampler(object):
326 | def __init__(self, n_classes, n_way, n_episodes):
327 | self.n_classes = n_classes
328 | self.n_way = n_way
329 | self.n_episodes = n_episodes
330 |
331 | def __len__(self):
332 | return self.n_episodes
333 |
334 | def __iter__(self):
335 | for i in range(self.n_episodes):
336 | yield torch.randperm(self.n_classes)[: self.n_way]
337 |
--------------------------------------------------------------------------------
/birds/fewshot/data/lang_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for processing language datasets
3 | """
4 |
5 | import os
6 | from collections import defaultdict
7 |
8 | import numpy as np
9 | import torch
10 | from numpy import random
11 | import torchfile
12 |
13 | SOS_TOKEN = ""
14 | EOS_TOKEN = ""
15 | PAD_TOKEN = ""
16 |
17 | COLOR_WORDS = set(
18 | [
19 | "amaranth",
20 | "charcoal",
21 | "amber",
22 | "amethyst",
23 | "apricot",
24 | "aquamarine",
25 | "azure",
26 | "baby blue",
27 | "beige",
28 | "black",
29 | "blue",
30 | "blush",
31 | "bronze",
32 | "brown",
33 | "burgundy",
34 | "byzantium",
35 | "carmine",
36 | "cerise",
37 | "cerulean",
38 | "champagne",
39 | "chartreuse",
40 | "chocolate",
41 | "cobalt",
42 | "coffee",
43 | "copper",
44 | "coral",
45 | "crimson",
46 | "cyan",
47 | "desert",
48 | "electric",
49 | "emerald",
50 | "erin",
51 | "gold",
52 | "gray",
53 | "grey",
54 | "green",
55 | "harlequin",
56 | "indigo",
57 | "ivory",
58 | "jade",
59 | "jungle",
60 | "lavender",
61 | "lemon",
62 | "lilac",
63 | "lime",
64 | "magenta",
65 | "magenta",
66 | "maroon",
67 | "mauve",
68 | "navy",
69 | "ochre",
70 | "olive",
71 | "orange",
72 | "orange",
73 | "orchid",
74 | "peach",
75 | "pear",
76 | "periwinkle",
77 | "persian",
78 | "pink",
79 | "plum",
80 | "prussian",
81 | "puce",
82 | "purple",
83 | "raspberry",
84 | "red",
85 | "red",
86 | "rose",
87 | "ruby",
88 | "salmon",
89 | "sangria",
90 | "sapphire",
91 | "scarlet",
92 | "silver",
93 | "slate",
94 | "spring",
95 | "spring",
96 | "tan",
97 | "taupe",
98 | "teal",
99 | "turquoise",
100 | "ultramarine",
101 | "violet",
102 | "viridian",
103 | "white",
104 | "yellow",
105 | "reddish",
106 | "yellowish",
107 | "greenish",
108 | "orangeish",
109 | "orangish",
110 | "blackish",
111 | "pinkish",
112 | "dark",
113 | "light",
114 | "bright",
115 | "greyish",
116 | "grayish",
117 | "brownish",
118 | "beigish",
119 | "aqua",
120 | ]
121 | )
122 |
123 |
124 | def filter_language(lang_tensor, language_filter, vocab):
125 | """
126 | Filter language, keeping or discarding color words
127 |
128 | :param lang_tensor: torch.Tensor of shape (n_imgs, lang_per_img,
129 | max_lang_len); language to be filtered
130 | :param language_filter: either 'color' or 'nocolor'; what language to
131 | filter out
132 | :param vocab: the vocabulary (so we know what indexes to remove)
133 |
134 | :returns: torch.Tensor of same shape as `lang_tensor` with either color or
135 | non-color words removed
136 | """
137 | assert language_filter in ["color", "nocolor"]
138 |
139 | cw = set(vocab[cw] for cw in COLOR_WORDS if cw in vocab)
140 |
141 | new_lang_tensor = torch.ones_like(lang_tensor)
142 | for bird_caps_i in range(lang_tensor.shape[0]):
143 | bird_caps = lang_tensor[bird_caps_i]
144 | new_bird_caps = torch.ones_like(bird_caps)
145 | for bird_cap_i in range(bird_caps.shape[0]):
146 | bird_cap = bird_caps[bird_cap_i]
147 | new_bird_cap = torch.ones_like(bird_cap)
148 | new_w_i = 0
149 | for w in bird_cap:
150 | is_cw = w.item() in cw
151 | if (language_filter == "color" and is_cw) or (
152 | language_filter == "nocolor" and not is_cw
153 | ):
154 | new_bird_cap[new_w_i] = w
155 | new_w_i += 1
156 | if new_bird_cap[0].item() == 1:
157 | # FIXME: Here we're just choosing an arbitrary randomly
158 | # mispelled token; make a proper UNK token.
159 | new_bird_cap[0] = 5724
160 | new_bird_caps[bird_cap_i] = new_bird_cap
161 | new_lang_tensor[bird_caps_i] = new_bird_caps
162 | return new_lang_tensor
163 |
164 |
165 | def shuffle_language(lang_tensor):
166 | """
167 | Scramble words in language
168 |
169 | :param lang_tensor: torch.Tensor of shape (n_img, lang_per_img, max_lang_len)
170 |
171 | :returns: torch.Tensor of same shape, but with words randomly scrambled
172 | """
173 | new_lang_tensor = torch.ones_like(lang_tensor)
174 | for bird_caps_i in range(lang_tensor.shape[0]):
175 | bird_caps = lang_tensor[bird_caps_i]
176 | new_bird_caps = torch.ones_like(bird_caps)
177 | for bird_cap_i in range(bird_caps.shape[0]):
178 | bird_cap = bird_caps[bird_cap_i]
179 | new_bird_cap = torch.ones_like(bird_cap)
180 | bird_cap_list = []
181 | for w in bird_cap.numpy():
182 | if w != 1:
183 | bird_cap_list.append(w)
184 | else:
185 | break
186 | random.shuffle(bird_cap_list)
187 | bird_cap_shuf = torch.tensor(
188 | bird_cap_list, dtype=new_bird_cap.dtype, requires_grad=False
189 | )
190 | new_bird_cap[: len(bird_cap_list)] = bird_cap_shuf
191 | new_bird_caps[bird_cap_i] = new_bird_cap
192 | new_lang_tensor[bird_caps_i] = new_bird_caps
193 | return new_lang_tensor
194 |
195 |
196 | def get_lang_lengths(lang_tensor):
197 | """
198 | Get lengths of each caption
199 |
200 | :param lang_tensor: torch.Tensor of shape (n_img, lang_per_img, max_len)
201 | :returns: torch.Tensor of shape (n_img, lang_per_img)
202 | """
203 | max_lang_len = lang_tensor.shape[2]
204 | n_pad = torch.sum(lang_tensor == 0, dim=2)
205 | lang_lengths = max_lang_len - n_pad
206 | return lang_lengths
207 |
208 |
209 | def get_lang_masks(lang_lengths, max_len=32):
210 | """
211 | Given lang lengths, convert to masks
212 |
213 | :param lang_lengths: torch.tensor of shape (n_imgs, lang_per_img)
214 |
215 | returns: torch.BoolTensor of shape (n_imgs, lang_per_img, max_len), binary
216 | mask with 0s in token spots and 1s in padding spots
217 | """
218 | mask = torch.ones(lang_lengths.shape + (max_len,), dtype=torch.bool)
219 | for i in range(lang_lengths.shape[0]):
220 | for j in range(lang_lengths.shape[1]):
221 | this_ll = lang_lengths[i, j]
222 | mask[i, j, :this_ll] = 0
223 | return mask
224 |
225 |
226 | def add_sos_eos(lang_tensor, lang_lengths, vocab):
227 | """
228 | Pad language tensors
229 |
230 | :param lang: torch.Tensor of shape (n_imgs, n_lang_per_img, max_len)
231 | :param lang_lengths: torch.Tensor of shape (n_imgs, n_lang_per_img)
232 | :param vocab: dictionary from words -> idxs
233 |
234 | :returns: (lang, lang_lengths) where lang has SOS and EOS tokens added, and
235 | lang_lengths have all been increased by 2 (to account for SOS/EOS)
236 | """
237 | sos_idx = vocab[SOS_TOKEN]
238 | eos_idx = vocab[EOS_TOKEN]
239 | lang_tensor_padded = torch.zeros(
240 | lang_tensor.shape[0],
241 | lang_tensor.shape[1],
242 | lang_tensor.shape[2] + 2,
243 | dtype=torch.int64,
244 | )
245 | lang_tensor_padded[:, :, 0] = sos_idx
246 | lang_tensor_padded[:, :, 1:-1] = lang_tensor
247 | for i in range(lang_tensor_padded.shape[0]):
248 | for j in range(lang_tensor_padded.shape[1]):
249 | ll = lang_lengths[i, j]
250 | lang_tensor_padded[
251 | i, j, ll + 1
252 | ] = eos_idx # + 1 accounts for sos token already there
253 | return lang_tensor_padded, lang_lengths + 2
254 |
255 |
256 | def shuffle_lang_class(lang, lang_length, lang_mask):
257 | """
258 | For each class, shuffle captions across images
259 |
260 | :param lang: dict from class -> list of languages for that class
261 | :param lang_length: dict from class -> list of language lengths for that class
262 | :param lang_mask: list of language masks
263 |
264 | :returns: (new_lang, new_lang_length, new_lang_mask): tuple of new language
265 | dictionaries representing the modified language
266 | """
267 | new_lang = {}
268 | new_lang_length = {}
269 | new_lang_mask = {}
270 | for y in lang:
271 | # FIXME: Make this seedable
272 | img_range = np.arange(len(lang[y]))
273 | random.shuffle(img_range)
274 | nlang = []
275 | nlang_length = []
276 | nlang_mask = []
277 | for lang_i in img_range:
278 | nlang.append(lang[y][lang_i])
279 | nlang_length.append(lang_length[y][lang_i])
280 | nlang_mask.append(lang_mask[y][lang_i])
281 | new_lang[y] = nlang
282 | new_lang_length[y] = nlang_length
283 | new_lang_mask[y] = nlang_mask
284 | return new_lang, new_lang_length, new_lang_mask
285 |
286 |
287 | def shuffle_all_class(lang, lang_length, lang_mask):
288 | """
289 | Shuffle captions completely randomly across all images and classes
290 |
291 | :param lang: dict from class -> list of languages for that class
292 | :param lang_length: dict from class -> list of language lengths for that class
293 | :param lang_mask: list of language masks
294 |
295 | :returns: (new_lang, new_lang_length, new_lang_mask): tuple of new language
296 | dictionaries representing the modified language
297 | """
298 | lens = [[(m, j) for j in range(len(lang[m]))] for m in lang.keys()]
299 | lens = [item for sublist in lens for item in sublist]
300 | shuffled_lens = lens[:]
301 | random.shuffle(shuffled_lens)
302 | new_lang = defaultdict(list)
303 | new_lang_length = defaultdict(list)
304 | new_lang_mask = defaultdict(list)
305 | for (m, _), (new_m, new_i) in zip(lens, shuffled_lens):
306 | new_lang[m].append(lang[new_m][new_i])
307 | new_lang_length[m].append(lang_length[new_m][new_i])
308 | new_lang_mask[m].append(lang_mask[new_m][new_i])
309 | assert all(len(new_lang[m]) == len(lang[m]) for m in lang.keys())
310 | return dict(new_lang), dict(new_lang_length), dict(new_lang_mask)
311 |
312 |
313 | def load_vocab(lang_dir):
314 | """
315 | Load torch-serialized vocabulary from the lang dir
316 |
317 | :param: lang_dir: str, path to language directory
318 | :returns: dictionary from words -> idxs
319 | """
320 | vocab = torchfile.load(os.path.join(lang_dir, "vocab_c10.t7"))
321 | vocab = {k: v - 1 for k, v in vocab.items()} # Decrement vocab
322 | vocab = {k.decode("utf-8"): v for k, v in vocab.items()} # Unicode
323 | # Add SOS/EOS tokens
324 | sos_idx = len(vocab)
325 | vocab[SOS_TOKEN] = sos_idx
326 | eos_idx = len(vocab)
327 | vocab[EOS_TOKEN] = eos_idx
328 | return vocab
329 |
330 |
331 | def glove_init(vocab, emb_size=300):
332 | """
333 | Initialize vocab with glove vectors. Requires spacy and en_vectors_web_lg
334 | spacy model
335 |
336 | :param vocab: dict from words -> idxs
337 | :param emb_size: int, size of embeddings (should be 300 for spacy glove
338 | vectors)
339 |
340 | :returns: torch.FloatTensor of size (len(vocab), emb_size), with glove
341 | embedding if exists, else zeros
342 | """
343 | import spacy
344 |
345 | try:
346 | nlp = spacy.load("en_vectors_web_lg", disable=["tagger", "parser", "ner"])
347 | except OSError:
348 | # Try loading for current directory (codalab)
349 | nlp = spacy.load(
350 | "./en_vectors_web_lg/en_vectors_web_lg-2.1.0/",
351 | disable=["tagger", "parser", "ner"],
352 | )
353 |
354 | vecs = np.zeros((len(vocab), emb_size), dtype=np.float32)
355 | vec_ids_sort = sorted(vocab.items(), key=lambda x: x[1])
356 | sos_idx = vocab[SOS_TOKEN]
357 | eos_idx = vocab[EOS_TOKEN]
358 | pad_idx = vocab[PAD_TOKEN]
359 | for vec, vecid in vec_ids_sort:
360 | if vecid in (pad_idx, sos_idx, eos_idx):
361 | v = np.zeros(emb_size, dtype=np.float32)
362 | else:
363 | v = nlp(vec)[0].vector
364 | vecs[vecid] = v
365 | vecs = torch.as_tensor(vecs)
366 | return vecs
367 |
368 |
369 | def get_special_indices(vocab):
370 | """
371 | Get indices of special items from vocab.
372 | :param vocab: dictionary from words -> idxs
373 | :returns: dictionary from {sos_index, eos_index, pad_index} -> tokens
374 | """
375 | return {
376 | name: vocab[token]
377 | for name, token in [
378 | ("sos_index", SOS_TOKEN),
379 | ("eos_index", EOS_TOKEN),
380 | ("pad_index", PAD_TOKEN),
381 | ]
382 | }
383 |
384 |
385 | def recycle_lang(langs, max_lang):
386 | """
387 | Given a limited amount of language, reuse `max_lang` times
388 | :param langs: list of languages
389 | :param max_lang: how long the full language tensor should be
390 |
391 | :returns: new_langs, a list of length `max_lang` created by cycling through
392 | `langs`
393 | """
394 | new_langs = []
395 | for i in range(len(langs)):
396 | new_langs.append(langs[i % max_lang])
397 | return new_langs
398 |
--------------------------------------------------------------------------------
/birds/fewshot/io_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains argument parsers and utilities for saving and loading metrics and
3 | models.
4 | """
5 |
6 | import argparse
7 | import glob
8 | import os
9 |
10 | import numpy as np
11 |
12 | import backbone
13 |
14 |
15 | model_dict = dict(
16 | Conv4=backbone.Conv4,
17 | Conv4NP=backbone.Conv4NP,
18 | Conv4S=backbone.Conv4S,
19 | Conv6=backbone.Conv6,
20 | ResNet10=backbone.ResNet10,
21 | ResNet18=backbone.ResNet18,
22 | PretrainedResNet18=backbone.PretrainedResNet18,
23 | ResNet34=backbone.ResNet34,
24 | ResNet50=backbone.ResNet50,
25 | ResNet101=backbone.ResNet101,
26 | )
27 |
28 |
29 | def parse_args(script):
30 | parser = argparse.ArgumentParser(
31 | description="few-shot script %s" % (script),
32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
33 | )
34 | parser.add_argument(
35 | "--checkpoint_dir",
36 | required=True,
37 | help="Specify checkpoint dir (if none, automatically generate)",
38 | )
39 | parser.add_argument("--model", default="Conv4", help="Choice of backbone")
40 | parser.add_argument("--lsl", action="store_true")
41 | parser.add_argument(
42 | "--l3", action="store_true", help="Use l3 (do not need to --lsl)"
43 | )
44 | parser.add_argument("--l3_n_infer", type=int, default=10, help="Number to sample")
45 | parser.add_argument(
46 | "--rnn_type", choices=["gru", "lstm"], default="gru", help="Language RNN type"
47 | )
48 | parser.add_argument(
49 | "--rnn_num_layers", default=1, type=int, help="Language RNN num layers"
50 | )
51 | parser.add_argument(
52 | "--rnn_dropout", default=0.0, type=float, help="Language RNN dropout"
53 | )
54 | parser.add_argument(
55 | "--lang_supervision",
56 | default="class",
57 | choices=["instance", "class"],
58 | help="At what level to supervise with language?",
59 | )
60 | parser.add_argument("--glove_init", action="store_true")
61 | parser.add_argument(
62 | "--freeze_emb", action="store_true", help="Freeze LM word embedding layer"
63 | )
64 |
65 | langparser = parser.add_argument_group("language settings")
66 | langparser.add_argument(
67 | "--shuffle_lang", action="store_true", help="Shuffle words in caption"
68 | )
69 | langparser.add_argument(
70 | "--scramble_lang",
71 | action="store_true",
72 | help="Scramble captions -> images mapping in a class",
73 | )
74 | langparser.add_argument(
75 | "--sample_class_lang",
76 | action="store_true",
77 | help="Sample language randomly from class, rather than getting lang assoc. w/ img",
78 | )
79 | langparser.add_argument(
80 | "--scramble_all",
81 | action="store_true",
82 | help="Scramble captions -> images mapping across all classes",
83 | )
84 | langparser.add_argument(
85 | "--scramble_lang_class",
86 | action="store_true",
87 | help="Scramble captions -> images mapping across all classes, but keep classes consistent",
88 | )
89 | langparser.add_argument(
90 | "--language_filter",
91 | default="all",
92 | choices=["all", "color", "nocolor"],
93 | help="What language to use",
94 | )
95 |
96 | parser.add_argument(
97 | "--lang_hidden_size", type=int, default=200, help="Language decoder hidden size"
98 | )
99 | parser.add_argument(
100 | "--lang_emb_size", type=int, default=300, help="Language embedding hidden size"
101 | )
102 | parser.add_argument(
103 | "--lang_lambda", type=float, default=5, help="Weight on language loss"
104 | )
105 |
106 | parser.add_argument(
107 | "--n_caption",
108 | type=int,
109 | default=1,
110 | choices=list(range(1, 11)),
111 | help="How many captions to use for pretraining",
112 | )
113 | parser.add_argument(
114 | "--max_class", type=int, default=None, help="Max number of training classes"
115 | )
116 | parser.add_argument(
117 | "--max_img_per_class",
118 | type=int,
119 | default=None,
120 | help="Max number of images per training class",
121 | )
122 | parser.add_argument(
123 | "--max_lang_per_class",
124 | type=int,
125 | default=None,
126 | help="Max number of language per training class (recycled among images)",
127 | )
128 | parser.add_argument(
129 | "--train_n_way", default=5, type=int, help="class num to classify for training"
130 | )
131 | parser.add_argument(
132 | "--test_n_way",
133 | default=5,
134 | type=int,
135 | help="class num to classify for testing (validation) ",
136 | )
137 | parser.add_argument(
138 | "--n_shot",
139 | default=1,
140 | type=int,
141 | help="number of labeled data in each class, same as n_support",
142 | )
143 | parser.add_argument(
144 | "--n_workers",
145 | default=4,
146 | type=int,
147 | help="Use this many workers for loading data",
148 | )
149 | parser.add_argument(
150 | "--debug", action="store_true", help="Inspect generated language"
151 | )
152 | parser.add_argument(
153 | "--seed", type=int, default=None, help="random seed (torch only; not numpy)"
154 | )
155 |
156 | if script == "train":
157 | parser.add_argument(
158 | "--n", default=1, type=int, help="Train run number (used for metrics)"
159 | )
160 | parser.add_argument(
161 | "--optimizer",
162 | default="adam",
163 | choices=["adam", "amsgrad", "rmsprop"],
164 | help="Optimizer",
165 | )
166 | parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate")
167 | parser.add_argument(
168 | "--rnn_lr_scale",
169 | default=1.0,
170 | type=float,
171 | help="Scale the RNN lr by this amount of the original lr",
172 | )
173 | parser.add_argument("--save_freq", default=50, type=int, help="Save frequency")
174 | parser.add_argument("--start_epoch", default=0, type=int, help="Starting epoch")
175 | parser.add_argument(
176 | "--stop_epoch", default=600, type=int, help="Stopping epoch"
177 | ) # for meta-learning methods, each epoch contains 100 episodes
178 | parser.add_argument(
179 | "--resume",
180 | action="store_true",
181 | help="continue from previous trained model with largest epoch",
182 | )
183 | elif script == "test":
184 | parser.add_argument(
185 | "--split",
186 | default="novel",
187 | choices=["base", "val", "novel"],
188 | help="which split to evaluate on",
189 | )
190 | parser.add_argument(
191 | "--save_iter",
192 | default=-1,
193 | type=int,
194 | help="saved feature from the model trained in x epoch, use the best model if x is -1",
195 | )
196 | parser.add_argument(
197 | "--save_embeddings",
198 | action="store_true",
199 | help="Save embeddings from language model, then exit (requires --lsl)",
200 | )
201 | parser.add_argument(
202 | "--embeddings_file",
203 | default="./embeddings.txt",
204 | help="File to save embeddings to",
205 | )
206 | parser.add_argument(
207 | "--embeddings_metadata",
208 | default="./embeddings_metadata.txt",
209 | help="File to save embedding metadata to (currently just words)",
210 | )
211 | parser.add_argument(
212 | "--record_file",
213 | default="./record/results.txt",
214 | help="Where to write results to",
215 | )
216 | else:
217 | raise ValueError("Unknown script")
218 |
219 | args = parser.parse_args()
220 |
221 | if "save_embeddings" in args and (args.save_embeddings and not args.lsl):
222 | parser.error("Must set --lsl to save embeddings")
223 |
224 | if args.glove_init and not (args.lsl or args.l3):
225 | parser.error("Must set --lsl to init with glove")
226 |
227 | return args
228 |
229 |
230 | def get_assigned_file(checkpoint_dir, num):
231 | assign_file = os.path.join(checkpoint_dir, "{:d}.tar".format(num))
232 | return assign_file
233 |
234 |
235 | def get_resume_file(checkpoint_dir):
236 | filelist = glob.glob(os.path.join(checkpoint_dir, "*.tar"))
237 | if len(filelist) == 0:
238 | return None
239 |
240 | filelist = [x for x in filelist if os.path.basename(x) != "best_model.tar"]
241 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist])
242 | max_epoch = np.max(epochs)
243 | resume_file = os.path.join(checkpoint_dir, "{:d}.tar".format(max_epoch))
244 | return resume_file
245 |
246 |
247 | def get_best_file(checkpoint_dir):
248 | best_file = os.path.join(checkpoint_dir, "best_model.tar")
249 | if os.path.isfile(best_file):
250 | return best_file
251 | else:
252 | return get_resume_file(checkpoint_dir)
253 |
--------------------------------------------------------------------------------
/birds/fewshot/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import language, protonet
2 |
--------------------------------------------------------------------------------
/birds/fewshot/models/language.py:
--------------------------------------------------------------------------------
1 | """
2 | Language encoders/decoders.
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.nn.utils.rnn as rnn_utils
10 |
11 |
12 | class TextProposal(nn.Module):
13 | r"""Reverse proposal model, estimating:
14 | argmax_lambda log q(w_i|x_1, y_1, ..., x_n, y_n; lambda)
15 | approximation to the distribution of descriptions.
16 | Because they use only positive labels, it actually simplifies to
17 | argmax_lambda log q(w_i|x_1, ..., x_4; lambda)
18 | https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/model.py
19 | """
20 |
21 | def __init__(
22 | self,
23 | embedding_module,
24 | input_size=1600,
25 | hidden_size=512,
26 | project_input=False,
27 | rnn="gru",
28 | num_layers=1,
29 | dropout=0.2,
30 | vocab=None,
31 | sos_index=0,
32 | eos_index=0,
33 | pad_index=0,
34 | ):
35 | super(TextProposal, self).__init__()
36 | self.embedding = embedding_module
37 | self.embedding_dim = embedding_module.embedding_dim
38 | self.vocab_size = embedding_module.num_embeddings
39 | self.input_size = input_size
40 | self.hidden_size = hidden_size
41 | self.project_input = project_input
42 | self.num_layers = num_layers
43 | self.rnn_type = rnn
44 | if self.project_input:
45 | self.proj_h = nn.Linear(self.input_size, self.hidden_size)
46 | if self.rnn_type == "lstm":
47 | self.proj_c = nn.Linear(self.input_size, self.hidden_size)
48 |
49 | if rnn == "gru":
50 | RNN = nn.GRU
51 | elif rnn == "lstm":
52 | RNN = nn.LSTM
53 | else:
54 | raise ValueError("Unknown RNN model {}".format(rnn))
55 |
56 | # Init the RNN
57 | self.rnn = None
58 | self.rnn = RNN(
59 | self.embedding_dim,
60 | hidden_size,
61 | num_layers=num_layers,
62 | dropout=dropout if num_layers > 1 else 0.0,
63 | batch_first=True,
64 | )
65 | self.dropout = nn.Dropout(p=dropout)
66 |
67 | # Projection from RNN hidden size to output vocab
68 | self.outputs2vocab = nn.Linear(hidden_size, self.vocab_size)
69 | self.vocab = vocab
70 | # Get sos/eos/pad indices
71 | self.sos_index = sos_index
72 | self.eos_index = eos_index
73 | self.pad_index = pad_index
74 | self.rev_vocab = {v: k for k, v in vocab.items()}
75 |
76 | def forward(self, feats, seq, length):
77 | # feats is from example images
78 | batch_size = seq.size(0)
79 |
80 | if self.project_input:
81 | feats_h = self.proj_h(feats)
82 | if self.rnn_type == "lstm":
83 | feats_c = self.proj_c(feats)
84 | else:
85 | feats_h = feats
86 | feats_c = feats
87 |
88 | if batch_size > 1:
89 | sorted_lengths, sorted_idx = torch.sort(length, descending=True)
90 | seq = seq[sorted_idx]
91 | feats_h = feats_h[sorted_idx]
92 | if self.rnn_type == "lstm":
93 | feats_c = feats_c[sorted_idx]
94 |
95 | # Construct hidden states by expanding to number of layers
96 | feats_h = feats_h.unsqueeze(0).expand(self.num_layers, -1, -1).contiguous()
97 | if self.rnn_type == "lstm":
98 | feats_c = feats_c.unsqueeze(0).expand(self.num_layers, -1, -1).contiguous()
99 | hidden = (feats_h, feats_c)
100 | else:
101 | hidden = feats_h
102 |
103 | # embed your sequences
104 | embed_seq = self.embedding(seq)
105 |
106 | # shape = (seq_len, batch, hidden_dim)
107 | packed_input = rnn_utils.pack_padded_sequence(
108 | embed_seq, sorted_lengths, batch_first=True
109 | )
110 | packed_output, _ = self.rnn(packed_input, hidden)
111 | output = rnn_utils.pad_packed_sequence(packed_output, batch_first=True)
112 | output = output[0].contiguous()
113 |
114 | if batch_size > 1:
115 | _, reversed_idx = torch.sort(sorted_idx)
116 | output = output[reversed_idx]
117 |
118 | max_length = output.size(1)
119 | output_2d = output.view(batch_size * max_length, self.hidden_size)
120 | output_2d_dropout = self.dropout(output_2d)
121 | outputs_2d = self.outputs2vocab(output_2d_dropout)
122 | outputs = outputs_2d.view(batch_size, max_length, self.vocab_size)
123 |
124 | return outputs
125 |
126 | def sample(self, feats, greedy=False, to_text=False):
127 | """Generate from image features using greedy search."""
128 | with torch.no_grad():
129 | if self.project_input:
130 | feats_h = self.proj_h(feats)
131 | states = feats_h
132 | if self.rnn_type == "lstm":
133 | feats_c = self.proj_c(feats)
134 | states = (feats_h, feats_c)
135 | else:
136 | states = feats
137 |
138 | batch_size = states.size(0)
139 |
140 | # initialize hidden states using image features
141 | states = states.unsqueeze(0)
142 |
143 | # first input is SOS token
144 | inputs = np.array([self.sos_index for _ in range(batch_size)])
145 | inputs = torch.from_numpy(inputs)
146 | inputs = inputs.unsqueeze(1)
147 | inputs = inputs.to(feats.device)
148 |
149 | # save SOS as first generated token
150 | inputs_npy = inputs.squeeze(1).cpu().numpy()
151 | sampled_ids = [[w] for w in inputs_npy]
152 |
153 | # compute embeddings
154 | inputs = self.embedding(inputs)
155 |
156 | # Here, we use the same as max caption length
157 | for i in range(32): # like in jacobs repo
158 | outputs, states = self.rnn(inputs, states) # outputs: (L=1,B,H)
159 | outputs = outputs.squeeze(1) # outputs: (B,H)
160 | outputs = self.outputs2vocab(outputs) # outputs: (B,V)
161 |
162 | if greedy:
163 | predicted = outputs.max(1)[1]
164 | predicted = predicted.unsqueeze(1)
165 | else:
166 | outputs = F.softmax(outputs, dim=1)
167 | predicted = torch.multinomial(outputs, 1)
168 |
169 | predicted_npy = predicted.squeeze(1).cpu().numpy()
170 | predicted_lst = predicted_npy.tolist()
171 |
172 | for w, so_far in zip(predicted_lst, sampled_ids):
173 | if so_far[-1] != self.eos_index:
174 | so_far.append(w)
175 |
176 | inputs = predicted
177 | inputs = self.embedding(inputs) # inputs: (L=1,B,E)
178 |
179 | sampled_lengths = [len(text) for text in sampled_ids]
180 | sampled_lengths = np.array(sampled_lengths)
181 |
182 | max_length = max(sampled_lengths)
183 | padded_ids = np.ones((batch_size, max_length)) * self.pad_index
184 |
185 | for i in range(batch_size):
186 | padded_ids[i, : sampled_lengths[i]] = sampled_ids[i]
187 |
188 | sampled_lengths = torch.from_numpy(sampled_lengths).long()
189 | sampled_ids = torch.from_numpy(padded_ids).long()
190 |
191 | if to_text:
192 | sampled_text = self.to_text(sampled_ids)
193 | return sampled_text, sampled_lengths
194 | return sampled_ids, sampled_lengths
195 |
196 | def to_text(self, sampled_ids):
197 | texts = []
198 | for sample in sampled_ids.numpy():
199 | texts.append(" ".join([self.rev_vocab[v] for v in sample if v != 0]))
200 | return np.array(texts, dtype=np.unicode_)
201 |
202 |
203 | class TextRep(nn.Module):
204 | r"""Deterministic Bowman et. al. model to form
205 | text representation.
206 |
207 | Again, this uses 512 hidden dimensions.
208 | """
209 |
210 | def __init__(
211 | self, embedding_module, hidden_size=512, rnn="gru", num_layers=1, dropout=0.2
212 | ):
213 | super(TextRep, self).__init__()
214 | self.embedding = embedding_module
215 | self.embedding_dim = embedding_module.embedding_dim
216 | if rnn == "gru":
217 | RNN = nn.GRU
218 | elif rnn == "lstm":
219 | RNN = nn.LSTM
220 | else:
221 | raise ValueError("Unknown RNN model {}".format(rnn))
222 | self.rnn = RNN(
223 | self.embedding_dim,
224 | hidden_size,
225 | num_layers=num_layers,
226 | dropout=dropout if num_layers > 1 else 0.0,
227 | )
228 | self.hidden_size = hidden_size
229 |
230 | def forward(self, seq, length):
231 | batch_size = seq.size(0)
232 |
233 | if batch_size > 1:
234 | sorted_lengths, sorted_idx = torch.sort(length, descending=True)
235 | seq = seq[sorted_idx]
236 |
237 | # reorder from (B,L,D) to (L,B,D)
238 | seq = seq.transpose(0, 1)
239 |
240 | # embed your sequences
241 | embed_seq = self.embedding(seq)
242 |
243 | packed = rnn_utils.pack_padded_sequence(
244 | embed_seq,
245 | sorted_lengths.data.tolist() if batch_size > 1 else length.data.tolist(),
246 | )
247 |
248 | _, hidden = self.rnn(packed)
249 | hidden = hidden[-1, ...]
250 |
251 | if batch_size > 1:
252 | _, reversed_idx = torch.sort(sorted_idx)
253 | hidden = hidden[reversed_idx]
254 |
255 | return hidden
256 |
--------------------------------------------------------------------------------
/birds/fewshot/run_cl.py:
--------------------------------------------------------------------------------
1 | """
2 | Run everything on codalab.
3 | """
4 |
5 | import json
6 | import os
7 | from subprocess import check_call
8 |
9 |
10 | if __name__ == "__main__":
11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
12 |
13 | parser = ArgumentParser(
14 | description="Run everything on codalab",
15 | formatter_class=ArgumentDefaultsHelpFormatter,
16 | )
17 |
18 | cl_parser = parser.add_argument_group(
19 | "Codalab args", "args to control high level codalab eval"
20 | )
21 | cl_parser.add_argument(
22 | "--no_train", action="store_true", help="Don't run the train command"
23 | )
24 | cl_parser.add_argument(
25 | "--log_dir", default="./test/", help="Where to save metrics/models"
26 | )
27 | cl_parser.add_argument("--n", default=1, type=int, help="Number of runs")
28 |
29 | fparser = parser.add_argument_group(
30 | "Few shot args", "args to pass to few shot scripts"
31 | )
32 | fparser.add_argument("--model", default="Conv4")
33 | fparser.add_argument("--lsl", action="store_true")
34 | fparser.add_argument("--l3", action="store_true")
35 | fparser.add_argument("--l3_n_infer", type=int, default=10)
36 | fparser.add_argument("--rnn_type", choices=["gru", "lstm"], default="gru")
37 | fparser.add_argument("--rnn_num_layers", default=1, type=int)
38 | fparser.add_argument("--rnn_dropout", default=0.0, type=float)
39 | fparser.add_argument(
40 | "--language_filter", default="all", choices=["all", "color", "nocolor"]
41 | )
42 | fparser.add_argument(
43 | "--lang_supervision", default="instance", choices=["instance", "class"]
44 | )
45 | fparser.add_argument("--glove_init", action="store_true")
46 | fparser.add_argument("--freeze_emb", action="store_true")
47 | fparser.add_argument("--scramble_lang", action="store_true")
48 | fparser.add_argument("--sample_class_lang", action="store_true")
49 | fparser.add_argument("--scramble_all", action="store_true")
50 | fparser.add_argument("--shuffle_lang", action="store_true")
51 | fparser.add_argument("--scramble_lang_class", action="store_true")
52 | fparser.add_argument("--n_caption", choices=list(range(1, 11)), type=int, default=1)
53 | fparser.add_argument("--max_class", type=int, default=None)
54 | fparser.add_argument("--max_img_per_class", type=int, default=None)
55 | fparser.add_argument("--max_lang_per_class", type=int, default=None)
56 | fparser.add_argument("--lang_lambda", type=float, default=0.25)
57 | fparser.add_argument(
58 | "--save_freq", type=int, default=10000
59 | ) # In CL script, by default, never save, just keep best model
60 | fparser.add_argument("--lang_emb_size", type=int, default=300)
61 | fparser.add_argument("--lang_hidden_size", type=int, default=200)
62 | fparser.add_argument("--lr", type=float, default=1e-3)
63 | fparser.add_argument("--rnn_lr_scale", default=1.0, type=float)
64 | fparser.add_argument(
65 | "--optimizer", default="adam", choices=["adam", "amsgrad", "rmsprop"]
66 | )
67 | fparser.add_argument("--n_way", type=int, default=5)
68 | fparser.add_argument(
69 | "--test_n_way",
70 | type=int,
71 | default=None,
72 | help="Specify to change n_way eval at test",
73 | )
74 | fparser.add_argument("--n_shot", type=int, default=1)
75 | fparser.add_argument("--epochs", type=int, default=600)
76 | fparser.add_argument("--n_workers", type=int, default=4)
77 | fparser.add_argument("--resume", action="store_true")
78 | fparser.add_argument("--debug", action="store_true")
79 | fparser.add_argument("--seed", default=None, type=int)
80 |
81 | args = parser.parse_args()
82 |
83 | if args.test_n_way is None:
84 | args.test_n_way = args.n_way
85 |
86 | args.cl_dir = os.path.join(args.log_dir, "checkpoints")
87 | args.cl_record_file = os.path.join(args.log_dir, "results_novel.json")
88 | args.cl_args_file = os.path.join(args.log_dir, "args.json")
89 |
90 | os.makedirs(args.log_dir, exist_ok=True)
91 | if os.path.exists(args.cl_record_file):
92 | os.remove(args.cl_record_file)
93 |
94 | # Save arg metadata to root directory
95 | # Only save if training a model
96 | print("==== RUN_CL: PARAMS ====")
97 | argsv = vars(args)
98 | print(argsv)
99 | if not args.no_train:
100 | with open(args.cl_args_file, "w") as fout:
101 | json.dump(argsv, fout, sort_keys=True, indent=4, separators=(",", ": "))
102 |
103 | # Train
104 | for i in range(1, args.n + 1):
105 | if not args.no_train:
106 | print("==== RUN_CL ({}/{}): TRAIN ====".format(i, args.n))
107 | train_cmd = [
108 | "python3.7",
109 | "fewshot/train.py",
110 | "--model",
111 | args.model,
112 | "--n_shot",
113 | args.n_shot,
114 | "--train_n_way",
115 | args.n_way,
116 | "--test_n_way",
117 | args.test_n_way,
118 | "--stop_epoch",
119 | args.epochs,
120 | "--rnn_type",
121 | args.rnn_type,
122 | "--rnn_num_layers",
123 | args.rnn_num_layers,
124 | "--rnn_dropout",
125 | args.rnn_dropout,
126 | "--language_filter",
127 | args.language_filter,
128 | "--lang_lambda",
129 | args.lang_lambda,
130 | "--lang_hidden_size",
131 | args.lang_hidden_size,
132 | "--lang_supervision",
133 | args.lang_supervision,
134 | "--lang_emb_size",
135 | args.lang_emb_size,
136 | "--n_caption",
137 | args.n_caption,
138 | "--stop_epoch",
139 | args.epochs,
140 | "--checkpoint_dir",
141 | args.cl_dir,
142 | "--save_freq",
143 | args.save_freq,
144 | "--n",
145 | i,
146 | "--lr",
147 | args.lr,
148 | "--rnn_lr_scale",
149 | args.rnn_lr_scale,
150 | "--optimizer",
151 | args.optimizer,
152 | "--n_workers",
153 | args.n_workers,
154 | "--l3_n_infer",
155 | args.l3_n_infer,
156 | ]
157 | if args.seed is not None:
158 | train_cmd.extend(["--seed", args.seed])
159 | if args.max_class is not None:
160 | train_cmd.extend(["--max_class", args.max_class])
161 | if args.max_img_per_class is not None:
162 | train_cmd.extend(["--max_img_per_class", args.max_img_per_class])
163 | if args.max_lang_per_class is not None:
164 | train_cmd.extend(["--max_lang_per_class", args.max_lang_per_class])
165 | if args.lsl:
166 | train_cmd.append("--lsl")
167 | if args.l3:
168 | train_cmd.append("--l3")
169 | if args.glove_init:
170 | train_cmd.append("--glove_init")
171 | if args.freeze_emb:
172 | train_cmd.append("--freeze_emb")
173 | if args.shuffle_lang:
174 | train_cmd.append("--shuffle_lang")
175 | if args.scramble_lang:
176 | train_cmd.append("--scramble_lang")
177 | if args.sample_class_lang:
178 | train_cmd.append("--sample_class_lang")
179 | if args.scramble_all:
180 | train_cmd.append("--scramble_all")
181 | if args.scramble_lang_class:
182 | train_cmd.append("--scramble_lang_class")
183 | if args.resume:
184 | train_cmd.append("--resume")
185 | if args.debug:
186 | train_cmd.append("--debug")
187 | train_cmd = [str(x) for x in train_cmd]
188 | check_call(train_cmd)
189 |
190 | print("==== RUN_CL ({}/{}): TEST NOVEL ====".format(i, args.n))
191 | test_cmd = [
192 | "python3.7",
193 | "fewshot/test.py",
194 | "--model",
195 | args.model,
196 | "--n_shot",
197 | args.n_shot,
198 | "--test_n_way",
199 | args.test_n_way,
200 | "--rnn_type",
201 | args.rnn_type,
202 | "--rnn_num_layers",
203 | args.rnn_num_layers,
204 | "--rnn_dropout",
205 | args.rnn_dropout,
206 | "--language_filter",
207 | args.language_filter,
208 | "--lang_lambda",
209 | args.lang_lambda,
210 | "--lang_hidden_size",
211 | args.lang_hidden_size,
212 | "--lang_supervision",
213 | args.lang_supervision,
214 | "--lang_emb_size",
215 | args.lang_emb_size,
216 | "--checkpoint_dir",
217 | args.cl_dir,
218 | "--split",
219 | "novel",
220 | "--n_workers",
221 | args.n_workers,
222 | "--record_file",
223 | args.cl_record_file,
224 | "--l3_n_infer",
225 | args.l3_n_infer,
226 | ]
227 | if args.seed is not None:
228 | test_cmd.extend(["--seed", args.seed])
229 | if args.lsl:
230 | test_cmd.append("--lsl")
231 | if args.l3:
232 | test_cmd.append("--l3")
233 | if args.debug:
234 | test_cmd.append("--debug")
235 | test_cmd = [str(x) for x in test_cmd]
236 | check_call(test_cmd)
237 |
--------------------------------------------------------------------------------
/birds/fewshot/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Test script.
3 | """
4 |
5 | import json
6 | import os
7 | import sys
8 | import time
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.optim
14 | import torch.utils.data.sampler
15 |
16 | import constants
17 | from data import lang_utils
18 | from data.datamgr import SetDataManager, TransformLoader
19 | from io_utils import get_assigned_file, get_best_file, model_dict, parse_args
20 | from models.language import TextProposal, TextRep
21 | from models.protonet import ProtoNet
22 |
23 |
24 | if __name__ == "__main__":
25 | args = parse_args("test")
26 |
27 | if args.seed is not None:
28 | torch.manual_seed(args.seed)
29 |
30 | acc_all = []
31 |
32 | vocab = lang_utils.load_vocab(constants.LANG_DIR)
33 |
34 | l3_model = None
35 | lang_model = None
36 | if args.lsl or args.l3:
37 | embedding_model = nn.Embedding(len(vocab), args.lang_emb_size)
38 | lang_model = TextProposal(
39 | embedding_model,
40 | input_size=1600,
41 | hidden_size=args.lang_hidden_size,
42 | project_input=1600 != args.lang_hidden_size,
43 | rnn=args.rnn_type,
44 | num_layers=args.rnn_num_layers,
45 | dropout=args.rnn_dropout,
46 | vocab=vocab,
47 | **lang_utils.get_special_indices(vocab),
48 | )
49 |
50 | if args.l3:
51 | l3_model = TextRep(
52 | embedding_model,
53 | hidden_size=args.lang_hidden_size,
54 | rnn=args.rnn_type,
55 | num_layers=args.rnn_num_layers,
56 | dropout=args.rnn_dropout,
57 | )
58 | l3_model = l3_model.cuda()
59 |
60 | embedding_model = embedding_model.cuda()
61 | lang_model = lang_model.cuda()
62 |
63 | model = ProtoNet(
64 | model_dict[args.model],
65 | n_way=args.test_n_way,
66 | n_support=args.n_shot,
67 | # Language options
68 | lsl=args.lsl,
69 | language_model=lang_model,
70 | lang_supervision=args.lang_supervision,
71 | l3=args.l3,
72 | l3_model=l3_model,
73 | l3_n_infer=args.l3_n_infer,
74 | )
75 |
76 | model = model.cuda()
77 |
78 | if args.save_iter != -1:
79 | modelfile = get_assigned_file(args.checkpoint_dir, args.save_iter)
80 | else:
81 | modelfile = get_best_file(args.checkpoint_dir)
82 |
83 | if modelfile is not None:
84 | tmp = torch.load(modelfile)
85 | model.load_state_dict(
86 | tmp["state"],
87 | # If language was used for pretraining, ignore
88 | # the language model component here. If we want to use language,
89 | # make sure the model is loaded
90 | strict=args.lsl,
91 | )
92 |
93 | if args.save_embeddings:
94 | if args.lsl:
95 | weights = model.language_model.embedding.weight.detach().cpu().numpy()
96 | vocab_srt = sorted(list(vocab.items()), key=lambda x: x[1])
97 | vocab_srt = [v[0] for v in vocab_srt]
98 | with open(args.embeddings_file, "w") as fout:
99 | fout.write("\n".join(vocab_srt))
100 | fout.write("\n")
101 | np.savetxt(args.embeddings_metadata, weights, fmt="%f", delimiter="\t")
102 | sys.exit(0)
103 |
104 | # Run the test loop for 600 iterations
105 | ITER_NUM = 600
106 | N_QUERY = 15
107 |
108 | test_datamgr = SetDataManager(
109 | "CUB",
110 | 84,
111 | n_query=N_QUERY,
112 | n_way=args.test_n_way,
113 | n_support=args.n_shot,
114 | n_episode=ITER_NUM,
115 | args=args,
116 | )
117 | test_loader = test_datamgr.get_data_loader(
118 | os.path.join(constants.DATA_DIR, f"{args.split}.json"),
119 | aug=False,
120 | lang_dir=constants.LANG_DIR,
121 | normalize=False,
122 | vocab=vocab,
123 | )
124 | normalizer = TransformLoader(84).get_normalize()
125 |
126 | model.eval()
127 |
128 | acc_all = model.test_loop(
129 | test_loader,
130 | normalizer=normalizer,
131 | verbose=True,
132 | return_all=True,
133 | # Debug on first loop only
134 | debug=args.debug,
135 | debug_dir=os.path.split(args.checkpoint_dir)[0],
136 | )
137 | acc_mean = np.mean(acc_all)
138 | acc_std = np.std(acc_all)
139 | print(
140 | "%d Test Acc = %4.2f%% +- %4.2f%%"
141 | % (ITER_NUM, acc_mean, 1.96 * acc_std / np.sqrt(ITER_NUM))
142 | )
143 |
144 | with open(args.record_file, "a") as f:
145 | timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime())
146 | acc_ci = 1.96 * acc_std / np.sqrt(ITER_NUM)
147 | f.write(
148 | json.dumps(
149 | {
150 | "time": timestamp,
151 | "split": args.split,
152 | "setting": args.checkpoint_dir,
153 | "iter_num": ITER_NUM,
154 | "acc": acc_mean,
155 | "acc_ci": acc_ci,
156 | "acc_all": list(acc_all),
157 | "acc_std": acc_std,
158 | },
159 | sort_keys=True,
160 | )
161 | )
162 | f.write("\n")
163 |
--------------------------------------------------------------------------------
/birds/fewshot/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train script.
3 | """
4 |
5 | import json
6 | import os
7 | from collections import defaultdict
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.optim
12 | from tqdm import tqdm
13 |
14 | import constants
15 | from data import lang_utils
16 | from data.datamgr import SetDataManager
17 | from io_utils import get_resume_file, model_dict, parse_args
18 | from models.language import TextProposal, TextRep
19 | from models.protonet import ProtoNet
20 |
21 |
22 | def get_optimizer(model, args):
23 | """
24 | Get the optimizer for the model based on arguments. Specifically, if
25 | needed, we split up training into (1) main parameters, (2) RNN-specific
26 | parameters, with different learning rates if specified.
27 |
28 | :param model: nn.Module to train
29 | :param args: argparse.Namespace - other args passed to the script
30 |
31 | :returns: a torch.optim.Optimizer
32 | """
33 | # Get params
34 | main_params = {"params": []}
35 | rnn_params = {"params": [], "lr": args.rnn_lr_scale * args.lr}
36 | for name, param in model.named_parameters():
37 | if not param.requires_grad:
38 | continue
39 | if name.startswith("language_model."):
40 | # Scale RNN learning rate
41 | rnn_params["params"].append(param)
42 | else:
43 | main_params["params"].append(param)
44 | if args.lsl and not rnn_params["params"]:
45 | print("Warning: --lsl is set but no RNN parameters found")
46 | params_to_optimize = [main_params, rnn_params]
47 |
48 | # Define optimizer
49 | if args.optimizer == "adam":
50 | optimizer = torch.optim.Adam(params_to_optimize, lr=args.lr)
51 | elif args.optimizer == "amsgrad":
52 | optimizer = torch.optim.Adam(params_to_optimize, lr=args.lr, amsgrad=True)
53 | elif args.optimizer == "rmsprop":
54 | optimizer = torch.optim.RMSprop(params_to_optimize, lr=args.lr)
55 | else:
56 | raise NotImplementedError("optimizer = {}".format(args.optimizer))
57 | return optimizer
58 |
59 |
60 | def train(
61 | base_loader,
62 | val_loader,
63 | model,
64 | start_epoch,
65 | stop_epoch,
66 | args,
67 | metrics_fname="metrics.json",
68 | ):
69 | """
70 | Main training script.
71 |
72 | :param base_loader: torch.utils.data.DataLoader for training set, generated
73 | by data.datamgr.SetDataManager
74 | :param val_loader: torch.utils.data.DataLoader for validation set,
75 | generated by data.datamgr.SetDataManager
76 | :param model: nn.Module to train
77 | :param start_epoch: which epoch we started at
78 | :param stop_epoch: which epoch to end at
79 | :param args: other arguments passed to the script
80 | "param metrics_fname": where to save metrics
81 | """
82 | optimizer = get_optimizer(model, args)
83 |
84 | max_val_acc = 0
85 | best_epoch = 0
86 |
87 | val_accs = []
88 | val_losses = []
89 | all_metrics = defaultdict(list)
90 | for epoch in tqdm(
91 | range(start_epoch, stop_epoch), total=stop_epoch - start_epoch, desc="Train"
92 | ):
93 | model.train()
94 | metric = model.train_loop(epoch, base_loader, optimizer, args)
95 | for m, val in metric.items():
96 | all_metrics[m].append(val)
97 | model.eval()
98 |
99 | os.makedirs(args.checkpoint_dir, exist_ok=True)
100 |
101 | val_acc, val_loss = model.test_loop(val_loader,)
102 | val_accs.append(val_acc)
103 | val_losses.append(val_loss)
104 | if val_acc > max_val_acc:
105 | best_epoch = epoch
106 | tqdm.write("best model! save...")
107 | max_val_acc = val_acc
108 | outfile = os.path.join(args.checkpoint_dir, "best_model.tar")
109 | torch.save({"epoch": epoch, "state": model.state_dict()}, outfile)
110 |
111 | if epoch and (epoch % args.save_freq == 0) or (epoch == stop_epoch - 1):
112 | outfile = os.path.join(args.checkpoint_dir, "{:d}.tar".format(epoch))
113 | torch.save({"epoch": epoch, "state": model.state_dict()}, outfile)
114 | tqdm.write("")
115 |
116 | # Save metrics
117 | metrics = {
118 | "train_acc": all_metrics["train_acc"],
119 | "current_train_acc": all_metrics["train_acc"][-1],
120 | "train_loss": all_metrics["train_loss"],
121 | "current_train_loss": all_metrics["train_loss"][-1],
122 | "cls_loss": all_metrics["cls_loss"],
123 | "current_cls_loss": all_metrics["cls_loss"][-1],
124 | "lang_loss": all_metrics["lang_loss"],
125 | "current_lang_loss": all_metrics["lang_loss"][-1],
126 | "current_epoch": epoch,
127 | "val_acc": val_accs,
128 | "val_loss": val_losses,
129 | "current_val_loss": val_losses[-1],
130 | "current_val_acc": val_acc,
131 | "best_epoch": best_epoch,
132 | "best_val_acc": max_val_acc,
133 | }
134 | with open(os.path.join(args.checkpoint_dir, metrics_fname), "w") as fout:
135 | json.dump(metrics, fout, sort_keys=True, indent=4, separators=(",", ": "))
136 |
137 | # Save a copy to current metrics too
138 | if (
139 | metrics_fname != "metrics.json"
140 | and metrics_fname.startswith("metrics_")
141 | and metrics_fname.endswith(".json")
142 | ):
143 | metrics["n"] = int(metrics_fname[8])
144 | with open(os.path.join(args.checkpoint_dir, "metrics.json"), "w") as fout:
145 | json.dump(
146 | metrics, fout, sort_keys=True, indent=4, separators=(",", ": ")
147 | )
148 |
149 | # If didn't train, save model anyways
150 | if stop_epoch == 0:
151 | outfile = os.path.join(args.checkpoint_dir, "best_model.tar")
152 | torch.save({"epoch": stop_epoch, "state": model.state_dict()}, outfile)
153 |
154 |
155 | if __name__ == "__main__":
156 | args = parse_args("train")
157 |
158 | if args.seed is not None:
159 | torch.manual_seed(args.seed)
160 | # I don't seed the np rng since dataset loading uses multiprocessing with
161 | # random choices.
162 | # https://github.com/numpy/numpy/issues/9650
163 | # Unavoidable undeterminism here for now
164 |
165 | base_file = os.path.join(constants.DATA_DIR, "base.json")
166 | val_file = os.path.join(constants.DATA_DIR, "val.json")
167 |
168 | # Load language
169 | vocab = lang_utils.load_vocab(constants.LANG_DIR)
170 |
171 | l3_model = None
172 | lang_model = None
173 | if args.lsl or args.l3:
174 | if args.glove_init:
175 | vecs = lang_utils.glove_init(vocab, emb_size=args.lang_emb_size)
176 | embedding_model = nn.Embedding(
177 | len(vocab), args.lang_emb_size, _weight=vecs if args.glove_init else None
178 | )
179 | if args.freeze_emb:
180 | embedding_model.weight.requires_grad = False
181 |
182 | lang_input_size = 1600
183 | lang_model = TextProposal(
184 | embedding_model,
185 | input_size=lang_input_size,
186 | hidden_size=args.lang_hidden_size,
187 | project_input=lang_input_size != args.lang_hidden_size,
188 | rnn=args.rnn_type,
189 | num_layers=args.rnn_num_layers,
190 | dropout=args.rnn_dropout,
191 | vocab=vocab,
192 | **lang_utils.get_special_indices(vocab)
193 | )
194 |
195 | if args.l3:
196 | l3_model = TextRep(
197 | embedding_model,
198 | hidden_size=args.lang_hidden_size,
199 | rnn=args.rnn_type,
200 | num_layers=args.rnn_num_layers,
201 | dropout=args.rnn_dropout,
202 | )
203 | l3_model = l3_model.cuda()
204 |
205 | embedding_model = embedding_model.cuda()
206 | lang_model = lang_model.cuda()
207 |
208 | # if test_n_way is smaller than train_n_way, reduce n_query to keep batch
209 | # size small
210 | n_query = max(1, int(16 * args.test_n_way / args.train_n_way))
211 |
212 | train_few_shot_args = dict(n_way=args.train_n_way, n_support=args.n_shot)
213 | base_datamgr = SetDataManager(
214 | "CUB", 84, n_query=n_query, **train_few_shot_args, args=args
215 | )
216 | print("Loading train data")
217 |
218 | base_loader = base_datamgr.get_data_loader(
219 | base_file,
220 | aug=True,
221 | lang_dir=constants.LANG_DIR,
222 | normalize=True,
223 | vocab=vocab,
224 | # Maximum training data restrictions only apply at train time
225 | max_class=args.max_class,
226 | max_img_per_class=args.max_img_per_class,
227 | max_lang_per_class=args.max_lang_per_class,
228 | )
229 |
230 | val_datamgr = SetDataManager(
231 | "CUB",
232 | 84,
233 | n_query=n_query,
234 | n_way=args.test_n_way,
235 | n_support=args.n_shot,
236 | args=args,
237 | )
238 | print("Loading val data\n")
239 | val_loader = val_datamgr.get_data_loader(
240 | val_file, aug=False, lang_dir=constants.LANG_DIR, normalize=True, vocab=vocab,
241 | )
242 | # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor
243 |
244 | model = ProtoNet(
245 | model_dict[args.model],
246 | **train_few_shot_args,
247 | # Language options
248 | lsl=args.lsl,
249 | language_model=lang_model,
250 | lang_supervision=args.lang_supervision,
251 | l3=args.l3,
252 | l3_model=l3_model,
253 | l3_n_infer=args.l3_n_infer
254 | )
255 |
256 | model = model.cuda()
257 |
258 | os.makedirs(args.checkpoint_dir, exist_ok=True)
259 |
260 | start_epoch = args.start_epoch
261 | stop_epoch = args.stop_epoch
262 |
263 | if args.resume:
264 | resume_file = get_resume_file(args.checkpoint_dir)
265 | if resume_file is not None:
266 | tmp = torch.load(resume_file)
267 | start_epoch = tmp["epoch"] + 1
268 | model.load_state_dict(tmp["state"])
269 |
270 | metrics_fname = "metrics_{}.json".format(args.n)
271 |
272 | train(
273 | base_loader,
274 | val_loader,
275 | model,
276 | start_epoch,
277 | stop_epoch,
278 | args,
279 | metrics_fname=metrics_fname,
280 | )
281 |
--------------------------------------------------------------------------------
/birds/filelists/CUB/download_CUB.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz
3 | tar -zxvf CUB_200_2011.tgz
4 | python write_CUB_filelist.py
5 |
--------------------------------------------------------------------------------
/birds/filelists/CUB/save_np.py:
--------------------------------------------------------------------------------
1 | """
2 | For each class, load images and save as numpy arrays.
3 | """
4 |
5 | import os
6 |
7 | import numpy as np
8 | from PIL import Image
9 | from tqdm import tqdm
10 |
11 | if __name__ == "__main__":
12 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
13 |
14 | parser = ArgumentParser(
15 | description="Save numpy", formatter_class=ArgumentDefaultsHelpFormatter
16 | )
17 |
18 | parser.add_argument(
19 | "--cub_dir", default="CUB_200_2011/images", help="Directory to load/cache"
20 | )
21 | parser.add_argument(
22 | "--original_cub_dir",
23 | default="CUB_200_2011/images",
24 | help="Original CUB directory if you want the image keys to be different (in case --cub_dir has changed)",
25 | )
26 | parser.add_argument("--filelist_prefix", default="./filelists/CUB/")
27 |
28 | args = parser.parse_args()
29 |
30 | for bird_class in tqdm(os.listdir(args.cub_dir), desc="Classes"):
31 | bird_imgs_np = {}
32 | class_dir = os.path.join(args.cub_dir, bird_class)
33 | bird_imgs = sorted([x for x in os.listdir(class_dir) if x != "img.npz"])
34 | for bird_img in bird_imgs:
35 | bird_img_fname = os.path.join(class_dir, bird_img)
36 | img = Image.open(bird_img_fname).convert("RGB")
37 | img_np = np.asarray(img)
38 |
39 | full_bird_img_fname = os.path.join(
40 | args.filelist_prefix, args.original_cub_dir, bird_class, bird_img
41 | )
42 |
43 | bird_imgs_np[full_bird_img_fname] = img_np
44 |
45 | np_fname = os.path.join(class_dir, "img.npz")
46 | np.savez_compressed(np_fname, **bird_imgs_np)
47 |
--------------------------------------------------------------------------------
/birds/filelists/CUB/write_CUB_filelist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os import listdir
3 | from os.path import isfile, isdir, join
4 | import os
5 | import json
6 | import numpy as np
7 | import pandas as pd
8 |
9 |
10 | if __name__ == '__main__':
11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
12 |
13 | parser = ArgumentParser()
14 |
15 | parser.add_argument('--seed', type=int, default=0, help='Random seed')
16 | parser.add_argument('--savedir', type=str, default='../../custom_filelists/CUB/',
17 | help='Directory to save filelists')
18 |
19 | args = parser.parse_args()
20 |
21 | random = np.random.RandomState(args.seed)
22 |
23 | filelist_path = './filelists/CUB/'
24 | data_path = 'CUB_200_2011/images'
25 | dataset_list = ['base', 'val', 'novel']
26 |
27 | folder_list = [f for f in listdir(data_path) if isdir(join(data_path, f))]
28 | folder_list.sort()
29 | label_dict = dict(zip(folder_list, range(0, len(folder_list))))
30 |
31 | classfile_list_all = []
32 |
33 | # Load attributes
34 | attrs = pd.read_csv('./CUB_200_2011/attributes/image_attribute_labels.txt',
35 | sep=' ',
36 | header=None,
37 | names=['image_id', 'attribute_id', 'is_present', 'certainty_id', 'time'])
38 | # Zero out attributes with certainty < 3
39 | attrs['is_present'] = np.where(attrs['certainty_id'] < 3, 0, attrs['is_present'])
40 | # Get image names
41 | image_names = pd.read_csv('./CUB_200_2011/images.txt', sep=' ',
42 | header=None,
43 | names=['image_id', 'image_name'])
44 | attrs = attrs.merge(image_names, on='image_id')
45 | attrs['is_present'] = attrs['is_present'].astype(str)
46 | attrs = attrs.groupby('image_name')['is_present'].apply(lambda col: ''.join(col))
47 | attrs = dict(zip(attrs.index, attrs))
48 | attrs = {os.path.basename(k): v for k, v in attrs.items()}
49 |
50 | for i, folder in enumerate(folder_list):
51 | folder_path = join(data_path, folder)
52 | classfile_list_all.append([
53 | join(filelist_path, folder_path, cf) for cf in listdir(folder_path)
54 | if (isfile(join(folder_path, cf)) and cf[0] != '.' and not cf.endswith('.npz'))
55 | ])
56 | random.shuffle(classfile_list_all[i])
57 |
58 | for dataset in dataset_list:
59 | file_list = []
60 | label_list = []
61 | for i, classfile_list in enumerate(classfile_list_all):
62 | if 'base' in dataset:
63 | if (i % 2 == 0):
64 | file_list.extend(classfile_list)
65 | label_list.extend(np.repeat(
66 | i, len(classfile_list)).tolist())
67 | if 'val' in dataset:
68 | if (i % 4 == 1):
69 | file_list.extend(classfile_list)
70 | label_list.extend(np.repeat(
71 | i, len(classfile_list)).tolist())
72 | if 'novel' in dataset:
73 | if (i % 4 == 3):
74 | file_list.extend(classfile_list)
75 | label_list.extend(np.repeat(
76 | i, len(classfile_list)).tolist())
77 |
78 | # Get attributes
79 | attribute_list = [
80 | attrs[os.path.basename(f)] for f in file_list if not f.endswith('.npz')
81 | ]
82 |
83 | djson = {
84 | 'label_names': folder_list,
85 | 'image_names': file_list,
86 | 'image_labels': label_list,
87 | 'image_attributes': attribute_list,
88 | }
89 |
90 | os.makedirs(args.savedir, exist_ok=True)
91 | with open(os.path.join(args.savedir, dataset + '.json'), 'w') as fout:
92 | json.dump(djson, fout)
93 |
94 | print("%s -OK" % dataset)
95 |
--------------------------------------------------------------------------------
/birds/run_l3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/l3 --l3 --glove_init --lang_lambda 5 --max_lang_per_class 20 --sample_class_lang
4 |
--------------------------------------------------------------------------------
/birds/run_lang_ablation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Color
4 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_ablation/lsl_color --lsl --glove_init --lang_lambda 5 --language_filter color --max_lang_per_class 20 --sample_class_lang
5 |
6 | # Nocolor
7 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_ablation/lsl_nocolor --lsl --glove_init --lang_lambda 5 --language_filter nocolor --max_lang_per_class 20 --sample_class_lang
8 |
9 | # Shuffled words
10 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_ablation/lsl_shuffled_words --lsl --glove_init --lang_lambda 5 --shuffle_lang --max_lang_per_class 20 --sample_class_lang
11 |
12 | # Shuffled captions
13 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_ablation/lsl_shuffled_captions --lsl --glove_init --lang_lambda 5 --scramble_all --max_lang_per_class 20 --sample_class_lang
14 |
--------------------------------------------------------------------------------
/birds/run_lang_amount.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for max_lang_per_class in 1 5 10 20 30 40 50 60; do
4 | # LSL
5 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_amount/lsl_max_lang_$max_lang_per_class --lsl --glove_init --lang_lambda 5 --max_lang_per_class $max_lang_per_class --sample_class_lang
6 |
7 | # L3
8 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/language_amount/l3_max_lang_$max_lang_per_class --l3 --glove_init --max_lang_per_class $max_lang_per_class --sample_class_lang
9 | done
10 |
--------------------------------------------------------------------------------
/birds/run_lsl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Standard LSL
4 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/lsl --lsl --glove_init --lang_lambda 5 --max_lang_per_class 20 --sample_class_lang
5 |
--------------------------------------------------------------------------------
/birds/run_meta.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python fewshot/run_cl.py --n 1 --log_dir exp/acl/meta
4 |
--------------------------------------------------------------------------------
/shapeworld/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | notebooks/
107 | tmp/
108 | .vscode
109 | viz/
110 |
111 | # History files
112 | .Rhistory
113 | .Rapp.history
114 |
115 | # Session Data files
116 | .RData
117 |
118 | # Example code in package build process
119 | *-Ex.R
120 |
121 | # Output files from R CMD build
122 | /*.tar.gz
123 |
124 | # Output files from R CMD check
125 | /*.Rcheck/
126 |
127 | # RStudio files
128 | .Rproj.user/
129 |
130 | # produced vignettes
131 | vignettes/*.html
132 | vignettes/*.pdf
133 |
134 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3
135 | .httr-oauth
136 |
137 | # knitr and R markdown default cache directories
138 | /*_cache/
139 | /cache/
140 | # R markdown files directories
141 | /*_files/
142 |
143 | # Temporary files created by R markdown
144 | *.utf8.md
145 | *.knit.md
146 | .Rproj.user
147 |
148 | *.nb.html
149 | notebooks/
150 |
151 | exp/*
152 |
--------------------------------------------------------------------------------
/shapeworld/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Mike Wu, Jesse Mu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/shapeworld/README.md:
--------------------------------------------------------------------------------
1 | # LSL - ShapeWorld experiments
2 |
3 | This code is graciously adapted from code written by [Mike Wu](https://www.mikehwu.com/).
4 |
5 | ## Dependencies
6 |
7 | Tested with Python 3.7.4, torch 1.3.0, torchvision 0.4.1, sklearn 0.21.3, and numpy 1.17.2.
8 |
9 | ## Data
10 |
11 | Download data [here](http://nlp.stanford.edu/data/muj/shapeworld_4k.tar.gz)
12 | (~850 MB). Untar, and set `DATA_DIR` in `datasets.py` to be
13 | point to the folder *containing* the ShapeWorld folder you just unzipped.
14 |
15 | This code works with Jacob Andreas' [original ShapeWorld data
16 | files](http://people.eecs.berkeley.edu/~jda/data/shapeworld.tar.gz) if you replace
17 | every `.npz` file with `.npy` in `datasets.py` and remove the `['arr_0']` indexing after each `np.load`.
18 | Results are similar, but with higher variance on test accuracies.
19 |
20 | For more details on the dataset (and how to reproduce it), check
21 | [jacobandreas/l3](https://github.com/jacobandreas/l3) and the accompanying
22 | [paper](https://arxiv.org/abs/1711.00482)
23 |
24 | ## Running
25 |
26 | The models can be run with the scripts in this directory:
27 |
28 | - `run_l3.sh` - L3
29 | - `run_lsl.sh` - LSL (ours)
30 | - `run_lsl_img.sh` - LSL, but decoding captions from the image embeddings
31 | instead of the concept (not reported)
32 | - `run_meta.sh` - meta-learning baseline
33 | - `run_lang_ablation.sh` - language ablation studies
34 |
35 | They will output results in the `exp/` directory (paper runs are already present there)
36 |
37 | To change the backbone, use `--backbone conv4` or `--backbone ResNet18`. ResNet18 may need reduced batch size (we use batch size 32)
38 |
39 | ## Analysis
40 |
41 | `analysis/metrics.Rmd` contains `R` code for reproducing the plots in the
42 | paper.
43 |
--------------------------------------------------------------------------------
/shapeworld/analysis/analysis.Rproj:
--------------------------------------------------------------------------------
1 | Version: 1.0
2 |
3 | RestoreWorkspace: Default
4 | SaveWorkspace: Default
5 | AlwaysSaveHistory: Default
6 |
7 | EnableCodeIndexing: Yes
8 | UseSpacesForTab: Yes
9 | NumSpacesForTab: 2
10 | Encoding: UTF-8
11 |
12 | RnwWeave: Sweave
13 | LaTeX: pdfLaTeX
14 |
--------------------------------------------------------------------------------
/shapeworld/analysis/metrics.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Metrics analysis"
3 | output: html_notebook
4 | ---
5 |
6 | ```{r setup}
7 | library(tidyverse)
8 | library(cowplot)
9 | library(jsonlite)
10 | theme_set(theme_cowplot())
11 | ```
12 |
13 | # Standard eval
14 |
15 | ```{r}
16 | MODELS <- c(
17 | 'l3' = 'L3',
18 | 'lsl' = 'LSL',
19 | 'meta' = 'Meta'
20 | )
21 | ```
22 |
23 | ```{r}
24 | metrics <- sapply(names(MODELS), function(model) {
25 | metrics_file <- paste0('../exp/', model, '/metrics.json')
26 | metrics_df <- as.data.frame(read_json(metrics_file, simplifyVector = TRUE)) %>%
27 | tbl_df %>%
28 | select(train_acc, val_acc, val_same_acc, test_acc, test_same_acc) %>%
29 | mutate(avg_val_acc = (val_acc + val_same_acc) / 2,
30 | avg_test_acc = (test_acc + test_same_acc) / 2)
31 | metrics_df %>%
32 | mutate(epoch = 1:nrow(metrics_df)) %>%
33 | mutate(model = MODELS[model])
34 | }, simplify = FALSE) %>%
35 | do.call(rbind, .) %>%
36 | mutate(model = factor(model))
37 |
38 | metrics_long <- metrics %>%
39 | gather('metric', 'value', -epoch, -model) %>%
40 | mutate(metric = factor(metric, levels = c('train_acc', 'avg_val_acc', 'val_acc', 'val_same_acc', 'avg_test_acc', 'test_acc', 'test_same_acc')))
41 | ```
42 |
43 | ```{r fig.width=3.5, fig.height=2}
44 | metric_names <- c('train_acc' = 'Train', 'avg_val_acc' = 'Val', 'avg_test_acc' = 'Test')
45 | ggplot(metrics_long %>% filter(metric %in% c('train_acc', 'avg_val_acc', 'avg_test_acc')) %>% rename(Model = model), aes(x = epoch, y = value, color = Model)) +
46 | geom_line() +
47 | facet_wrap(~ metric, labeller = as_labeller(metric_names)) +
48 | xlab('Epoch') +
49 | ylab('Accuracy')
50 | ```
--------------------------------------------------------------------------------
/shapeworld/exp/README.md:
--------------------------------------------------------------------------------
1 | # Experiment results folder
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/l3/args.json:
--------------------------------------------------------------------------------
1 | {"exp_dir": "exp/vigil/l3", "predict_concept_hyp": false, "predict_image_hyp": false, "infer_hyp": true, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 523, "language_filter": null, "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 1.0, "save_checkpoint": false, "cuda": true, "predict_hyp": false, "use_hyp": true, "encode_hyp": true, "decode_hyp": true}
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/l3/metrics.json:
--------------------------------------------------------------------------------
1 | {"train_acc": [0.5023333333333333, 0.5304444444444445, 0.545, 0.573, 0.662, 0.6244444444444445, 0.6083333333333333, 0.7182222222222222, 0.5692222222222222, 0.7284444444444444, 0.7607777777777778, 0.7855555555555556, 0.7686666666666667, 0.7953333333333333, 0.8346666666666667, 0.8088888888888889, 0.8265555555555556, 0.8394444444444444, 0.8614444444444445, 0.7967777777777778, 0.8381111111111111, 0.862, 0.86, 0.8982222222222223, 0.878, 0.8752222222222222, 0.9072222222222223, 0.8603333333333333, 0.9007777777777778, 0.8971111111111111, 0.9114444444444444, 0.8907777777777778, 0.9074444444444445, 0.9084444444444445, 0.8948888888888888, 0.8741111111111111, 0.8701111111111111, 0.9282222222222222, 0.891, 0.9207777777777778, 0.932, 0.9013333333333333, 0.9011111111111111, 0.9132222222222223, 0.9084444444444445, 0.9397777777777778, 0.9152222222222223, 0.9025555555555556, 0.9347777777777778, 0.916], "val_acc": [0.508, 0.5, 0.52, 0.486, 0.536, 0.524, 0.524, 0.538, 0.53, 0.546, 0.598, 0.568, 0.572, 0.56, 0.62, 0.608, 0.582, 0.588, 0.594, 0.61, 0.618, 0.626, 0.632, 0.58, 0.622, 0.606, 0.598, 0.65, 0.64, 0.624, 0.6, 0.632, 0.62, 0.618, 0.608, 0.63, 0.634, 0.596, 0.63, 0.616, 0.628, 0.614, 0.612, 0.632, 0.602, 0.628, 0.648, 0.652, 0.614, 0.622], "val_same_acc": [0.498, 0.494, 0.52, 0.546, 0.508, 0.556, 0.52, 0.532, 0.532, 0.57, 0.582, 0.582, 0.594, 0.582, 0.602, 0.614, 0.596, 0.566, 0.594, 0.636, 0.644, 0.612, 0.63, 0.654, 0.6, 0.632, 0.616, 0.69, 0.648, 0.66, 0.662, 0.648, 0.638, 0.674, 0.646, 0.69, 0.676, 0.614, 0.664, 0.67, 0.626, 0.674, 0.67, 0.668, 0.66, 0.64, 0.692, 0.694, 0.61, 0.642], "val_tre": [0.25633602127432825, 0.29470103612542153, 0.3173871642053127, 0.32591839861869815, 0.33455874407291414, 0.3444766681492329, 0.3517530900835991, 0.3546512795686722, 0.3629691895842552, 0.37140207839012146, 0.3772867166697979, 0.38848341038823125, 0.39334050261974335, 0.41172417044639587, 0.41499094939231873, 0.42147668239474295, 0.4337600200772285, 0.44199954602122304, 0.45299749591946603, 0.4578404571712017, 0.4647700753211975, 0.471630163282156, 0.4751690610051155, 0.4885770261287689, 0.493298515021801, 0.4948560266792774, 0.4987256095409393, 0.5102960558831692, 0.5090047884583473, 0.510729871481657, 0.5190574961602687, 0.5180712405443192, 0.5254904857873917, 0.5273345997929573, 0.5298865022361279, 0.5335929306447506, 0.5398195767402649, 0.5402879483401776, 0.5459680155813694, 0.5467929720878602, 0.5468201187252998, 0.5507783033549786, 0.5473250425457954, 0.5527532941997051, 0.5607433926463127, 0.5589669682085514, 0.5609841901659965, 0.5578600949645043, 0.5605509672164917, 0.563900175690651], "val_tre_std": [0.0008983922636733366, 0.0012393818670073098, 0.002705310356240471, 0.004119969486292012, 0.004415815432919283, 0.004880650390603528, 0.008861252828759958, 0.009476131311003444, 0.009022034651017733, 0.01143623739051007, 0.017494352103482703, 0.017762730295086442, 0.021317160074330523, 0.023186749583176555, 0.022208067779347494, 0.02261467769505962, 0.021649295420653976, 0.02183235733200545, 0.02308747601341338, 0.024546957098724283, 0.027185148108284336, 0.02425879832978174, 0.025027398060597974, 0.024505227514392356, 0.025782036503912583, 0.02413022162730152, 0.02493159414026219, 0.02113466113404446, 0.02219720340775683, 0.023728708218024107, 0.023085432824290267, 0.022559277009476516, 0.021082738316435375, 0.024827830885288418, 0.022480336133616313, 0.022559989002424452, 0.02072817771662128, 0.01973465625196693, 0.018798697877029417, 0.019722927256330926, 0.023218060008900417, 0.020582324265838555, 0.02138545180626131, 0.020373384768127337, 0.01968535742858713, 0.021034499364920817, 0.019202968664080812, 0.020055075653409696, 0.019806053228292417, 0.021419522391281927], "test_acc": [0.49675, 0.5085, 0.515, 0.53325, 0.539, 0.52825, 0.541, 0.5665, 0.5155, 0.57875, 0.58325, 0.579, 0.5835, 0.58175, 0.60325, 0.61875, 0.60825, 0.6055, 0.60975, 0.632, 0.638, 0.6335, 0.64175, 0.62575, 0.63975, 0.644, 0.632, 0.65275, 0.6365, 0.6455, 0.63525, 0.64425, 0.63775, 0.646, 0.64475, 0.682, 0.66075, 0.62875, 0.6405, 0.6415, 0.6535, 0.65975, 0.6585, 0.6655, 0.654, 0.62825, 0.64875, 0.67, 0.6345, 0.648], "test_same_acc": [0.4985, 0.51025, 0.51725, 0.5185, 0.5445, 0.54175, 0.53425, 0.5655, 0.522, 0.57925, 0.5795, 0.57975, 0.59125, 0.59425, 0.6055, 0.62575, 0.613, 0.59075, 0.6055, 0.636, 0.63925, 0.63925, 0.64675, 0.618, 0.63275, 0.64875, 0.6225, 0.63025, 0.64375, 0.64175, 0.642, 0.6535, 0.6365, 0.64925, 0.6525, 0.66225, 0.66975, 0.6335, 0.65175, 0.63025, 0.63575, 0.65775, 0.665, 0.66, 0.657, 0.61475, 0.64725, 0.662, 0.6175, 0.63975], "test_acc_ci": [0.010955877061536423, 0.010956264336984343, 0.010956360212081153, 0.010956483478830012, 0.010955374028159649, 0.010956264336984343, 0.010955877061536423, 0.010956675224349538, 0.010955697287113908, 0.0109567245297979, 0.010956683784339127, 0.010955659277286785, 0.010955659277286785, 0.010953376733769135, 0.010954702824443708, 0.010951121805550332, 0.010950020034228245, 0.010948714963267377, 0.01094729390705758, 0.010922705819616081, 0.010919382294228, 0.01092869042417709, 0.010911272425798927, 0.010930996998787437, 0.010903721834624864, 0.010910012236330214, 0.01091202021840479, 0.010886866255809107, 0.010906127598740076, 0.010874212600804713, 0.010906391433323625, 0.010866250967876408, 0.0108775585496011, 0.010902360982259758, 0.010861949694777408, 0.010840536777386764, 0.010819544733808119, 0.010893535577207014, 0.0108775585496011, 0.010881481312303026, 0.010882125315741406, 0.010857547206753235, 0.010835660275076688, 0.010835249299739252, 0.010826880986115069, 0.01087854868407891, 0.010847660668175189, 0.010826880986115069, 0.010896743733221406, 0.010856057197338035], "best_epoch": 48, "best_val_acc": 0.652, "best_val_same_acc": 0.694, "best_val_tre": 0.5578600949645043, "best_val_tre_std": 0.020055075653409696, "best_test_acc": 0.67, "best_test_same_acc": 0.662, "best_test_acc_ci": 0.010826880986115069, "lowest_val_tre": 0.25633602127432825, "lowest_val_tre_std": 0.0008983922636733366, "has_same": true}
--------------------------------------------------------------------------------
/shapeworld/exp/lsl/args.json:
--------------------------------------------------------------------------------
1 | {"exp_dir": "exp/lsl", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 27140, "language_filter": null, "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true}
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/lsl/metrics.json:
--------------------------------------------------------------------------------
1 | {"train_acc": [0.5018888888888889, 0.5018888888888889, 0.5592222222222222, 0.6594444444444445, 0.7117777777777777, 0.731, 0.7423333333333333, 0.7373333333333333, 0.7316666666666667, 0.7376666666666667, 0.7403333333333333, 0.7303333333333333, 0.7413333333333333, 0.7401111111111112, 0.7433333333333333, 0.7292222222222222, 0.7354444444444445, 0.7391111111111112, 0.7423333333333333, 0.7348888888888889, 0.7391111111111112, 0.742, 0.745, 0.7401111111111112, 0.7373333333333333, 0.7472222222222222, 0.7374444444444445, 0.714, 0.7431111111111111, 0.742, 0.7372222222222222, 0.7266666666666667, 0.7314444444444445, 0.7421111111111112, 0.7452222222222222, 0.7514444444444445, 0.756, 0.7337777777777778, 0.742, 0.7391111111111112, 0.7505555555555555, 0.7317777777777777, 0.7293333333333333, 0.7471111111111111, 0.7375555555555555, 0.7288888888888889, 0.7352222222222222, 0.7392222222222222, 0.7356666666666667, 0.7437777777777778], "val_acc": [0.508, 0.508, 0.556, 0.6, 0.598, 0.612, 0.608, 0.614, 0.618, 0.618, 0.636, 0.65, 0.656, 0.636, 0.656, 0.66, 0.668, 0.664, 0.654, 0.646, 0.656, 0.654, 0.66, 0.642, 0.65, 0.636, 0.644, 0.652, 0.652, 0.662, 0.654, 0.638, 0.642, 0.656, 0.646, 0.666, 0.65, 0.666, 0.676, 0.67, 0.664, 0.668, 0.65, 0.668, 0.662, 0.656, 0.658, 0.65, 0.664, 0.652], "val_same_acc": [0.496, 0.496, 0.528, 0.558, 0.59, 0.592, 0.61, 0.628, 0.638, 0.624, 0.644, 0.656, 0.666, 0.664, 0.676, 0.656, 0.658, 0.662, 0.654, 0.64, 0.676, 0.682, 0.674, 0.674, 0.678, 0.686, 0.684, 0.674, 0.664, 0.684, 0.694, 0.668, 0.696, 0.71, 0.698, 0.7, 0.69, 0.678, 0.68, 0.694, 0.696, 0.706, 0.702, 0.698, 0.656, 0.692, 0.682, 0.676, 0.68, 0.692], "val_tre": [0.23315593773126603, 0.3193679120838642, 0.4544656649827957, 0.6009624934494495, 0.6937140434384346, 0.7510887870192527, 0.7626189323961735, 0.7726548528671264, 0.7821315121650696, 0.7757832805812359, 0.777156808435917, 0.7755475530028343, 0.803969556093216, 0.7948731675744056, 0.8089491415023804, 0.8003908374905586, 0.79257049202919, 0.7978935792148113, 0.7970038573145867, 0.7915700083374977, 0.8110601940453053, 0.8116079128980637, 0.8045996649861336, 0.8138662700951099, 0.8021689679026603, 0.8072018190920353, 0.8115154400467872, 0.7561183496713638, 0.7906409577429294, 0.7920836460888385, 0.7970328702330589, 0.7686191479265689, 0.7872865536510945, 0.8017647383213043, 0.7931095496416092, 0.7974768998026848, 0.8039976232647896, 0.7989465886652469, 0.7998599541783333, 0.7994992446899414, 0.7942561790943146, 0.7959962756037712, 0.7920953050553798, 0.799251141756773, 0.7694647860527039, 0.7820107011795044, 0.7913144878745079, 0.7904651002883911, 0.7779531913697719, 0.7935415287613868], "val_tre_std": [0.015967926829254738, 0.041638463279630916, 0.10019827073131268, 0.14631836833822712, 0.2240175849950952, 0.2284532373612079, 0.19094390253159696, 0.16651079912256633, 0.1366198991062094, 0.1531613575168732, 0.15832926271828068, 0.1355022313367676, 0.1371991570253776, 0.15200430030145073, 0.1289743287693288, 0.11889629150246303, 0.1299349850688645, 0.1439676228356837, 0.15671544266725537, 0.13916557490241635, 0.1254910527876445, 0.12804438160833864, 0.1348672604639425, 0.12044834017050647, 0.12676008889430834, 0.11461523484218875, 0.12100130478055854, 0.16064779698548673, 0.15085657104567965, 0.12438815929620986, 0.12298713167116866, 0.1302325547150731, 0.13835974209656793, 0.13045702495739636, 0.12149457598918338, 0.12178719797982543, 0.10589989430645402, 0.10695060304222803, 0.11247619327198298, 0.12526679163881363, 0.11674564159243002, 0.12351190996349973, 0.13275382956110487, 0.11753092526746389, 0.12913086971281487, 0.13101288147859969, 0.11308533867954541, 0.1238697781895269, 0.12148779349419715, 0.12295913230178925], "test_acc": [0.496, 0.496, 0.532, 0.558, 0.57675, 0.601, 0.61175, 0.6135, 0.62775, 0.61775, 0.635, 0.6445, 0.65125, 0.646, 0.65825, 0.6525, 0.65325, 0.663, 0.65875, 0.65725, 0.66775, 0.67025, 0.66625, 0.664, 0.654, 0.67625, 0.6745, 0.66525, 0.66375, 0.6715, 0.66725, 0.652, 0.661, 0.68025, 0.66775, 0.67025, 0.67625, 0.6735, 0.6715, 0.66925, 0.6675, 0.672, 0.6785, 0.68, 0.6535, 0.6715, 0.67175, 0.667, 0.6665, 0.67025], "test_same_acc": [0.49675, 0.49675, 0.53075, 0.56775, 0.59, 0.61025, 0.61975, 0.6265, 0.628, 0.632, 0.65075, 0.64925, 0.65575, 0.654, 0.67175, 0.66725, 0.66675, 0.6685, 0.676, 0.66425, 0.6755, 0.68, 0.67975, 0.6695, 0.6645, 0.6785, 0.67675, 0.66475, 0.66725, 0.67275, 0.6695, 0.654, 0.66725, 0.6735, 0.66875, 0.66725, 0.67725, 0.6755, 0.67725, 0.6745, 0.67025, 0.672, 0.6735, 0.67425, 0.65825, 0.66575, 0.673, 0.67475, 0.667, 0.6755], "test_acc_ci": [0.010956445129323426, 0.010956445129323426, 0.01093514040247655, 0.010869758131939963, 0.010803330145000428, 0.010709462226081894, 0.01065909238103789, 0.010636499424152667, 0.010592344504257542, 0.010609518073262093, 0.01049988332539343, 0.010473347031721758, 0.010427627129409645, 0.010452057213773755, 0.010342947113854927, 0.010381525923334921, 0.010380601138662442, 0.010337187015213566, 0.010324606027441192, 0.010375035512589824, 0.010291046665518283, 0.01026269256807759, 0.010279985126448383, 0.01032946120993249, 0.010386137775299342, 0.010244120270763858, 0.010258588775844122, 0.010342947113854929, 0.010339110307468433, 0.010287035875162245, 0.010316795164772585, 0.010431155170929057, 0.010349630191309977, 0.010248270856435975, 0.010317774391190184, 0.010313852562815702, 0.010249306410069903, 0.010267803560158325, 0.010268823261302872, 0.010289042924241059, 0.010312870053815038, 0.010288039813297768, 0.010255502172005037, 0.010246197238335546, 0.01041069513033952, 0.010314834250819303, 0.010285025516345352, 0.01029703806183482, 0.01032946120993249, 0.010280994861727876], "best_epoch": 50, "best_val_acc": 0.652, "best_val_same_acc": 0.692, "best_val_tre": 0.7935415287613868, "best_val_tre_std": 0.12295913230178925, "best_test_acc": 0.67025, "best_test_same_acc": 0.6755, "best_test_acc_ci": 0.010280994861727876, "lowest_val_tre": 0.23315593773126603, "lowest_val_tre_std": 0.015967926829254738, "has_same": true}
--------------------------------------------------------------------------------
/shapeworld/exp/lsl_color/args.json:
--------------------------------------------------------------------------------
1 | {"exp_dir": "exp/lsl_color", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 19626, "language_filter": "color", "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true}
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/lsl_color/metrics.json:
--------------------------------------------------------------------------------
1 | {"train_acc": [0.5018888888888889, 0.583, 0.7208888888888889, 0.7158888888888889, 0.7237777777777777, 0.7438888888888889, 0.7416666666666667, 0.7338888888888889, 0.7194444444444444, 0.7316666666666667, 0.7446666666666667, 0.7274444444444444, 0.7347777777777778, 0.7404444444444445, 0.7362222222222222, 0.7321111111111112, 0.7328888888888889, 0.7116666666666667, 0.7422222222222222, 0.7313333333333333, 0.7286666666666667, 0.7286666666666667, 0.737, 0.7194444444444444, 0.7313333333333333, 0.7153333333333334, 0.7331111111111112, 0.731, 0.7397777777777778, 0.7474444444444445, 0.7346666666666667, 0.6967777777777778, 0.725, 0.7418888888888889, 0.739, 0.7433333333333333, 0.7352222222222222, 0.7178888888888889, 0.7315555555555555, 0.7408888888888889, 0.7422222222222222, 0.7294444444444445, 0.734, 0.7356666666666667, 0.733, 0.7255555555555555, 0.7358888888888889, 0.7362222222222222, 0.7361111111111112, 0.7376666666666667], "val_acc": [0.508, 0.546, 0.602, 0.62, 0.608, 0.602, 0.596, 0.584, 0.606, 0.604, 0.61, 0.606, 0.616, 0.632, 0.592, 0.624, 0.59, 0.616, 0.614, 0.598, 0.624, 0.612, 0.628, 0.644, 0.632, 0.604, 0.636, 0.616, 0.638, 0.644, 0.63, 0.59, 0.616, 0.626, 0.63, 0.634, 0.648, 0.614, 0.614, 0.632, 0.62, 0.62, 0.622, 0.622, 0.622, 0.628, 0.626, 0.622, 0.634, 0.606], "val_same_acc": [0.496, 0.564, 0.588, 0.602, 0.638, 0.626, 0.624, 0.636, 0.636, 0.634, 0.646, 0.642, 0.648, 0.648, 0.656, 0.642, 0.652, 0.626, 0.638, 0.636, 0.652, 0.66, 0.664, 0.676, 0.638, 0.628, 0.666, 0.65, 0.668, 0.662, 0.66, 0.624, 0.65, 0.668, 0.658, 0.666, 0.654, 0.642, 0.648, 0.68, 0.66, 0.654, 0.666, 0.652, 0.642, 0.662, 0.658, 0.646, 0.652, 0.64], "val_tre": [0.2928761223256588, 0.4909118238091469, 0.7093970262408257, 0.715656586676836, 0.6946930028498173, 0.7530982179045678, 0.7181237662732601, 0.7177722745537758, 0.7063919916450977, 0.7170701187551022, 0.7364676860868931, 0.7208060621917248, 0.707125977486372, 0.7182179855704307, 0.7051682096123696, 0.7211604817807674, 0.7094560405910015, 0.6675868063867092, 0.7088536221086978, 0.691121347218752, 0.7120009001791477, 0.6966433502137661, 0.705272063344717, 0.6911544860005379, 0.6795149468779564, 0.6730550227761268, 0.7166172302365303, 0.7124211880266667, 0.7084594193398952, 0.7274633083939552, 0.7160733468532562, 0.6490570981502533, 0.6786254914104939, 0.7213582392036915, 0.7089775923788547, 0.7222011366784573, 0.7164143627583981, 0.6961094659864903, 0.7130949632823467, 0.7321216926574707, 0.7255135977864265, 0.7286434670984745, 0.714307487398386, 0.7320686898827553, 0.7254576664865017, 0.7269528369903564, 0.7190864905118942, 0.7051229429543018, 0.7255788600146771, 0.7198136151731014], "val_tre_std": [0.028786290772653104, 0.15274888657758995, 0.29936371313875904, 0.3729337300219388, 0.2503601848631358, 0.34096694884861295, 0.29398723347551503, 0.28951914089576947, 0.29075415453489756, 0.2656755109916801, 0.2658799580899334, 0.24215173114625418, 0.22278727569008222, 0.2368307482422677, 0.23235498110132288, 0.2072491654701949, 0.2702409115734452, 0.2236184128671816, 0.22949241611567125, 0.27024222622417127, 0.2521140637309202, 0.22977934149599544, 0.2523585546825122, 0.2574569513302243, 0.225301321506544, 0.24954425982964126, 0.21412330289882422, 0.17697085712411145, 0.20319146035285715, 0.14752056211130987, 0.22221709431604936, 0.23647851952807025, 0.1968089448080926, 0.2107266092599971, 0.1918881409223694, 0.1775148583175112, 0.18596143152343486, 0.22458536367131698, 0.2174860410072009, 0.1585495236993038, 0.17890690690619446, 0.19183306641129158, 0.23175687856064675, 0.1876181925216055, 0.21742256350551445, 0.17311539572774087, 0.20677547101924948, 0.20651109041952515, 0.1764090059441681, 0.2239317514391409], "test_acc": [0.496, 0.55175, 0.59975, 0.60625, 0.6155, 0.62075, 0.621, 0.616, 0.63225, 0.624, 0.6355, 0.6345, 0.63875, 0.631, 0.6295, 0.6435, 0.63125, 0.61675, 0.63425, 0.6255, 0.63925, 0.63825, 0.63875, 0.6365, 0.62925, 0.61675, 0.637, 0.64575, 0.63775, 0.642, 0.63775, 0.60725, 0.621, 0.64625, 0.6385, 0.645, 0.63325, 0.62075, 0.62625, 0.642, 0.6445, 0.6385, 0.628, 0.635, 0.63825, 0.64125, 0.6415, 0.63175, 0.6405, 0.6355], "test_same_acc": [0.49675, 0.545, 0.6035, 0.603, 0.62375, 0.628, 0.62875, 0.634, 0.63175, 0.63975, 0.64275, 0.64375, 0.64825, 0.64525, 0.64525, 0.6505, 0.6435, 0.63, 0.6425, 0.63475, 0.63825, 0.635, 0.64575, 0.6405, 0.6345, 0.63675, 0.64825, 0.6515, 0.64825, 0.65975, 0.6395, 0.61725, 0.6275, 0.656, 0.648, 0.6565, 0.6465, 0.62625, 0.64025, 0.65375, 0.64675, 0.6485, 0.6375, 0.64575, 0.63975, 0.64125, 0.653, 0.64925, 0.644, 0.6445], "test_acc_ci": [0.010956445129323426, 0.01090533192855105, 0.010728031831229574, 0.01071417487359036, 0.010638527645866931, 0.010612338039841879, 0.010609518073262093, 0.010608811196359372, 0.010568017562438093, 0.010568766924143751, 0.0105240373121191, 0.0105240373121191, 0.010495789705877305, 0.010530360715420671, 0.010535070864587241, 0.010472504867509013, 0.010535070864587243, 0.010617941807237173, 0.010528784496411491, 0.010579177945231616, 0.010526414379431392, 0.010539753280882573, 0.01050395749170283, 0.010527995229387216, 0.010568766924143751, 0.010598835638290652, 0.010501515325031668, 0.010461485297120815, 0.010499066158473331, 0.010446007842562392, 0.010527205190451784, 0.010677052495305058, 0.010613041146980444, 0.01044427234166531, 0.010497429489522661, 0.010446874407568035, 0.010519262307637119, 0.010617243971483373, 0.01056048194390294, 0.010466587748491624, 0.01048172558178638, 0.010495789705877305, 0.010563505359846228, 0.010516063496949559, 0.010524830440439406, 0.010510435751551883, 0.01047081818615432, 0.010515261858365679, 0.01050395749170283, 0.01051846376615901], "best_epoch": 49, "best_val_acc": 0.634, "best_val_same_acc": 0.652, "best_val_tre": 0.7255788600146771, "best_val_tre_std": 0.1764090059441681, "best_test_acc": 0.6405, "best_test_same_acc": 0.644, "best_test_acc_ci": 0.01050395749170283, "lowest_val_tre": 0.2928761223256588, "lowest_val_tre_std": 0.028786290772653104, "has_same": true}
--------------------------------------------------------------------------------
/shapeworld/exp/lsl_nocolor/args.json:
--------------------------------------------------------------------------------
1 | {"exp_dir": "exp/lsl_nocolor", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 12126, "language_filter": "nocolor", "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true}
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/lsl_nocolor/metrics.json:
--------------------------------------------------------------------------------
1 | {"train_acc": [0.5018888888888889, 0.5067777777777778, 0.6631111111111111, 0.7022222222222222, 0.7382222222222222, 0.7432222222222222, 0.7513333333333333, 0.7447777777777778, 0.7344444444444445, 0.7235555555555555, 0.7504444444444445, 0.7398888888888889, 0.7394444444444445, 0.7472222222222222, 0.7412222222222222, 0.7385555555555555, 0.7404444444444445, 0.7384444444444445, 0.7408888888888889, 0.7381111111111112, 0.7401111111111112, 0.7348888888888889, 0.7365555555555555, 0.7384444444444445, 0.7391111111111112, 0.7394444444444445, 0.7406666666666667, 0.6986666666666667, 0.7175555555555555, 0.7244444444444444, 0.7346666666666667, 0.7128888888888889, 0.7334444444444445, 0.7233333333333334, 0.7347777777777778, 0.755, 0.7448888888888889, 0.7327777777777778, 0.735, 0.7406666666666667, 0.7381111111111112, 0.73, 0.723, 0.7344444444444445, 0.7138888888888889, 0.7175555555555555, 0.7363333333333333, 0.7253333333333334, 0.732, 0.7188888888888889], "val_acc": [0.508, 0.514, 0.574, 0.592, 0.616, 0.622, 0.646, 0.632, 0.632, 0.626, 0.656, 0.644, 0.646, 0.636, 0.634, 0.65, 0.666, 0.63, 0.642, 0.654, 0.638, 0.638, 0.63, 0.656, 0.624, 0.646, 0.664, 0.628, 0.634, 0.666, 0.652, 0.652, 0.664, 0.658, 0.656, 0.658, 0.668, 0.664, 0.662, 0.658, 0.654, 0.654, 0.676, 0.668, 0.66, 0.664, 0.676, 0.67, 0.676, 0.662], "val_same_acc": [0.496, 0.496, 0.54, 0.568, 0.602, 0.602, 0.628, 0.63, 0.61, 0.616, 0.614, 0.642, 0.648, 0.632, 0.648, 0.638, 0.654, 0.65, 0.64, 0.652, 0.648, 0.658, 0.646, 0.654, 0.66, 0.65, 0.658, 0.62, 0.614, 0.642, 0.648, 0.634, 0.676, 0.648, 0.65, 0.67, 0.636, 0.652, 0.668, 0.666, 0.648, 0.65, 0.646, 0.632, 0.642, 0.656, 0.662, 0.644, 0.66, 0.65], "val_tre": [0.22808635076880454, 0.3463576367199421, 0.5833195138275623, 0.6936306474804879, 0.7367003444433212, 0.7617647814750671, 0.7925184162259102, 0.7806620350778103, 0.7809249736964703, 0.7618764308989048, 0.818745451271534, 0.7906182879209518, 0.8013895196616649, 0.8116443670988083, 0.7867026439905167, 0.7947599971592426, 0.8079485739171505, 0.8126739595234395, 0.8161718902587891, 0.795997316300869, 0.7976047129631042, 0.7890185303986073, 0.7939921219050884, 0.7666756071150302, 0.7984343985021114, 0.7855962689816952, 0.801971653997898, 0.7558098890483379, 0.7291252992153168, 0.7725845266282558, 0.7714625637233258, 0.726695333570242, 0.7933139767348766, 0.778794666916132, 0.8004181494414806, 0.7938655286729336, 0.7891755924224854, 0.7856353769600392, 0.7903081247210503, 0.7963232119977475, 0.7696316801607609, 0.7731537413299083, 0.7955225827097893, 0.7584440202414989, 0.7596146790981293, 0.7835552912950515, 0.8013896183669567, 0.7690100253224373, 0.7866484541594982, 0.7879996562898159], "val_tre_std": [0.01763054536880929, 0.07339547584061559, 0.14941119368980174, 0.17516040573277214, 0.2652266204437812, 0.2193908594978303, 0.16969182617088535, 0.19375694303927413, 0.17319022049681249, 0.20916169809256696, 0.14892911600273737, 0.1340359423374055, 0.16795556135270545, 0.14105599331769797, 0.16546023395342999, 0.1476439768847445, 0.16017509978664035, 0.18668384015479503, 0.1601517815583093, 0.13709783292609365, 0.15121442764377227, 0.15114236661248562, 0.17220838823519422, 0.15274215369714983, 0.11693935023993977, 0.14640735146688152, 0.12193917723328035, 0.1359785426625704, 0.16940021888863163, 0.15788490043491232, 0.15381887433112926, 0.16534711985816136, 0.11927814837344929, 0.13760702346949716, 0.13427542712608062, 0.1268173391039503, 0.1446572349584176, 0.13625064891802438, 0.14825136077253176, 0.11018452950263992, 0.1267134292548735, 0.14440096747254103, 0.13868456348510233, 0.14252962444754882, 0.15683233923197384, 0.14174051286031775, 0.12008535457174874, 0.1366931669427143, 0.11642106634144173, 0.14012168432103977], "test_acc": [0.496, 0.5, 0.55975, 0.582, 0.59375, 0.603, 0.61475, 0.6135, 0.61525, 0.623, 0.63175, 0.63575, 0.63725, 0.642, 0.6355, 0.64925, 0.64175, 0.64325, 0.64125, 0.6435, 0.64475, 0.645, 0.64175, 0.641, 0.6405, 0.6515, 0.65, 0.63525, 0.6285, 0.64125, 0.64475, 0.62175, 0.64925, 0.6335, 0.642, 0.6515, 0.641, 0.648, 0.64825, 0.64875, 0.63925, 0.63875, 0.64875, 0.64, 0.63025, 0.641, 0.6475, 0.6325, 0.64525, 0.64975], "test_same_acc": [0.49675, 0.50075, 0.574, 0.59875, 0.60675, 0.61575, 0.62725, 0.62475, 0.62325, 0.63475, 0.62975, 0.629, 0.63525, 0.63325, 0.63475, 0.6385, 0.64025, 0.635, 0.644, 0.642, 0.639, 0.636, 0.636, 0.636, 0.62875, 0.63775, 0.639, 0.6275, 0.61925, 0.63, 0.637, 0.61675, 0.64225, 0.6265, 0.63475, 0.645, 0.641, 0.63475, 0.64175, 0.64525, 0.6385, 0.64475, 0.64475, 0.63, 0.62125, 0.63925, 0.64475, 0.63075, 0.64125, 0.6455], "test_acc_ci": [0.010956445129323426, 0.010956730008167355, 0.010858287988761166, 0.010776265539224384, 0.01073424240398455, 0.010691371283510595, 0.010631057887153093, 0.010641221468744789, 0.010640549134678153, 0.010586523071664038, 0.010575476863361764, 0.010565764897624544, 0.01054208410076015, 0.010533503898127867, 0.010549035062358783, 0.010493324187161809, 0.010512047555067468, 0.0105240373121191, 0.010501515325031666, 0.010500699714185716, 0.010506392663368098, 0.010515261858365677, 0.010525622796152014, 0.010527995229387216, 0.01055210442148271, 0.010488372095891478, 0.010489199395092077, 0.010571756742465981, 0.010615145947978057, 0.010545953426403656, 0.01051285229359164, 0.010640549134678153, 0.010480891249674332, 0.010579915878682589, 0.010528784496411491, 0.01046404006048811, 0.010512047555067468, 0.010509628686203663, 0.010485885513393706, 0.010472504867509013, 0.010525622796152014, 0.01050720283365178, 0.010474188411877075, 0.010549803552673384, 0.01060455408715991, 0.010517664450669406, 0.010478383558396543, 0.01057026335868104, 0.010497429489522661, 0.010468282278954603], "best_epoch": 47, "best_val_acc": 0.676, "best_val_same_acc": 0.662, "best_val_tre": 0.8013896183669567, "best_val_tre_std": 0.12008535457174874, "best_test_acc": 0.6475, "best_test_same_acc": 0.64475, "best_test_acc_ci": 0.010478383558396543, "lowest_val_tre": 0.22808635076880454, "lowest_val_tre_std": 0.01763054536880929, "has_same": true}
--------------------------------------------------------------------------------
/shapeworld/exp/lsl_shuffle_captions/args.json:
--------------------------------------------------------------------------------
1 | {"exp_dir": "exp/lsl_shuffle_captions", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 5974, "language_filter": null, "shuffle_words": false, "shuffle_captions": true, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true}
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/lsl_shuffle_captions/metrics.json:
--------------------------------------------------------------------------------
1 | {"train_acc": [0.5071111111111111, 0.501, 0.5021111111111111, 0.5645555555555556, 0.6984444444444444, 0.736, 0.7335555555555555, 0.7392222222222222, 0.7366666666666667, 0.7367777777777778, 0.7414444444444445, 0.7402222222222222, 0.7327777777777778, 0.7438888888888889, 0.7367777777777778, 0.7236666666666667, 0.7268888888888889, 0.7038888888888889, 0.7293333333333333, 0.7333333333333333, 0.7113333333333334, 0.6705555555555556, 0.7296666666666667, 0.7068888888888889, 0.7284444444444444, 0.6974444444444444, 0.7138888888888889, 0.7088888888888889, 0.7338888888888889, 0.696, 0.7006666666666667, 0.7127777777777777, 0.665, 0.6768888888888889, 0.726, 0.729, 0.7235555555555555, 0.6732222222222223, 0.7087777777777777, 0.6856666666666666, 0.7183333333333334, 0.6925555555555556, 0.7375555555555555, 0.7103333333333334, 0.7212222222222222, 0.7194444444444444, 0.7066666666666667, 0.7207777777777777, 0.6978888888888889, 0.697], "val_acc": [0.508, 0.508, 0.508, 0.54, 0.566, 0.554, 0.542, 0.556, 0.56, 0.578, 0.56, 0.546, 0.588, 0.564, 0.572, 0.58, 0.56, 0.572, 0.582, 0.592, 0.564, 0.55, 0.566, 0.562, 0.566, 0.568, 0.556, 0.566, 0.564, 0.54, 0.526, 0.562, 0.538, 0.566, 0.536, 0.55, 0.544, 0.522, 0.562, 0.538, 0.564, 0.542, 0.552, 0.55, 0.548, 0.55, 0.544, 0.546, 0.532, 0.55], "val_same_acc": [0.496, 0.496, 0.496, 0.526, 0.532, 0.544, 0.546, 0.534, 0.544, 0.552, 0.554, 0.552, 0.56, 0.566, 0.562, 0.568, 0.572, 0.58, 0.58, 0.55, 0.57, 0.554, 0.56, 0.566, 0.554, 0.566, 0.542, 0.522, 0.552, 0.514, 0.544, 0.552, 0.512, 0.54, 0.584, 0.582, 0.562, 0.534, 0.554, 0.554, 0.546, 0.538, 0.57, 0.55, 0.576, 0.564, 0.554, 0.574, 0.546, 0.564], "val_tre": [0.22047525447607041, 0.2645132866203785, 0.36374358320236205, 0.5220565661787987, 0.7167646891772746, 0.8003292497396469, 0.7903731316030026, 0.8116375431418419, 0.7994353666901588, 0.7786327120959758, 0.8107034644186497, 0.813195310741663, 0.7700015254616738, 0.8113658610582352, 0.7617959608137608, 0.7662459227144718, 0.7579508483409881, 0.75589797565341, 0.7805133513510227, 0.7729206332266331, 0.7705274262428283, 0.707739619165659, 0.7790121876001358, 0.7549196770191192, 0.7852870355248451, 0.7266519095599652, 0.7746121415793896, 0.7715683530569076, 0.7978267765641213, 0.7414474821686745, 0.7512542349100113, 0.7495681802928448, 0.6815140551030636, 0.7298193091452122, 0.793954403668642, 0.8172262015938759, 0.7798480809032917, 0.7092128167152405, 0.7335295847356319, 0.7364193704426288, 0.786985212892294, 0.7563742088973522, 0.8047082170248031, 0.7670680065453053, 0.7921611217856407, 0.8007868621647358, 0.768711370229721, 0.7819871633350849, 0.7565498836040497, 0.7600137263834477], "val_tre_std": [0.014784197950176159, 0.02652211147234335, 0.04997577132833228, 0.09928714512269435, 0.14315726153916938, 0.17905031274741376, 0.3028480897883202, 0.24420796565707084, 0.2854344250567698, 0.20748745243769307, 0.16452433254581292, 0.19958105132499038, 0.19181740719145532, 0.18428430107650812, 0.2091974202371749, 0.2142233483216199, 0.16718732490884747, 0.17163516532810213, 0.1424397817852027, 0.13710612071647124, 0.19242126031642523, 0.16510685661208033, 0.16180709629451084, 0.14522477034842432, 0.13575573003599034, 0.20591636374702857, 0.17189024005737322, 0.2046076913501664, 0.11446174753005237, 0.19160172738826214, 0.1564791769508523, 0.22557180313439085, 0.18962065680253348, 0.2226163961731732, 0.14819683260006034, 0.12277138594077944, 0.18841027141919195, 0.21871579157749851, 0.21611514095686526, 0.2226257086464708, 0.154523188492496, 0.18645111716967322, 0.12550173796621417, 0.17387476825860076, 0.12255660439773507, 0.18413762094468844, 0.14161339259122727, 0.1449093187359931, 0.16743659468361968, 0.16200972926336132], "test_acc": [0.496, 0.496, 0.496, 0.5205, 0.54575, 0.55475, 0.5495, 0.5565, 0.55175, 0.55725, 0.55975, 0.56525, 0.5665, 0.57075, 0.56825, 0.571, 0.5675, 0.57, 0.5795, 0.5745, 0.56625, 0.536, 0.5745, 0.5555, 0.5655, 0.559, 0.564, 0.54425, 0.562, 0.55275, 0.54975, 0.55675, 0.53125, 0.532, 0.56625, 0.569, 0.55375, 0.5315, 0.552, 0.5415, 0.56625, 0.547, 0.56625, 0.564, 0.5645, 0.562, 0.559, 0.5555, 0.552, 0.56075], "test_same_acc": [0.49675, 0.49675, 0.49675, 0.5275, 0.55825, 0.555, 0.5605, 0.55775, 0.56325, 0.563, 0.57075, 0.57375, 0.57425, 0.577, 0.57625, 0.58325, 0.57325, 0.57275, 0.57775, 0.5835, 0.581, 0.54825, 0.577, 0.564, 0.57325, 0.552, 0.575, 0.56125, 0.569, 0.55025, 0.55575, 0.55525, 0.5335, 0.542, 0.56825, 0.571, 0.57475, 0.54475, 0.5535, 0.55675, 0.56825, 0.557, 0.57425, 0.55975, 0.56625, 0.574, 0.5645, 0.56825, 0.566, 0.566], "test_acc_ci": [0.010956445129323426, 0.010956445129323426, 0.010956445129323426, 0.010944103654479887, 0.01089731798196235, 0.01089054591133406, 0.010890243110234041, 0.01088498873894112, 0.010884040552570539, 0.010877227105143801, 0.01086303449720657, 0.010850369300166697, 0.010847660668175189, 0.010836480101807737, 0.010841739988927054, 0.010825601269531176, 0.01084766066817519, 0.010844522747538268, 0.010820418481827538, 0.010819106793076773, 0.010837297097149038, 0.010917778052189695, 0.010830262341582499, 0.01087821933900489, 0.010850753426461914, 0.01088902493109461, 0.010850369300166697, 0.010895586881279043, 0.01086231199837309, 0.010898458127184778, 0.010895586881279043, 0.010887795589558063, 0.01093374057090596, 0.010926692372351296, 0.010857175760182755, 0.010848825742908767, 0.010865896391347562, 0.010924835092433891, 0.010895586881279043, 0.010903721834624864, 0.010857175760182755, 0.010897317981962348, 0.010848049731979476, 0.01087251336154042, 0.010862673599159416, 0.010854932298268838, 0.010872854610795638, 0.01087251336154042, 0.010880184915708004, 0.010868363686492783], "best_epoch": 13, "best_val_acc": 0.588, "best_val_same_acc": 0.56, "best_val_tre": 0.7700015254616738, "best_val_tre_std": 0.19181740719145532, "best_test_acc": 0.5665, "best_test_same_acc": 0.57425, "best_test_acc_ci": 0.010847660668175189, "lowest_val_tre": 0.22047525447607041, "lowest_val_tre_std": 0.014784197950176159, "has_same": true}
--------------------------------------------------------------------------------
/shapeworld/exp/lsl_shuffle_words/args.json:
--------------------------------------------------------------------------------
1 | {"exp_dir": "exp/lsl_shuffle_words", "predict_concept_hyp": true, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 27075, "language_filter": null, "shuffle_words": true, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 20.0, "save_checkpoint": false, "cuda": true, "predict_hyp": true, "use_hyp": true, "encode_hyp": false, "decode_hyp": true}
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/lsl_shuffle_words/metrics.json:
--------------------------------------------------------------------------------
1 | {"train_acc": [0.5034444444444445, 0.5135555555555555, 0.6038888888888889, 0.6691111111111111, 0.6986666666666667, 0.703, 0.7335555555555555, 0.737, 0.7327777777777778, 0.7257777777777777, 0.7482222222222222, 0.7406666666666667, 0.7463333333333333, 0.7477777777777778, 0.7313333333333333, 0.747, 0.7385555555555555, 0.7303333333333333, 0.7364444444444445, 0.7438888888888889, 0.7351111111111112, 0.7403333333333333, 0.7396666666666667, 0.7406666666666667, 0.7463333333333333, 0.7496666666666667, 0.7395555555555555, 0.74, 0.7352222222222222, 0.7293333333333333, 0.7483333333333333, 0.7407777777777778, 0.7425555555555555, 0.7446666666666667, 0.7411111111111112, 0.758, 0.7344444444444445, 0.7464444444444445, 0.7391111111111112, 0.7352222222222222, 0.7374444444444445, 0.737, 0.744, 0.743, 0.7445555555555555, 0.7363333333333333, 0.7406666666666667, 0.7398888888888889, 0.7377777777777778, 0.7367777777777778], "val_acc": [0.508, 0.52, 0.568, 0.61, 0.62, 0.62, 0.63, 0.64, 0.644, 0.652, 0.656, 0.648, 0.648, 0.646, 0.654, 0.648, 0.66, 0.65, 0.634, 0.656, 0.648, 0.654, 0.66, 0.668, 0.66, 0.658, 0.648, 0.652, 0.652, 0.638, 0.658, 0.652, 0.66, 0.656, 0.656, 0.664, 0.66, 0.66, 0.646, 0.652, 0.654, 0.65, 0.656, 0.656, 0.662, 0.652, 0.654, 0.658, 0.668, 0.65], "val_same_acc": [0.496, 0.5, 0.542, 0.562, 0.566, 0.572, 0.61, 0.6, 0.62, 0.62, 0.62, 0.622, 0.654, 0.638, 0.646, 0.654, 0.646, 0.666, 0.66, 0.656, 0.65, 0.652, 0.666, 0.674, 0.674, 0.668, 0.67, 0.678, 0.676, 0.652, 0.676, 0.684, 0.68, 0.678, 0.67, 0.686, 0.68, 0.682, 0.678, 0.674, 0.672, 0.68, 0.682, 0.678, 0.664, 0.664, 0.668, 0.682, 0.688, 0.666], "val_tre": [0.27602067959308624, 0.40746447333693503, 0.5171051873266697, 0.6125595195889473, 0.6366198836266994, 0.6640413806140423, 0.7124214672148228, 0.7231019198596478, 0.7267296032011509, 0.7232981150150299, 0.736225801974535, 0.7374791561663151, 0.7319063286185264, 0.7416931557953358, 0.7135514880418777, 0.7375768918097019, 0.7422215968072414, 0.7392168205976486, 0.7333029879629612, 0.7477176522910595, 0.7457649510502815, 0.7435495406389236, 0.7467366912662983, 0.7510937738120556, 0.7458743584156037, 0.7536228042840958, 0.7528315424025058, 0.7484392536580563, 0.7481056715548039, 0.7274645104706288, 0.7511067868769169, 0.7526578099429607, 0.7544187552034854, 0.7587365226745606, 0.7553361043334007, 0.7578833720088005, 0.7530133697390556, 0.7546444187760353, 0.7501327090263367, 0.7505363914966583, 0.7546211395263672, 0.7538118029236793, 0.7556851161122322, 0.7571718983650207, 0.7616352689862251, 0.7572485668361187, 0.7550260179936886, 0.7595631908476352, 0.763034375667572, 0.7599161682724953], "val_tre_std": [0.027511252715802494, 0.07593431201113533, 0.1262049245646053, 0.14781622666731317, 0.19731004650994827, 0.1538107448595129, 0.15554684902972157, 0.14695279508565937, 0.17908736485187154, 0.18875130684038025, 0.1769762953390025, 0.16766753936312187, 0.17139012632148498, 0.17359321816174428, 0.20311660648997165, 0.17161296839295997, 0.16578064885892926, 0.17290399315845784, 0.16886487327310604, 0.14806971197739538, 0.16859308914903975, 0.13572207758448554, 0.16111478218014064, 0.1637356360401312, 0.14536515798208788, 0.18004281884501797, 0.1591164335824455, 0.1388791819160653, 0.1534606178792528, 0.15306302122473106, 0.14378495034669006, 0.1367711261502094, 0.12677138741279045, 0.1316132040423189, 0.15562621157817208, 0.14923601261638875, 0.1393943019165844, 0.13612622115365428, 0.1349727064430412, 0.11432319623566965, 0.12727456126632894, 0.1289669480511882, 0.11890626436770418, 0.1310952292714821, 0.12584475311577023, 0.1257270419068187, 0.12828189650193889, 0.13452211085484508, 0.11821392361223679, 0.1296130167067139], "test_acc": [0.496, 0.50625, 0.548, 0.57, 0.58325, 0.59225, 0.59775, 0.60475, 0.612, 0.61975, 0.62475, 0.6275, 0.6395, 0.6455, 0.637, 0.646, 0.65125, 0.65, 0.65675, 0.6535, 0.65925, 0.655, 0.65825, 0.65775, 0.66075, 0.6675, 0.6625, 0.6605, 0.6605, 0.6565, 0.6645, 0.666, 0.663, 0.664, 0.6585, 0.666, 0.6655, 0.66575, 0.664, 0.66325, 0.6605, 0.6695, 0.663, 0.664, 0.669, 0.66825, 0.6675, 0.6695, 0.667, 0.67075], "test_same_acc": [0.49675, 0.50525, 0.5515, 0.57875, 0.59125, 0.59175, 0.60225, 0.60675, 0.613, 0.6225, 0.64475, 0.64375, 0.6445, 0.64975, 0.64275, 0.647, 0.66125, 0.6615, 0.65925, 0.67175, 0.666, 0.67575, 0.67625, 0.67825, 0.6805, 0.67325, 0.67875, 0.67725, 0.68, 0.66025, 0.67175, 0.68125, 0.6785, 0.68075, 0.67325, 0.6795, 0.67525, 0.6745, 0.6735, 0.6725, 0.66925, 0.677, 0.67775, 0.67475, 0.6765, 0.67325, 0.6735, 0.67625, 0.67475, 0.673], "test_acc_ci": [0.010956445129323426, 0.010956008551817583, 0.010902360982259758, 0.010834837616313176, 0.010788625838701608, 0.010769660496041646, 0.010735362127101257, 0.010708869846417036, 0.010675788905275336, 0.010630374322989525, 0.010551338232068006, 0.010545953426403656, 0.010505581716402, 0.010468282278954603, 0.010519262307637119, 0.010475868820770904, 0.010407995349129437, 0.010411593460537154, 0.010395301207757281, 0.010360994311448829, 0.01036099431144883, 0.010340070730989945, 0.010325578699884088, 0.010319730384074962, 0.010299028591176694, 0.010301015822329125, 0.010299028591176696, 0.010312870053815038, 0.010302008201680875, 0.010392560607082116, 0.01031875279754656, 0.010274924008082734, 0.010298033738898899, 0.010285025516345352, 0.010336224146025226, 0.010282003768113489, 0.010301015822329125, 0.01030299975720057, 0.0103138525628157, 0.010320707151008355, 0.01034390427966515, 0.010277963168230366, 0.010301015822329125, 0.010308931803144057, 0.01028200376811349, 0.010298033738898899, 0.010300022618907202, 0.010280994861727876, 0.01029703806183482, 0.010289042924241059], "best_epoch": 49, "best_val_acc": 0.668, "best_val_same_acc": 0.688, "best_val_tre": 0.763034375667572, "best_val_tre_std": 0.11821392361223679, "best_test_acc": 0.667, "best_test_same_acc": 0.67475, "best_test_acc_ci": 0.01029703806183482, "lowest_val_tre": 0.27602067959308624, "lowest_val_tre_std": 0.027511252715802494, "has_same": true}
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/meta/args.json:
--------------------------------------------------------------------------------
1 | {"exp_dir": "exp/meta", "predict_concept_hyp": false, "predict_image_hyp": false, "infer_hyp": false, "backbone": "vgg16_fixed", "multimodal_concept": false, "comparison": "dotp", "dropout": 0.0, "debug_bilinear": false, "poe": false, "predict_hyp_task": "generate", "n_infer": 10, "oracle": false, "max_train": null, "noise": 0.0, "class_noise_weight": 0.0, "noise_at_test": false, "noise_type": "gaussian", "fixed_noise_colors": null, "fixed_noise_colors_max_rgb": 0.2, "batch_size": 100, "epochs": 50, "data_dir": null, "lr": 0.0001, "tre_err": "cos", "tre_comp": "add", "optimizer": "adam", "seed": 29493, "language_filter": null, "shuffle_words": false, "shuffle_captions": false, "log_interval": 10, "pred_lambda": 1.0, "hypo_lambda": 10.0, "save_checkpoint": false, "cuda": true, "predict_hyp": false, "use_hyp": false, "encode_hyp": false, "decode_hyp": false}
2 |
--------------------------------------------------------------------------------
/shapeworld/exp/meta/metrics.json:
--------------------------------------------------------------------------------
1 | {"train_acc": [0.587, 0.7234444444444444, 0.7327777777777778, 0.7338888888888889, 0.7413333333333333, 0.7468888888888889, 0.7552222222222222, 0.7471111111111111, 0.7431111111111111, 0.7497777777777778, 0.7494444444444445, 0.7463333333333333, 0.7562222222222222, 0.7497777777777778, 0.7491111111111111, 0.7525555555555555, 0.751, 0.7422222222222222, 0.7553333333333333, 0.7503333333333333, 0.7524444444444445, 0.7457777777777778, 0.7596666666666667, 0.7496666666666667, 0.7512222222222222, 0.7546666666666667, 0.7551111111111111, 0.7454444444444445, 0.7604444444444445, 0.7524444444444445, 0.7546666666666667, 0.7572222222222222, 0.7493333333333333, 0.7554444444444445, 0.7573333333333333, 0.7665555555555555, 0.7581111111111111, 0.7556666666666667, 0.7588888888888888, 0.7476666666666667, 0.7614444444444445, 0.7563333333333333, 0.762, 0.764, 0.7618888888888888, 0.7621111111111111, 0.7532222222222222, 0.7506666666666667, 0.7588888888888888, 0.7618888888888888], "val_acc": [0.534, 0.568, 0.54, 0.558, 0.56, 0.592, 0.598, 0.606, 0.59, 0.608, 0.606, 0.614, 0.62, 0.596, 0.6, 0.6, 0.596, 0.618, 0.6, 0.57, 0.568, 0.572, 0.582, 0.586, 0.568, 0.594, 0.592, 0.59, 0.576, 0.584, 0.6, 0.586, 0.582, 0.594, 0.602, 0.606, 0.594, 0.568, 0.596, 0.594, 0.602, 0.612, 0.618, 0.604, 0.602, 0.58, 0.61, 0.594, 0.614, 0.614], "val_same_acc": [0.526, 0.524, 0.564, 0.586, 0.59, 0.596, 0.596, 0.588, 0.598, 0.604, 0.584, 0.606, 0.596, 0.598, 0.596, 0.608, 0.606, 0.602, 0.602, 0.602, 0.612, 0.606, 0.608, 0.6, 0.596, 0.604, 0.6, 0.584, 0.576, 0.58, 0.592, 0.58, 0.612, 0.614, 0.606, 0.604, 0.608, 0.604, 0.62, 0.594, 0.606, 0.628, 0.61, 0.59, 0.614, 0.62, 0.614, 0.602, 0.61, 0.606], "val_tre": [0.5916588901281357, 0.8042377699017524, 0.8100116618275642, 0.8228306672573089, 0.8091322653889657, 0.8141303354799747, 0.8227340990900993, 0.8060084843039512, 0.7999683973491192, 0.8017535305321216, 0.8079626688361168, 0.8147974437475205, 0.8114711083471775, 0.8146977169215679, 0.801983368396759, 0.8079127200245857, 0.7854678380787372, 0.8079566982984543, 0.8099119906425476, 0.7922585190832615, 0.8039774939119816, 0.7961225943863391, 0.807337353438139, 0.8064042346477509, 0.8044357794523239, 0.8018924917876721, 0.8004497699439526, 0.8031283156871796, 0.8019032092988491, 0.8100885750055313, 0.8107224105894566, 0.8069514398574829, 0.8026845241487026, 0.8086875425577164, 0.8073043225705624, 0.8142727278470993, 0.8092543596625328, 0.8049768908321857, 0.8047672483623027, 0.8046509600281715, 0.8093694306910038, 0.81609526014328, 0.8113720493614673, 0.8151067448556423, 0.8141982303857803, 0.8105539740622043, 0.8153765279650688, 0.8105866264998913, 0.812785368680954, 0.8175703119635582], "val_tre_std": [0.20776802624853982, 0.22506850861303512, 0.3550965570239999, 0.25240233315544186, 0.30489556283502167, 0.25192126585263014, 0.22163285096208185, 0.22398180211467097, 0.22136366888011189, 0.21504156556522225, 0.21378987075781997, 0.1961432996474801, 0.18039981992759827, 0.20156448897379214, 0.19800972594737676, 0.186366711354458, 0.2288897358079121, 0.19004879167233124, 0.17280907742916762, 0.22070847274461852, 0.17811587832274378, 0.18698121777982987, 0.17709791856239163, 0.18664346244800895, 0.16421413078859573, 0.163887529613401, 0.18290036932823, 0.18182176754039484, 0.18873458891343609, 0.18661134863337522, 0.17477719968532573, 0.17169934653941996, 0.16874678702261794, 0.16338142096155675, 0.15784144370700903, 0.1539296292378941, 0.15841589817970592, 0.16560336048870408, 0.15963311797982846, 0.1611074459988127, 0.15862874099906135, 0.14242522707771324, 0.1613770586649978, 0.15857943745800235, 0.14755391734663065, 0.1512508923237696, 0.15404908131241488, 0.15800266933740134, 0.15596154811266771, 0.14921550713938078], "test_acc": [0.5335, 0.55, 0.556, 0.56, 0.563, 0.5795, 0.58525, 0.583, 0.594, 0.59175, 0.58925, 0.59925, 0.5935, 0.5975, 0.601, 0.5965, 0.6025, 0.59175, 0.59625, 0.603, 0.605, 0.60125, 0.5995, 0.60775, 0.5975, 0.6005, 0.6085, 0.5945, 0.604, 0.59925, 0.60025, 0.6045, 0.60925, 0.61075, 0.60525, 0.59725, 0.601, 0.60675, 0.597, 0.59525, 0.60625, 0.60125, 0.59825, 0.6015, 0.597, 0.60075, 0.6015, 0.593, 0.60825, 0.59775], "test_same_acc": [0.53775, 0.55625, 0.56175, 0.56925, 0.57175, 0.579, 0.58425, 0.5825, 0.5915, 0.59425, 0.5975, 0.603, 0.59875, 0.60125, 0.59475, 0.60825, 0.6025, 0.60275, 0.607, 0.60675, 0.61175, 0.60225, 0.6155, 0.613, 0.60425, 0.617, 0.6065, 0.60775, 0.61275, 0.61275, 0.61275, 0.6105, 0.603, 0.6125, 0.61125, 0.61125, 0.60775, 0.61625, 0.6075, 0.60625, 0.6035, 0.61425, 0.6135, 0.61475, 0.61175, 0.608, 0.614, 0.60625, 0.616, 0.6175], "test_acc_ci": [0.010928886433295707, 0.010894711930421795, 0.010880510063727481, 0.010864828448800976, 0.010856803609805003, 0.010818228777738989, 0.010798191213694077, 0.010805637440128184, 0.01076657092520641, 0.01076553529556241, 0.010763976436330348, 0.010730300426683076, 0.010752345999681882, 0.010738148696208066, 0.010744762286662045, 0.010724607049765274, 0.010724033697727736, 0.010747487077801024, 0.010728031831229574, 0.01071300110832044, 0.010696260256364136, 0.010727462858826408, 0.010700499462641917, 0.010686435138617322, 0.01073143034953286, 0.010694432415397274, 0.010700499462641917, 0.010730300426683076, 0.010696260256364137, 0.01070768288660063, 0.010705300161602193, 0.010700499462641917, 0.010707088306205145, 0.010680198522119098, 0.010696868064414928, 0.010715930029050209, 0.010715345709396173, 0.010680825508826552, 0.01072517967157194, 0.010731994217641937, 0.010713001108320441, 0.010699291938605098, 0.01070827673329724, 0.010697475136539231, 0.010715345709396173, 0.010715345709396173, 0.010699291938605098, 0.010737036251772414, 0.010677683180206978, 0.010699896068274446], "best_epoch": 43, "best_val_acc": 0.618, "best_val_same_acc": 0.61, "best_val_tre": 0.8113720493614673, "best_val_tre_std": 0.1613770586649978, "best_test_acc": 0.59825, "best_test_same_acc": 0.6135, "best_test_acc_ci": 0.01070827673329724, "lowest_val_tre": 0.5916588901281357, "lowest_val_tre_std": 0.20776802624853982, "has_same": true}
--------------------------------------------------------------------------------
/shapeworld/lsl/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset utilities
3 | """
4 |
5 | import os
6 | import json
7 | import logging
8 |
9 | import torch
10 | import numpy as np
11 | import torch.utils.data as data
12 | from torchvision import transforms
13 |
14 | from utils import next_random, OrderedCounter
15 |
16 | # Set your data directory here!
17 | DATA_DIR = '/u/scr/muj/shapeworld_4k'
18 | SPLIT_OPTIONS = ['train', 'val', 'test', 'val_same', 'test_same']
19 |
20 | logging.getLogger(__name__).setLevel(logging.INFO)
21 |
22 | SOS_TOKEN = ''
23 | EOS_TOKEN = ''
24 | PAD_TOKEN = ''
25 | UNK_TOKEN = ''
26 | N_EX = 4 # number of examples per task
27 |
28 | random = next_random()
29 | COLORS = {
30 | 'black', 'red', 'green', 'blue', 'yellow', 'magenta', 'cyan', 'white'
31 | }
32 | SHAPES = {
33 | 'square', 'rectangle', 'triangle', 'pentagon', 'cross', 'circle',
34 | 'semicircle', 'ellipse'
35 | }
36 |
37 |
38 | def get_max_hint_length(data_dir=None):
39 | """
40 | Get the maximum number of words in a sentence across all splits
41 | """
42 | if data_dir is None:
43 | data_dir = DATA_DIR
44 | max_len = 0
45 | for split in ['train', 'val', 'test', 'val_same', 'test_same']:
46 | for tf in ['hints.json', 'test_hints.json']:
47 | hints_file = os.path.join(data_dir, 'shapeworld', split, tf)
48 | if os.path.exists(hints_file):
49 | with open(hints_file) as fp:
50 | hints = json.load(fp)
51 | split_max_len = max([len(hint.split()) for hint in hints])
52 | if split_max_len > max_len:
53 | max_len = split_max_len
54 | if max_len == 0:
55 | raise RuntimeError("Can't find any splits in {}".format(data_dir))
56 | return max_len
57 |
58 |
59 | def get_black_mask(imgs):
60 | if len(imgs.shape) == 4:
61 | # Then color is 1st dim
62 | col_dim = 1
63 | else:
64 | col_dim = 0
65 | total = imgs.sum(dim=col_dim)
66 |
67 | # Put dim back in
68 | is_black = total == 0.0
69 | is_black = is_black.unsqueeze(col_dim).expand_as(imgs)
70 |
71 | return is_black
72 |
73 |
74 | class ShapeWorld(data.Dataset):
75 | r"""Loader for ShapeWorld data as in L3.
76 |
77 | @param split: string [default: train]
78 | train|val|test|val_same|test_same
79 | @param vocab: ?Object [default: None]
80 | initialize with a vocabulary
81 | important to do this for validation/test set.
82 | @param augment: boolean [default: False]
83 | negatively sample data from other concepts.
84 | @param max_size: limit size to this many training examples
85 | @param precomputed_features: load precomputed VGG features rather than raw image data
86 | @param noise: amount of uniform noise to add to examples
87 | @param class_noise_weight: how much of the noise added to examples should
88 | be the same across (pos/neg classes) (between
89 | 0.0 and 1.0)
90 |
91 | NOTE: for now noise/class_noise_weight has no impact on val/test datasets
92 | """
93 |
94 | def __init__(self,
95 | split='train',
96 | vocab=None,
97 | augment=False,
98 | max_size=None,
99 | precomputed_features=True,
100 | preprocess=False,
101 | noise=0.0,
102 | class_noise_weight=0.5,
103 | fixed_noise_colors=None,
104 | fixed_noise_colors_max_rgb=0.2,
105 | noise_type='gaussian',
106 | data_dir=None,
107 | language_filter=None,
108 | shuffle_words=False,
109 | shuffle_captions=False):
110 | super(ShapeWorld, self).__init__()
111 | self.split = split
112 | assert self.split in SPLIT_OPTIONS
113 | self.vocab = vocab
114 | self.augment = augment
115 | self.max_size = max_size
116 |
117 | assert noise_type in ('gaussian', 'normal')
118 | self.noise_type = noise_type
119 |
120 | # Positive class noise
121 | if precomputed_features:
122 | self.image_dim = (4608, )
123 | else:
124 | self.image_dim = (3, 64, 64)
125 |
126 | self.noise = noise
127 | self.fixed_noise_colors = fixed_noise_colors
128 | self.fixed_noise_colors_max_rgb = fixed_noise_colors_max_rgb
129 | if not class_noise_weight >= 0.0 and class_noise_weight <= 1.0:
130 | raise ValueError(
131 | "Class noise weight must be between 0 and 1, got {}".format(
132 | class_noise_weight))
133 | self.class_noise_weight = class_noise_weight
134 |
135 | if data_dir is None:
136 | data_dir = DATA_DIR
137 | self.data_dir = data_dir
138 | split_dir = os.path.join(data_dir, 'shapeworld', split)
139 | if not os.path.exists(split_dir):
140 | raise RuntimeError("Can't find {}".format(split_dir))
141 |
142 | self.precomputed_features = precomputed_features
143 | if self.precomputed_features:
144 | in_features_name = 'inputs.feats.npz'
145 | ex_features_name = 'examples.feats.npz'
146 | else:
147 | in_features_name = 'inputs.npz'
148 | ex_features_name = 'examples.npz'
149 |
150 | self.preprocess = None
151 | if preprocess:
152 | self.preprocess = transforms.Compose([
153 | transforms.ToPILImage(),
154 | transforms.Resize((224, 224)),
155 | transforms.ToTensor(),
156 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
157 | std=[0.229, 0.224, 0.225])
158 | ])
159 | # hints = language
160 | # examples = images with positive labels (pre-training)
161 | # input = test time input
162 | # label = test time label
163 | labels = np.load(os.path.join(split_dir, 'labels.npz'))['arr_0']
164 | in_features = np.load(os.path.join(split_dir, in_features_name))['arr_0']
165 | ex_features = np.load(os.path.join(split_dir, ex_features_name))['arr_0']
166 | with open(os.path.join(split_dir, 'hints.json')) as fp:
167 | hints = json.load(fp)
168 |
169 | test_hints = os.path.join(split_dir, 'test_hints.json')
170 | if self.fixed_noise_colors is not None:
171 | assert os.path.exists(test_hints)
172 | if os.path.exists(test_hints):
173 | with open(test_hints, 'r') as fp:
174 | test_hints = json.load(fp)
175 | self.test_hints = test_hints
176 | else:
177 | self.test_hints = None
178 |
179 | if self.test_hints is not None:
180 | for a, b, label in zip(hints, test_hints, labels):
181 | if label:
182 | assert a == b, (a, b, label)
183 | # else: # XXX: What?/
184 | # assert a != b, (a, b, label)
185 |
186 | if not self.precomputed_features:
187 | # Bring channel to first dim
188 | in_features = np.transpose(in_features, (0, 3, 1, 2))
189 | ex_features = np.transpose(ex_features, (0, 1, 4, 2, 3))
190 |
191 | if self.max_size is not None:
192 | labels = labels[:self.max_size]
193 | in_features = in_features[:self.max_size]
194 | ex_features = ex_features[:self.max_size]
195 | hints = hints[:self.max_size]
196 |
197 | n_data = len(hints)
198 |
199 | self.in_features = in_features
200 | self.ex_features = ex_features
201 | self.hints = hints
202 |
203 | if self.vocab is None:
204 | self.create_vocab(hints, test_hints)
205 |
206 | self.w2i, self.i2w = self.vocab['w2i'], self.vocab['i2w']
207 | self.vocab_size = len(self.w2i)
208 |
209 | # Language processing
210 | self.language_filter = language_filter
211 | if self.language_filter is not None:
212 | assert self.language_filter in ['color', 'nocolor']
213 | self.shuffle_words = shuffle_words
214 | self.shuffle_captions = shuffle_captions
215 |
216 | # this is the maximum number of tokens in a sentence
217 | max_length = get_max_hint_length(data_dir)
218 |
219 | hints, hint_lengths = [], []
220 | for hint in self.hints:
221 | hint_tokens = hint.split()
222 | # Hint processing
223 | if self.language_filter == 'color':
224 | hint_tokens = [t for t in hint_tokens if t in COLORS]
225 | elif self.language_filter == 'nocolor':
226 | hint_tokens = [t for t in hint_tokens if t not in COLORS]
227 | if self.shuffle_words:
228 | random.shuffle(hint_tokens)
229 |
230 | hint = [SOS_TOKEN, *hint_tokens, EOS_TOKEN]
231 | hint_length = len(hint)
232 |
233 | hint.extend([PAD_TOKEN] * (max_length + 2 - hint_length))
234 | hint = [self.w2i.get(w, self.w2i[UNK_TOKEN]) for w in hint]
235 |
236 | hints.append(hint)
237 | hint_lengths.append(hint_length)
238 |
239 | hints = np.array(hints)
240 | hint_lengths = np.array(hint_lengths)
241 |
242 | if self.test_hints is not None:
243 | test_hints, test_hint_lengths = [], []
244 | for test_hint in self.test_hints:
245 | test_hint_tokens = test_hint.split()
246 |
247 | if self.language_filter == 'color':
248 | test_hint_tokens = [
249 | t for t in test_hint_tokens if t in COLORS
250 | ]
251 | elif self.language_filter == 'nocolor':
252 | test_hint_tokens = [
253 | t for t in test_hint_tokens if t not in COLORS
254 | ]
255 | if self.shuffle_words:
256 | random.shuffle(test_hint_tokens)
257 |
258 | test_hint = [SOS_TOKEN, *test_hint_tokens, EOS_TOKEN]
259 | test_hint_length = len(test_hint)
260 |
261 | test_hint.extend([PAD_TOKEN] * (max_length + 2 - test_hint_length))
262 |
263 | test_hint = [
264 | self.w2i.get(w, self.w2i[UNK_TOKEN]) for w in test_hint
265 | ]
266 |
267 | test_hints.append(test_hint)
268 | test_hint_lengths.append(test_hint_length)
269 |
270 | test_hints = np.array(test_hints)
271 | test_hint_lengths = np.array(test_hint_lengths)
272 |
273 | data = []
274 | for i in range(n_data):
275 | if self.shuffle_captions:
276 | hint_i = random.randint(len(hints))
277 | test_hint_i = random.randint(len(test_hints))
278 | else:
279 | hint_i = i
280 | test_hint_i = i
281 | if self.test_hints is not None:
282 | th = test_hints[test_hint_i]
283 | thl = test_hint_lengths[test_hint_i]
284 | else:
285 | th = hints[test_hint_i]
286 | thl = hint_lengths[test_hint_i]
287 | data_i = (ex_features[i], in_features[i], labels[i], hints[hint_i],
288 | hint_lengths[hint_i], th, thl)
289 | data.append(data_i)
290 |
291 | self.data = data
292 | self.max_length = max_length
293 |
294 | def create_vocab(self, hints, test_hints):
295 | w2i = dict()
296 | i2w = dict()
297 | w2c = OrderedCounter()
298 |
299 | special_tokens = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN]
300 | for st in special_tokens:
301 | i2w[len(w2i)] = st
302 | w2i[st] = len(w2i)
303 |
304 | for hint in hints:
305 | hint_tokens = hint.split()
306 | w2c.update(hint_tokens)
307 |
308 | if test_hints is not None:
309 | for hint in test_hints:
310 | hint_tokens = hint.split()
311 | w2c.update(hint_tokens)
312 |
313 | for w, c in list(w2c.items()):
314 | i2w[len(w2i)] = w
315 | w2i[w] = len(w2i)
316 |
317 | assert len(w2i) == len(i2w)
318 | vocab = dict(w2i=w2i, i2w=i2w)
319 | self.vocab = vocab
320 |
321 | logging.info('Created vocab with %d words.' % len(w2c))
322 |
323 | def __len__(self):
324 | return len(self.data)
325 |
326 | def sample_train(self, n_batch):
327 | assert self.split == 'train'
328 | n_train = len(self.data)
329 | batch_examples = []
330 | batch_image = []
331 | batch_label = []
332 | batch_hint = []
333 | batch_hint_length = []
334 | if self.test_hints is not None:
335 | batch_test_hint = []
336 | batch_test_hint_length = []
337 |
338 | for _ in range(n_batch):
339 | index = random.randint(n_train)
340 | examples, image, label, hint, hint_length, test_hint, test_hint_length = \
341 | self.__getitem__(index)
342 |
343 | batch_examples.append(examples)
344 | batch_image.append(image)
345 | batch_label.append(label)
346 | batch_hint.append(hint)
347 | batch_hint_length.append(hint_length)
348 | if self.test_hints is not None:
349 | batch_test_hint.append(test_hint)
350 | batch_test_hint_length.append(test_hint_length)
351 |
352 | batch_examples = torch.stack(batch_examples)
353 | batch_image = torch.stack(batch_image)
354 | batch_label = torch.from_numpy(np.array(batch_label)).long()
355 | batch_hint = torch.stack(batch_hint)
356 | batch_hint_length = torch.from_numpy(
357 | np.array(batch_hint_length)).long()
358 | if self.test_hints is not None:
359 | batch_test_hint = torch.stack(batch_test_hint)
360 | batch_test_hint_length = torch.from_numpy(
361 | np.array(batch_test_hint_length)).long()
362 | else:
363 | batch_test_hint = None
364 | batch_test_hint_length = None
365 |
366 | return (
367 | batch_examples, batch_image, batch_label, batch_hint,
368 | batch_hint_length, batch_test_hint, batch_test_hint_length
369 | )
370 |
371 | def __getitem__(self, index):
372 | if self.split == 'train' and self.augment:
373 | examples, image, label, hint, hint_length, test_hint, test_hint_length = self.data[
374 | index]
375 |
376 | # tie a language to a concept; convert to pytorch.
377 | hint = torch.from_numpy(hint).long()
378 | test_hint = torch.from_numpy(test_hint).long()
379 |
380 | # in training, pick whether to show positive or negative example.
381 | sample_label = random.randint(2)
382 | n_train = len(self.data)
383 |
384 | if sample_label == 0:
385 | # if we are training, we need to negatively sample data and
386 | # return a tuple (example_z, hint_z, 1) or...
387 | # return a tuple (example_z, hint_other_z, 0).
388 | # Sample a new test hint as well.
389 | examples2, image2, _, support_hint2, support_hint_length2, query_hint2, query_hint_length2 = self.data[
390 | random.randint(n_train)]
391 |
392 | # pick either an example or an image.
393 | swap = random.randint(N_EX + 1)
394 | if swap == N_EX:
395 | feats = image2
396 | # Use the QUERY hint of the new example
397 | test_hint = query_hint2
398 | test_hint_length = query_hint_length2
399 | else:
400 | feats = examples2[swap, ...]
401 | # Use the SUPPORT hint of the new example
402 | test_hint = support_hint2
403 | test_hint_length = support_hint_length2
404 |
405 | test_hint = torch.from_numpy(test_hint).long()
406 |
407 | feats = torch.from_numpy(feats).float()
408 | examples = torch.from_numpy(examples).float()
409 |
410 | if self.preprocess is not None:
411 | feats = self.preprocess(feats)
412 | examples = torch.stack(
413 | [self.preprocess(e) for e in examples])
414 | return examples, feats, 0, hint, hint_length, test_hint, test_hint_length
415 | else: # sample_label == 1
416 | swap = random.randint((N_EX + 1 if label == 1 else N_EX))
417 | # pick either an example or an image.
418 | if swap == N_EX:
419 | feats = image
420 | else:
421 | feats = examples[swap, ...]
422 | if label == 1:
423 | examples[swap, ...] = image
424 | else:
425 | examples[swap, ...] = examples[random.randint(N_EX
426 | ), ...]
427 |
428 | # This is a positive example, so whatever example we've chosen,
429 | # assume the query hint matches the support hint.
430 | test_hint = hint
431 | test_hint_length = hint_length
432 |
433 | feats = torch.from_numpy(feats).float()
434 | examples = torch.from_numpy(examples).float()
435 |
436 | if self.preprocess is not None:
437 | feats = self.preprocess(feats)
438 | examples = torch.stack(
439 | [self.preprocess(e) for e in examples])
440 | return examples, feats, 1, hint, hint_length, test_hint, test_hint_length
441 |
442 | else: # val, val_same, test, test_same
443 | examples, image, label, hint, hint_length, test_hint, test_hint_length = self.data[
444 | index]
445 |
446 | # no fancy stuff. just return image.
447 | image = torch.from_numpy(image).float()
448 |
449 | # NOTE: we provide the oracle text.
450 | hint = torch.from_numpy(hint).long()
451 | test_hint = torch.from_numpy(test_hint).long()
452 | examples = torch.from_numpy(examples).float()
453 |
454 | if self.preprocess is not None:
455 | image = self.preprocess(image)
456 | examples = torch.stack([self.preprocess(e) for e in examples])
457 | return examples, image, label, hint, hint_length, test_hint, test_hint_length
458 |
459 | def to_text(self, hints):
460 | texts = []
461 | for hint in hints:
462 | text = []
463 | for tok in hint:
464 | i = tok.item()
465 | w = self.vocab['i2w'].get(i, UNK_TOKEN)
466 | if w == PAD_TOKEN:
467 | break
468 | text.append(w)
469 | texts.append(text)
470 |
471 | return texts
472 |
473 |
474 | def extract_features(hints):
475 | """
476 | Extract features from hints
477 | """
478 | all_feats = []
479 | for hint in hints:
480 | feats = []
481 | for maybe_rel in ['above', 'below', 'left', 'right']:
482 | if maybe_rel in hint:
483 | rel = maybe_rel
484 | rel_idx = hint.index(rel)
485 | break
486 | else:
487 | raise RuntimeError("Didn't find relation: {}".format(hint))
488 | # Add relation
489 | feats.append('rel:{}'.format(rel))
490 | fst, snd = hint[:rel_idx], hint[rel_idx:]
491 | # fst: [, a, ..., is]
492 | fst_shape = fst[2:fst.index('is')]
493 | # snd: [..., a, ..., ., ]
494 | try:
495 | snd_shape = snd[snd.index('a') + 1:-2]
496 | except ValueError:
497 | # Use "an"
498 | snd_shape = snd[snd.index('an') + 1:-2]
499 |
500 | for name, fragment in [('fst', fst_shape), ('snd', snd_shape)]:
501 | for feat in fragment:
502 | if feat != 'shape':
503 | if feat in COLORS:
504 | feats.append('{}:color:{}'.format(name, feat))
505 | else:
506 | assert feat in SHAPES, hint
507 | feats.append('{}:shape:{}'.format(name, feat))
508 | all_feats.append(feats)
509 | return all_feats
510 |
--------------------------------------------------------------------------------
/shapeworld/lsl/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Models
3 | """
4 |
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 | import torch.nn.utils.rnn as rnn_utils
11 |
12 |
13 | class ExWrapper(nn.Module):
14 | """
15 | Wrap around a model and allow training on examples
16 | i.e. tensor inputs of shape
17 | (batch_size, n_ex, *img_dims)
18 | """
19 |
20 | def __init__(self, model):
21 | super(ExWrapper, self).__init__()
22 | self.model = model
23 |
24 | def forward(self, x):
25 | batch_size = x.shape[0]
26 | if len(x.shape) == 5:
27 | n_ex = x.shape[1]
28 | img_dim = x.shape[2:]
29 | # Flatten out examples first
30 | x_flat = x.view(batch_size * n_ex, *img_dim)
31 | else:
32 | x_flat = x
33 |
34 | x_enc = self.model(x_flat)
35 |
36 | if len(x.shape) == 5:
37 | x_enc = x_enc.view(batch_size, n_ex, -1)
38 |
39 | return x_enc
40 |
41 |
42 | class Identity(nn.Module):
43 | def forward(self, x):
44 | return x
45 |
46 |
47 | class ImageRep(nn.Module):
48 | r"""Two fully-connected layers to form a final image
49 | representation.
50 |
51 | VGG-16 -> FC -> ReLU -> FC
52 |
53 | Paper uses 512 hidden dimension.
54 | """
55 |
56 | def __init__(self, backbone=None, hidden_size=512):
57 | super(ImageRep, self).__init__()
58 | if backbone is None:
59 | self.backbone = Identity()
60 | self.backbone.final_feat_dim = 4608
61 | else:
62 | self.backbone = backbone
63 | self.model = nn.Sequential(
64 | nn.Linear(self.backbone.final_feat_dim, hidden_size), nn.ReLU(),
65 | nn.Linear(hidden_size, hidden_size))
66 |
67 | def forward(self, x):
68 | x_enc = self.backbone(x)
69 | return self.model(x_enc)
70 |
71 |
72 | class TextRep(nn.Module):
73 | r"""Deterministic Bowman et. al. model to form
74 | text representation.
75 |
76 | Again, this uses 512 hidden dimensions.
77 | """
78 |
79 | def __init__(self, embedding_module):
80 | super(TextRep, self).__init__()
81 | self.embedding = embedding_module
82 | self.embedding_dim = embedding_module.embedding_dim
83 | self.gru = nn.GRU(self.embedding_dim, 512)
84 |
85 | def forward(self, seq, length):
86 | batch_size = seq.size(0)
87 |
88 | if batch_size > 1:
89 | sorted_lengths, sorted_idx = torch.sort(length, descending=True)
90 | seq = seq[sorted_idx]
91 |
92 | # reorder from (B,L,D) to (L,B,D)
93 | seq = seq.transpose(0, 1)
94 |
95 | # embed your sequences
96 | embed_seq = self.embedding(seq)
97 |
98 | packed = rnn_utils.pack_padded_sequence(
99 | embed_seq,
100 | sorted_lengths.data.cpu().tolist()
101 | if batch_size > 1 else length.data.tolist())
102 |
103 | _, hidden = self.gru(packed)
104 | hidden = hidden[-1, ...]
105 |
106 | if batch_size > 1:
107 | _, reversed_idx = torch.sort(sorted_idx)
108 | hidden = hidden[reversed_idx]
109 |
110 | return hidden
111 |
112 |
113 | class MultimodalDeepRep(nn.Module):
114 | def __init__(self):
115 | super(MultimodalDeepRep, self).__init__()
116 | self.model = nn.Sequential(nn.Linear(512 * 2, 512 * 2), nn.ReLU(),
117 | nn.Linear(512 * 2, 512), nn.ReLU(),
118 | nn.Linear(512, 512))
119 |
120 | def forward(self, x, y):
121 | xy = torch.cat([x, y], dim=1)
122 | return self.model(xy)
123 |
124 |
125 | class MultimodalRep(nn.Module):
126 | r"""Concat Image and Text representations."""
127 |
128 | def __init__(self):
129 | super(MultimodalRep, self).__init__()
130 | self.model = nn.Sequential(nn.Linear(512 * 2, 512), nn.ReLU(),
131 | nn.Linear(512, 512))
132 |
133 | def forward(self, x, y):
134 | xy = torch.cat([x, y], dim=1)
135 | return self.model(xy)
136 |
137 |
138 | class MultimodalSumExp(nn.Module):
139 | def forward(self, x, y):
140 | return x + y
141 |
142 |
143 | class MultimodalLinearRep(nn.Module):
144 | def __init__(self):
145 | super(MultimodalLinearRep, self).__init__()
146 | self.model = nn.Linear(512 * 2, 512)
147 |
148 | def forward(self, x, y):
149 | xy = torch.cat([x, y], dim=1)
150 | return self.model(xy)
151 |
152 |
153 | class MultimodalWeightedRep(nn.Module):
154 | def __init__(self):
155 | super(MultimodalWeightedRep, self).__init__()
156 | self.model = nn.Sequential(nn.Linear(512 * 2, 512), nn.ReLU(),
157 | nn.Linear(512, 1), nn.Sigmoid())
158 |
159 | def forward(self, x, y):
160 | xy = torch.cat([x, y], dim=1)
161 | w = self.model(xy)
162 | out = w * x + (1. - w) * y
163 | return out
164 |
165 |
166 | class MultimodalSingleWeightRep(nn.Module):
167 | def __init__(self):
168 | super(MultimodalSingleWeightRep, self).__init__()
169 | self.w = nn.Parameter(torch.normal(torch.zeros(1), 1))
170 |
171 | def forward(self, x, y):
172 | w = torch.sigmoid(self.w)
173 | out = w * x + (1. - w) * y
174 | return out
175 |
176 |
177 | class TextProposal(nn.Module):
178 | r"""Reverse proposal model, estimating:
179 |
180 | argmax_lambda log q(w_i|x_1, y_1, ..., x_n, y_n; lambda)
181 |
182 | approximation to the distribution of descriptions.
183 |
184 | Because they use only positive labels, it actually simplifies to
185 |
186 | argmax_lambda log q(w_i|x_1, ..., x_4; lambda)
187 |
188 | https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/model.py
189 | """
190 |
191 | def __init__(self, embedding_module):
192 | super(TextProposal, self).__init__()
193 | self.embedding = embedding_module
194 | self.embedding_dim = embedding_module.embedding_dim
195 | self.vocab_size = embedding_module.num_embeddings
196 | self.gru = nn.GRU(self.embedding_dim, 512)
197 | self.outputs2vocab = nn.Linear(512, self.vocab_size)
198 |
199 | def forward(self, feats, seq, length):
200 | # feats is from example images
201 | batch_size = seq.size(0)
202 |
203 | if batch_size > 1:
204 | # BUGFIX? dont we need to sort feats too?
205 | sorted_lengths, sorted_idx = torch.sort(length, descending=True)
206 | seq = seq[sorted_idx]
207 | feats = feats[sorted_idx]
208 |
209 | feats = feats.unsqueeze(0)
210 | # reorder from (B,L,D) to (L,B,D)
211 | seq = seq.transpose(0, 1)
212 |
213 | # embed your sequences
214 | embed_seq = self.embedding(seq)
215 |
216 | packed_input = rnn_utils.pack_padded_sequence(embed_seq,
217 | sorted_lengths.cpu())
218 |
219 | # shape = (seq_len, batch, hidden_dim)
220 | packed_output, _ = self.gru(packed_input, feats)
221 | output = rnn_utils.pad_packed_sequence(packed_output)
222 | output = output[0].contiguous()
223 |
224 | # reorder from (L,B,D) to (B,L,D)
225 | output = output.transpose(0, 1)
226 |
227 | if batch_size > 1:
228 | _, reversed_idx = torch.sort(sorted_idx)
229 | output = output[reversed_idx]
230 |
231 | max_length = output.size(1)
232 | output_2d = output.view(batch_size * max_length, 512)
233 | outputs_2d = self.outputs2vocab(output_2d)
234 | outputs = outputs_2d.view(batch_size, max_length, self.vocab_size)
235 |
236 | return outputs
237 |
238 | def sample(self, feats, sos_index, eos_index, pad_index, greedy=False):
239 | """Generate from image features using greedy search."""
240 | with torch.no_grad():
241 | batch_size = feats.size(0)
242 |
243 | # initialize hidden states using image features
244 | states = feats.unsqueeze(0)
245 |
246 | # first input is SOS token
247 | inputs = np.array([sos_index for _ in range(batch_size)])
248 | inputs = torch.from_numpy(inputs)
249 | inputs = inputs.unsqueeze(1)
250 | inputs = inputs.to(feats.device)
251 |
252 | # save SOS as first generated token
253 | inputs_npy = inputs.squeeze(1).cpu().numpy()
254 | sampled_ids = [[w] for w in inputs_npy]
255 |
256 | # (B,L,D) to (L,B,D)
257 | inputs = inputs.transpose(0, 1)
258 |
259 | # compute embeddings
260 | inputs = self.embedding(inputs)
261 |
262 | for i in range(20): # like in jacobs repo
263 | outputs, states = self.gru(inputs,
264 | states) # outputs: (L=1,B,H)
265 | outputs = outputs.squeeze(0) # outputs: (B,H)
266 | outputs = self.outputs2vocab(outputs) # outputs: (B,V)
267 |
268 | if greedy:
269 | predicted = outputs.max(1)[1]
270 | predicted = predicted.unsqueeze(1)
271 | else:
272 | outputs = F.softmax(outputs, dim=1)
273 | predicted = torch.multinomial(outputs, 1)
274 |
275 | predicted_npy = predicted.squeeze(1).cpu().numpy()
276 | predicted_lst = predicted_npy.tolist()
277 |
278 | for w, so_far in zip(predicted_lst, sampled_ids):
279 | if so_far[-1] != eos_index:
280 | so_far.append(w)
281 |
282 | inputs = predicted.transpose(0, 1) # inputs: (L=1,B)
283 | inputs = self.embedding(inputs) # inputs: (L=1,B,E)
284 |
285 | sampled_lengths = [len(text) for text in sampled_ids]
286 | sampled_lengths = np.array(sampled_lengths)
287 |
288 | max_length = max(sampled_lengths)
289 | padded_ids = np.ones((batch_size, max_length)) * pad_index
290 |
291 | for i in range(batch_size):
292 | padded_ids[i, :sampled_lengths[i]] = sampled_ids[i]
293 |
294 | sampled_lengths = torch.from_numpy(sampled_lengths).long()
295 | sampled_ids = torch.from_numpy(padded_ids).long()
296 |
297 | return sampled_ids, sampled_lengths
298 |
299 |
300 | class EmbedImageRep(nn.Module):
301 | def __init__(self, z_dim):
302 | super(EmbedImageRep, self).__init__()
303 | self.z_dim = z_dim
304 | self.model = nn.Sequential(nn.Linear(self.z_dim, 512), nn.ReLU(),
305 | nn.Linear(512, 512))
306 |
307 | def forward(self, x):
308 | return self.model(x)
309 |
310 |
311 | class EmbedTextRep(nn.Module):
312 | def __init__(self, z_dim):
313 | super(EmbedTextRep, self).__init__()
314 | self.z_dim = z_dim
315 | self.model = nn.Sequential(nn.Linear(self.z_dim, 512), nn.ReLU(),
316 | nn.Linear(512, 512))
317 |
318 | def forward(self, x):
319 | return self.model(x)
320 |
321 |
322 | class Scorer(nn.Module):
323 | def __init__(self):
324 | super(Scorer, self).__init__()
325 |
326 | def forward(self, x, y):
327 | raise NotImplementedError
328 |
329 | def score(self, x, y):
330 | raise NotImplementedError
331 |
332 | def batchwise_score(self, x, y):
333 | raise NotImplementedError
334 |
335 |
336 | class DotPScorer(Scorer):
337 | def __init__(self):
338 | super(DotPScorer, self).__init__()
339 |
340 | def score(self, x, y):
341 | return torch.sum(x * y, dim=1)
342 |
343 | def batchwise_score(self, y, x):
344 | # REVERSED
345 | bw_scores = torch.einsum('ijk,ik->ij', (x, y))
346 | return torch.sum(bw_scores, dim=1)
347 |
348 |
349 | class BilinearScorer(DotPScorer):
350 | def __init__(self, hidden_size, dropout=0.0, identity_debug=False):
351 | super(BilinearScorer, self).__init__()
352 | self.bilinear = nn.Linear(hidden_size, hidden_size, bias=False)
353 | self.dropout_p = dropout
354 | if self.dropout_p > 0.0:
355 | self.dropout = nn.Dropout(p=self.dropout_p)
356 | else:
357 | self.dropout = lambda x: x
358 | if identity_debug:
359 | # Set this as identity matrix to make sure we get the same output
360 | # as DotPScorer
361 | self.bilinear.weight = nn.Parameter(
362 | torch.eye(hidden_size, dtype=torch.float32))
363 | self.bilinear.weight.requires_grad = False
364 |
365 | def score(self, x, y):
366 | wy = self.bilinear(y)
367 | wy = self.dropout(wy)
368 | return super(BilinearScorer, self).score(x, wy)
369 |
370 | def batchwise_score(self, x, y):
371 | """
372 | x: (batch_size, h)
373 | y: (batch_size, n_examples, h)
374 | """
375 | batch_size, n_examples, h = y.shape
376 | wy = self.bilinear(y.view(batch_size * n_examples,
377 | -1)).unsqueeze(1).view_as(y)
378 | wy = self.dropout(wy)
379 | # wy: (batch_size, n_examples, h)
380 | return super(BilinearScorer, self).batchwise_score(x, wy)
381 |
--------------------------------------------------------------------------------
/shapeworld/lsl/tre.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import optim
4 | from tqdm import trange
5 | from torch.nn.modules.distance import CosineSimilarity
6 |
7 |
8 | def flatten(l):
9 | if not isinstance(l, tuple):
10 | return (l, )
11 |
12 | out = ()
13 | for ll in l:
14 | out = out + flatten(ll)
15 | return out
16 |
17 |
18 | class L1Dist(nn.Module):
19 | def forward(self, pred, target):
20 | return torch.norm(pred - target, p=1, dim=1)
21 |
22 |
23 | class L2Dist(nn.Module):
24 | def forward(self, pred, target):
25 | return torch.norm(pred - target, p=2, dim=1)
26 |
27 |
28 | class CosDist(nn.Module):
29 | def __init__(self):
30 | super().__init__()
31 | self.cossim = CosineSimilarity()
32 |
33 | def forward(self, x, y):
34 | return 1 - self.cossim(x, y)
35 |
36 |
37 | class AddComp(nn.Module):
38 | def forward(self, embs, embs_mask):
39 | """
40 | embs: (batch_size, max_feats, h)
41 | embs_mask: (batch_size, max_feats)
42 | """
43 | embs_mask_exp = embs_mask.float().unsqueeze(2).expand_as(embs)
44 | embs_zeroed = embs * embs_mask_exp
45 | composed = embs_zeroed.sum(1)
46 | return composed
47 |
48 |
49 | class MulComp(nn.Module):
50 | def forward(self, embs, embs_mask):
51 | """
52 | embs: (batch_size, max_feats, h)
53 | embs_mask: (batch_size, max_feats)
54 | """
55 | raise NotImplementedError
56 |
57 |
58 | class Objective(nn.Module):
59 | def __init__(self, vocab, repr_size, comp_fn, err_fn, zero_init):
60 | super().__init__()
61 | self.emb = nn.Embedding(len(vocab), repr_size)
62 | if zero_init:
63 | self.emb.weight.data.zero_()
64 | self.comp = comp_fn
65 | self.err = err_fn
66 |
67 | def compose(self, feats, feats_mask):
68 | """
69 | Input:
70 | batch_size, max_feats
71 | Output:
72 | batch_size, h
73 | """
74 | embs = self.emb(feats)
75 | # Compose embeddings
76 | composed = self.comp(embs, feats_mask)
77 | return composed
78 |
79 | def forward(self, rep, feats, feats_mask):
80 | return self.err(self.compose(feats, feats_mask), rep)
81 |
82 |
83 | def tre(reps,
84 | feats,
85 | feats_mask,
86 | vocab,
87 | comp_fn,
88 | err_fn,
89 | quiet=False,
90 | steps=400,
91 | include_pred=False,
92 | zero_init=True):
93 |
94 | obj = Objective(vocab, reps.shape[1], comp_fn, err_fn, zero_init)
95 | obj = obj.to(reps.device)
96 | opt = optim.Adam(obj.parameters(), lr=0.001)
97 |
98 | if not quiet:
99 | ranger = trange(steps, desc='TRE')
100 | else:
101 | ranger = range(steps)
102 | for t in ranger:
103 | opt.zero_grad()
104 | loss = obj(reps, feats, feats_mask)
105 | total_loss = loss.sum()
106 | total_loss.backward()
107 | if not quiet and t % 100 == 0:
108 | print(total_loss.item())
109 | opt.step()
110 |
111 | final_losses = [l.item() for l in loss]
112 | if include_pred:
113 | lexicon = {
114 | k: obj.emb(torch.LongTensor([v])).data.cpu().numpy()
115 | for k, v in vocab.items()
116 | }
117 | composed = [obj.compose(f, fm) for f, fm in zip(feats, feats_mask)]
118 | return final_losses, lexicon, composed
119 | else:
120 | return final_losses
121 |
--------------------------------------------------------------------------------
/shapeworld/lsl/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities
3 | """
4 |
5 | from collections import Counter, OrderedDict
6 | import json
7 | import os
8 | import shutil
9 |
10 | import numpy as np
11 | import torch
12 |
13 | random_counter = [0]
14 |
15 |
16 | def next_random():
17 | random = np.random.RandomState(random_counter[0])
18 | random_counter[0] += 1
19 | return random
20 |
21 |
22 | class OrderedCounter(Counter, OrderedDict):
23 | """Counter that remembers the order elements are first encountered"""
24 |
25 | def __repr__(self):
26 | return '%s(%r)' % (self.__class__.__name__, OrderedDict(self))
27 |
28 | def __reduce__(self):
29 | return self.__class__, (OrderedDict(self), )
30 |
31 |
32 | class AverageMeter(object):
33 | """Computes and stores the average and current value"""
34 |
35 | def __init__(self, raw=False):
36 | self.raw = raw
37 | self.reset()
38 |
39 | def reset(self):
40 | self.val = 0
41 | self.avg = 0
42 | self.sum = 0
43 | self.count = 0
44 | if self.raw:
45 | self.raw_scores = []
46 |
47 | def update(self, val, n=1, raw_scores=None):
48 | self.val = val
49 | self.sum += val * n
50 | self.count += n
51 | self.avg = self.sum / self.count
52 | if self.raw:
53 | self.raw_scores.extend(list(raw_scores))
54 |
55 |
56 | def save_checkpoint(state, is_best, folder='./',
57 | filename='checkpoint.pth.tar'):
58 | if not os.path.isdir(folder):
59 | os.mkdir(folder)
60 | torch.save(state, os.path.join(folder, filename))
61 | if is_best:
62 | shutil.copyfile(os.path.join(folder, filename),
63 | os.path.join(folder, 'model_best.pth.tar'))
64 |
65 |
66 | def merge_args_with_dict(args, dic):
67 | for k, v in list(dic.items()):
68 | setattr(args, k, v)
69 |
70 |
71 | def make_output_and_sample_dir(out_dir):
72 | if not os.path.exists(out_dir):
73 | os.makedirs(out_dir)
74 |
75 | sample_dir = os.path.join(out_dir, 'samples')
76 | if not os.path.exists(sample_dir):
77 | os.makedirs(sample_dir)
78 |
79 | return out_dir, sample_dir
80 |
81 |
82 | def save_defaultdict_to_fs(d, out_path):
83 | d = dict(d)
84 | with open(out_path, 'w') as fp:
85 | d_str = json.dumps(d, ensure_ascii=True)
86 | fp.write(d_str)
87 |
88 |
89 | def idx2word(idx, i2w):
90 | sent_str = [str()] * len(idx)
91 | for i, sent in enumerate(idx):
92 | for word_id in sent:
93 | sent_str[i] += str(i2w[word_id.item()]) + " "
94 | sent_str[i] = sent_str[i].strip()
95 |
96 | return sent_str
97 |
--------------------------------------------------------------------------------
/shapeworld/run_l3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python lsl/train.py --cuda \
4 | --infer_hyp \
5 | --hypo_lambda 1.0 \
6 | --batch_size 100 \
7 | --seed $RANDOM \
8 | exp/l3
9 |
--------------------------------------------------------------------------------
/shapeworld/run_lang_ablation.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 | python lsl/train.py --cuda \
6 | --predict_concept_hyp \
7 | --hypo_lambda 20.0 \
8 | --seed "$RANDOM" \
9 | --batch_size 100 \
10 | --language_filter color \
11 | exp/lsl_color
12 |
13 | python lsl/train.py --cuda \
14 | --predict_concept_hyp \
15 | --hypo_lambda 20.0 \
16 | --seed "$RANDOM" \
17 | --batch_size 100 \
18 | --language_filter nocolor \
19 | exp/lsl_nocolor
20 |
21 | python lsl/train.py --cuda \
22 | --predict_concept_hyp \
23 | --hypo_lambda 20.0 \
24 | --seed "$RANDOM" \
25 | --batch_size 100 \
26 | --shuffle_words \
27 | exp/lsl_shuffle_words
28 |
29 | python lsl/train.py --cuda \
30 | --predict_concept_hyp \
31 | --hypo_lambda 20.0 \
32 | --seed "$RANDOM" \
33 | --batch_size 100 \
34 | --shuffle_captions \
35 | exp/lsl_shuffle_captions
36 |
--------------------------------------------------------------------------------
/shapeworld/run_lsl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | HYPO_LAMBDA=20
4 |
5 | python lsl/train.py --cuda \
6 | --predict_concept_hyp \
7 | --hypo_lambda $HYPO_LAMBDA \
8 | --batch_size 100 \
9 | --seed $RANDOM \
10 | exp/lsl
11 |
--------------------------------------------------------------------------------
/shapeworld/run_lsl_img.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | HYPO_LAMBDA=20
4 |
5 | python lsl/train.py --cuda \
6 | --predict_concept_hyp \
7 | --hypo_lambda $HYPO_LAMBDA \
8 | --batch_size 100 \
9 | --seed $RANDOM \
10 | exp/lsl_img
11 |
--------------------------------------------------------------------------------
/shapeworld/run_meta.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python lsl/train.py --cuda \
4 | --batch_size 100 \
5 | --seed $RANDOM \
6 | exp/meta
7 |
--------------------------------------------------------------------------------