├── .gitignore
├── LICENSE
├── README.md
├── asset
├── classification_summary.png
├── model.png
└── tb.png
├── download.py
├── get_dataset_script
├── download_miniimagenet.py
├── get_aircraft.py
├── get_bird.py
├── get_cifar.py
├── get_miniimagenet.py
├── proc_aircraft.py
├── proc_bird.py
├── proc_cifar.py
├── proc_miniimagenet.py
└── resize_dataset.py
├── main.py
├── maml
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── aircraft.py
│ ├── bird.py
│ ├── cifar100.py
│ ├── metadataset.py
│ ├── miniimagenet.py
│ ├── multimodal_few_shot.py
│ └── omniglot.py
├── metalearner.py
├── models
│ ├── __init__.py
│ ├── conv_embedding_model.py
│ ├── conv_net.py
│ ├── fully_connected.py
│ ├── gated_conv_net.py
│ ├── gated_net.py
│ ├── gru_embedding_model.py
│ ├── lstm_embedding_model.py
│ ├── model.py
│ └── simple_embedding_model.py
├── sampler.py
├── trainer.py
└── utils.py
├── miniimagenet-data
├── download_mini_imagenet.py
├── test.csv
├── train.csv
└── val.csv
├── requirements.txt
└── visualization.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # project specific
2 | data/
3 | saves/
4 | logs/
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 |
111 | *.txt
112 | *.sw[opn]
113 | *.pt
114 | *.hdf5
115 | train_dir/
116 | *tgz
117 | *tar.gz
118 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Tristan Deleu
4 | Copyright (c) 2018 Risto Vuorio
5 | Copyright (c) 2019 Hexiang Hu
6 | Copyright (c) 2019 Shao-Hua Sun
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multimodal Model-Agnostic Meta-Learning for Few-shot Classification
2 |
3 | This project is an implementation of [**Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation**](https://arxiv.org/abs/1910.13616), which is published in [**NeurIPS 2019**](https://neurips.cc/Conferences/2019/). Please visit our [project page](https://vuoristo.github.io/MMAML/) for more information and contact [Shao-Hua Sun](http://shaohua0116.github.io/) for any questions.
4 |
5 | Model-agnostic meta-learners aim to acquire meta-prior parameters from a distribution of tasks and adapt to novel tasks with few gradient updates. Yet, seeking a common initialization shared across the entire task distribution substantially limits the diversity of the task distributions that they are able to learn from. We propose a multimodal MAML (MMAML) framework, which is able to modulate its meta-learned prior according to the identified mode, allowing more efficient fast adaptation. An illustration of the proposed framework is as follows.
6 |
7 |
8 |
9 |
10 |
11 | We evaluate our model and baselines (MAML and Multi-MAML) on multiple multimodal settings based on the following five datasets: (a) [Omniglot](https://www.omniglot.com/), (b) [Mini-ImageNet](https://openreview.net/forum?id=rJY0-Kcll), (c) [FC100](https://arxiv.org/abs/1805.10123) (e.g. CIFAR100), (d) [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), and (e) [FGVC-Aircraft](https://arxiv.org/abs/1306.5151).
12 |
13 |
14 |
15 |
16 |
17 | ## Datasets
18 |
19 | Run the following command to download and preprocess the datasets
20 |
21 | ```bash
22 | python download.py --dataset aircraft bird cifar miniimagenet
23 | ```
24 |
25 | ## Getting started
26 |
27 | Please first install the following prerequisites: `wget`, `unzip`.
28 |
29 | To avoid any conflict with your existing Python setup, and to keep this project self-contained, it is suggested to work in a virtual environment with [`virtualenv`](http://docs.python-guide.org/en/latest/dev/virtualenvs/). To install `virtualenv`:
30 | ```
31 | pip install --upgrade virtualenv
32 | ```
33 | Create a virtual environment, activate it and install the requirements in [`requirements.txt`](requirements.txt).
34 | ```
35 | virtualenv mmaml_venv
36 | source mmaml_venv/bin/activate
37 | pip install -r requirements.txt
38 | ```
39 |
40 | ## Usage
41 |
42 | After downloading the datasets, we can start to train models with the following commands.
43 |
44 | ### Training command
45 |
46 | ```bash
47 | $ python main.py -dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s
48 | ```
49 | - Selected arguments (see the `trainer.py` for more details)
50 | - --output-folder: a nickname for the training
51 | - --dataset: choose among `omniglot`, `miniimagenet`, `cifar`, `bird` (CUB), and `aircraft`. You can also add your own datasets.
52 | - Checkpoints: specify the path to a pre-trained checkpoint
53 | - --checkpoint: load all the parameters (e.g. `train_dir/mmaml_5mode_5w1s/maml_gatedconv_60000.pt`).
54 | - Hyperparameters
55 | - --num-batches: number of batches
56 | - --meta-batch-size: number of tasks per batch
57 | - --slow-lr: learning rate for the global update of MAML
58 | - --fast-lr: learning rate for the adapted models
59 | - --num-updates: how many update steps in the inner loop
60 | - --num-classes-per-batch: how many classes per task (`N`-way)
61 | - --num-samples-per-class: how many samples per class for training (`K`-shot)
62 | - --num-val-samples: how many samples per class for validation
63 | - --max\_steps: the max training iterations
64 | - Logging
65 | - --log-interval: number of batches between tensorboard writes
66 | - --save-interval: number of batches between model saves
67 | - Model
68 | - maml-model: set to `True` to train a MAML model
69 | - mmaml-model: set to `True` to train a MMAML (our) model
70 |
71 | ### Interpret TensorBoard
72 |
73 | Launch Tensorboard and go to the specified port, you can see differernt accuracies and losses in the **scalars** tab.
74 |
75 |
76 |
77 |
78 |
79 | You can reproduce our results with the following training commands.
80 |
81 | ### 2 Modes (Omniglot and Mini-ImageNet)
82 |
83 | | Setup | Method | Command |
84 | | :---: | :----: | ---------------------------------------- |
85 | | 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --output-folder maml_2mode_5w1s``` |
86 | | 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --output-folder mmaml_2mode_5w1s``` |
87 | | 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_2mode_5w5s``` |
88 | | 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_2mode_5w5s``` |
89 | | 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_2mode_20w1s``` |
90 | | 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_2mode_20w1s``` |
91 |
92 | ### 3 Modes (Omniglot, Mini-ImageNet, and FC100)
93 |
94 | | Setup | Method | Command |
95 | | :---: | :----: | ---------------------------------------- |
96 | | 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --output-folder maml_3mode_5w1s``` |
97 | | 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --output-folder mmaml_3mode_5w1s``` |
98 | | 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` |
99 | | 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` |
100 | | 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_3mode_20w1s``` |
101 | | 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_3mode_20w1s``` |
102 |
103 | ### 5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB)
104 |
105 | | Setup | Method | Command |
106 | | :---: | :----: | ---------------------------------------- |
107 | | 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s``` |
108 | | 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s``` |
109 | | 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s``` |
110 | | 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s``` |
111 | | 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` |
112 | | 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` |
113 | | 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` |
114 | | 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` |
115 | | 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s``` |
116 | | 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s``` |
117 | | 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s``` |
118 | | 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s``` |
119 |
120 | ### Multi-MAML
121 |
122 | | Setup | Dataset | Command |
123 | | :---: | :-----------: | ---------------------------------------- |
124 | | 5w1s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --output-folder multi_omniglot_5w1s``` |
125 | | 5w1s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_miniimagenet_5w1s``` |
126 | | 5w1s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_cifar_5w1s``` |
127 | | 5w1s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_bird_5w1s``` |
128 | | 5w1s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_aircraft_5w1s``` |
129 | | 5w5s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --num-samples-per-class 5 --output-folder multi_omniglot_5w5s``` |
130 | | 5w5s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_miniimagenet_5w5s``` |
131 | | 5w5s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_cifar_5w5s``` |
132 | | 5w5s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_bird_5w5s``` |
133 | | 5w5s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_aircraft_5w5s``` |
134 | | 20w1s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.1 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_omniglot_20w1s``` |
135 | | 20w1s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_miniimagenet_20w1s``` |
136 | | 20w1s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_cifar_20w1s``` |
137 | | 20w1s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_bird_20w1s``` |
138 | | 20w1s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_aircraft_20w1s``` |
139 |
140 |
141 | ## Results
142 |
143 | ### 2 Modes (Omniglot and Mini-ImageNet)
144 |
145 | | Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot |
146 | | :----------: | :----------: | :----------: | :-----------: |
147 | | MAML | 66.80% | 77.79% | 44.69% |
148 | | Multi-MAML | 66.85% | 73.07% | 53.15% |
149 | | MMAML (Ours) | 69.93% | 78.73% | 47.80% |
150 |
151 | ### 3 Modes (Omniglot, Mini-ImageNet, and FC100)
152 |
153 | | Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot |
154 | | :----------: | :----------: | :----------: | :-----------: |
155 | | MAML | 54.55% | 67.97% | 28.22% |
156 | | Multi-MAML | 55.90% | 62.20% | 39.77% |
157 | | MMAML (Ours) | 57.47% | 70.15% | 36.27% |
158 |
159 | ### 5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB)
160 |
161 | | Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot |
162 | | :----------: | :----------: | :----------: | :-----------: |
163 | | MAML | 44.09% | 54.41% | 28.85% |
164 | | Multi-MAML | 45.46% | 55.92% | 33.78% |
165 | | MMAML (Ours) | 49.06% | 60.83% | 33.97% |
166 |
167 | Please check out [our paper](https://arxiv.org/abs/1910.13616) for more comprehensive results.
168 |
169 | ## Related work
170 | - \[MAML\] [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) in ICML 2017
171 | - [Probabilistic Model-Agnostic Meta-Learning](https://arxiv.org/abs/1806.02817) in NeurIPS 2018
172 | - [Bayesian Model-Agnostic Meta-Learning](https://arxiv.org/abs/1806.03836) in NeurIPS 2018
173 | - [Gradient-Based Meta-Learning with Learned Layerwise Metric and Subspace](https://arxiv.org/abs/1801.05558) in ICML 2018
174 | - [Reptile: A Scalable Meta-Learning Algorithm](https://openai.com/blog/reptile/)
175 | - [Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples](https://arxiv.org/abs/1903.03096) In Meta-Learning Workshop at NeurIPS 2018
176 | - [TADAM: Task dependent adaptive metric for improved few-shot learning](https://arxiv.org/abs/1805.10123) in NeurIPS 2018
177 | - [ProMP: Proximal Meta-Policy Search](https://arxiv.org/abs/1810.06784) in ICLR 2019
178 |
179 | ## Cite the paper
180 | If you find this useful, please cite
181 | ```
182 | @inproceedings{vuorio2019multimodal,
183 | title={Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation},
184 | author={Vuorio, Risto and Sun, Shao-Hua and Hu, Hexiang and Lim, Joseph J.},
185 | booktitle={Neural Information Processing Systems},
186 | year={2019},
187 | }
188 | ```
189 |
190 | ## Authors
191 | [Shao-Hua Sun](http://shaohua0116.github.io/), [Risto Vuorio](https://vuoristo.github.io/), [Hexiang Hu](http://hexianghu.com/)
192 |
--------------------------------------------------------------------------------
/asset/classification_summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/asset/classification_summary.png
--------------------------------------------------------------------------------
/asset/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/asset/model.png
--------------------------------------------------------------------------------
/asset/tb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/asset/tb.png
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import argparse
3 |
4 |
5 | parser = argparse.ArgumentParser(description='Download datasets for MMAML.')
6 | parser.add_argument('--dataset', metavar='N', type=str, nargs='+',
7 | choices=['aircraft', 'bird', 'cifar', 'miniimagenet'])
8 |
9 |
10 | def download(dataset):
11 | cmd = ['python', 'get_dataset_script/get_{}.py'.format(dataset)]
12 | print(' '.join(cmd))
13 | subprocess.call(cmd)
14 | return
15 |
16 |
17 | if __name__ == '__main__':
18 | args = parser.parse_args()
19 | if len(args.dataset) > 0:
20 | for dataset in args.dataset:
21 | download(dataset)
22 |
--------------------------------------------------------------------------------
/get_dataset_script/download_miniimagenet.py:
--------------------------------------------------------------------------------
1 | # https://drive.google.com/drive/folders/0B-r7apOz1BHAWXYwT1lGb3J1Yjg
2 | # Download files on Google Drive
3 | import requests
4 |
5 | def download_file_from_google_drive(id, destination):
6 | URL = "https://drive.google.com/uc?export=download"
7 |
8 | session = requests.Session()
9 |
10 | response = session.get(URL, params = { 'id' : id }, stream = True)
11 | token = get_confirm_token(response)
12 |
13 | if token:
14 | params = { 'id' : id, 'confirm' : token }
15 | response = session.get(URL, params = params, stream = True)
16 |
17 | save_response_content(response, destination)
18 |
19 | def get_confirm_token(response):
20 | for key, value in response.cookies.items():
21 | if key.startswith('download_warning'):
22 | return value
23 |
24 | return None
25 |
26 | def save_response_content(response, destination):
27 | CHUNK_SIZE = 32768
28 |
29 | with open(destination, "wb") as f:
30 | for chunk in response.iter_content(CHUNK_SIZE):
31 | if chunk: # filter out keep-alive new chunks
32 | f.write(chunk)
33 |
34 | if __name__ == "__main__":
35 | file_id = "1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk"
36 | destination = './mini-imagenet.zip'
37 | download_file_from_google_drive(file_id, destination)
38 |
--------------------------------------------------------------------------------
/get_dataset_script/get_aircraft.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 |
4 | cmds = []
5 | cmds.append(['wget', 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'])
6 | cmds.append(['tar', 'xvzf', 'fgvc-aircraft-2013b.tar.gz'])
7 | cmds.append(['python', 'get_dataset_script/proc_aircraft.py'])
8 | cmds.append(['rm', '-rf', 'fgvc-aircraft-2013b.tar.gz', 'fgvc-aircraft-2013b'])
9 |
10 | for cmd in cmds:
11 | print(' '.join(cmd))
12 | subprocess.call(cmd)
13 |
--------------------------------------------------------------------------------
/get_dataset_script/get_bird.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 |
4 |
5 | cmds = []
6 | cmds.append(['wget', 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'])
7 | cmds.append(['tar', 'xvzf', 'CUB_200_2011.tgz'])
8 | cmds.append(['python', 'get_dataset_script/proc_bird.py'])
9 | cmds.append(['rm', '-rf', 'CUB_200_2011', 'CUB_200_2011.tgz'])
10 |
11 | for cmd in cmds:
12 | print(' '.join(cmd))
13 | subprocess.call(cmd)
14 |
--------------------------------------------------------------------------------
/get_dataset_script/get_cifar.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 |
4 | cmds = []
5 | cmds.append(['wget', 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'])
6 | cmds.append(['tar', '-xzvf', 'cifar-100-python.tar.gz'])
7 | cmds.append(['mv', 'cifar-100-python', './data'])
8 | cmds.append(['python3', 'get_dataset_script/proc_cifar.py'])
9 | cmds.append(['rm', '-rf', 'cifar-100-python.tar.gz'])
10 |
11 | for cmd in cmds:
12 | print(' '.join(cmd))
13 | subprocess.call(cmd)
14 |
--------------------------------------------------------------------------------
/get_dataset_script/get_miniimagenet.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 |
4 | cmds = []
5 | cmds.append(['python', 'get_dataset_script/download_miniimagenet.py'])
6 | cmds.append(['unzip', 'mini-imagenet.zip'])
7 | cmds.append(['rm', '-rf', 'mini-imagenet.zip'])
8 | cmds.append(['mkdir', 'miniimagenet'])
9 | cmds.append(['mv', 'images', 'miniimagenet'])
10 | cmds.append(['python', 'get_dataset_script/proc_miniimagenet.py'])
11 | cmds.append(['mv', 'miniimagenet', './data'])
12 |
13 | for cmd in cmds:
14 | print(' '.join(cmd))
15 | subprocess.call(cmd)
16 |
--------------------------------------------------------------------------------
/get_dataset_script/proc_aircraft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import subprocess
4 |
5 |
6 | source_dir = './fgvc-aircraft-2013b/data'
7 | target_dir = './data/aircraft'
8 |
9 | percentage_train_class = 70
10 | percentage_val_class = 15
11 | percentage_test_class = 15
12 | train_val_test_ratio = [
13 | percentage_train_class, percentage_val_class, percentage_test_class]
14 |
15 | with open(os.path.join(source_dir, 'variants.txt')) as f:
16 | lines = f.readlines()
17 | classes = [line.strip() for line in lines]
18 |
19 | rs = np.random.RandomState(123)
20 | rs.shuffle(classes)
21 | num_train, num_val, num_test = [
22 | int(float(ratio)/np.sum(train_val_test_ratio)*len(classes))
23 | for ratio in train_val_test_ratio]
24 |
25 | classes = {
26 | 'train': classes[:num_train],
27 | 'val': classes[num_train:num_train+num_val],
28 | 'test': classes[num_train+num_val:]
29 | }
30 |
31 | if not os.path.exists(target_dir):
32 | os.makedirs(target_dir)
33 | for k in classes.keys():
34 | target_dir_k = os.path.join(target_dir, k)
35 | if not os.path.exists(target_dir_k):
36 | os.makedirs(target_dir_k)
37 | for c in classes[k]:
38 | c = c.replace('/', '-')
39 | target_dir_k_c = os.path.join(target_dir_k, c)
40 | if not os.path.exists(target_dir_k_c):
41 | os.makedirs(target_dir_k_c)
42 |
43 | lines = []
44 | with open(os.path.join(source_dir, 'images_variant_trainval.txt')) as f:
45 | lines += f.readlines()
46 | with open(os.path.join(source_dir, 'images_variant_test.txt')) as f:
47 | lines += f.readlines()
48 | lines = [line.strip() for line in lines]
49 |
50 | for i, line in enumerate(lines):
51 | image_num, image_class = line.split(' ', 1)
52 | image_class = image_class.replace('/', '-')
53 | image_k = list(classes.keys())[np.argmax([image_class in classes[k] for k in list(classes.keys())])]
54 | image_source_path = os.path.join(source_dir, 'images', '{}.jpg'.format(image_num))
55 | image_target_path = os.path.join(target_dir, image_k, image_class)
56 | cmd = ['mv', image_source_path, image_target_path]
57 | subprocess.call(cmd)
58 | print('{}/{} {}'.format(i, len(lines), ' '.join(cmd)))
59 |
60 | # resize images
61 | cmd = ['python', 'get_dataset_script/resize_dataset.py', './data/aircraft']
62 | subprocess.call(cmd)
63 |
--------------------------------------------------------------------------------
/get_dataset_script/proc_bird.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import subprocess
4 | from imageio import imread, imwrite
5 |
6 |
7 | source_dir = './CUB_200_2011/images'
8 | target_dir = './data/bird'
9 |
10 | percentage_train_class = 70
11 | percentage_val_class = 15
12 | percentage_test_class = 15
13 | train_val_test_ratio = [
14 | percentage_train_class, percentage_val_class, percentage_test_class]
15 |
16 | classes = os.listdir(source_dir)
17 |
18 | rs = np.random.RandomState(123)
19 | rs.shuffle(classes)
20 | num_train, num_val, num_test = [
21 | int(float(ratio)/np.sum(train_val_test_ratio)*len(classes))
22 | for ratio in train_val_test_ratio]
23 |
24 | classes = {
25 | 'train': classes[:num_train],
26 | 'val': classes[num_train:num_train+num_val],
27 | 'test': classes[num_train+num_val:]
28 | }
29 |
30 | if not os.path.exists(target_dir):
31 | os.makedirs(target_dir)
32 | for k in classes.keys():
33 | target_dir_k = os.path.join(target_dir, k)
34 | if not os.path.exists(target_dir_k):
35 | os.makedirs(target_dir_k)
36 | cmd = ['mv'] + [os.path.join(source_dir, c) for c in classes[k]] + [target_dir_k]
37 | subprocess.call(cmd)
38 |
39 | _ids = []
40 |
41 | for root, dirnames, filenames in os.walk(target_dir):
42 | for filename in filenames:
43 | if filename.endswith(('.jpg', '.webp', '.JPEG', '.png', 'jpeg')):
44 | _ids.append(os.path.join(root, filename))
45 |
46 | for path in _ids:
47 | try:
48 | img = imread(path)
49 | except:
50 | print(img)
51 | if len(img.shape) < 3:
52 | print(path)
53 | img = np.tile(np.expand_dims(img, axis=-1), [1, 1, 3])
54 | imwrite(path, img)
55 | else:
56 | if img.shape[-1] == 1:
57 | print(path)
58 | img = np.tile(img, [1, 1, 3])
59 | imwrite(path, img)
60 |
61 | # resize images
62 | cmd = ['python', 'get_dataset_script/resize_dataset.py', './data/bird']
63 | subprocess.call(cmd)
64 |
--------------------------------------------------------------------------------
/get_dataset_script/proc_cifar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pickle
4 | import os
5 |
6 | img_size=32
7 | classes={
8 | 'train': [1, 2, 3, 4, 5, 6, 9, 10, 15, 17, 18, 19],
9 | 'val': [8, 11, 13, 16],
10 | 'test': [0, 7, 12, 14]
11 | }
12 |
13 | def _get_file_path(filename=""):
14 | return os.path.join('./data', "cifar-100-python/", filename)
15 |
16 | def _unpickle(filename):
17 | """
18 | Unpickle the given file and return the data.
19 | Note that the appropriate dir-name is prepended the filename.
20 | """
21 |
22 | # Create full path for the file.
23 | file_path = _get_file_path(filename)
24 |
25 | print("Loading data: " + file_path)
26 |
27 | with open(file_path, mode='rb') as file:
28 | # In Python 3.X it is important to set the encoding,
29 | # otherwise an exception is raised here.
30 | data = pickle.load(file, encoding='latin1')
31 |
32 | return data
33 |
34 | # import IPython
35 | # IPython.embed()
36 | meta = _unpickle('meta')
37 | train = _unpickle('train')
38 | test = _unpickle('test')
39 |
40 | data = np.concatenate([train['data'], test['data']])
41 | labels = np.array(train['fine_labels'] + test['fine_labels'])
42 | filts = np.array(train['coarse_labels'] + test['coarse_labels'])
43 |
44 | cifar_data = {}
45 | cifar_label = {}
46 | for k, v in classes.items():
47 | data_filter = np.zeros_like(filts)
48 | for i in v: data_filter += ( filts == i )
49 | assert data_filter.max() == 1
50 |
51 | cifar_data[k] = data[data_filter == 1]
52 | cifar_label[k] = labels[data_filter == 1]
53 |
54 | torch.save({'data': cifar_data, 'label': cifar_label}, './data/cifar100.pth')
55 |
--------------------------------------------------------------------------------
/get_dataset_script/proc_miniimagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code)
3 |
4 | Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the
5 | csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'.
6 | Then run this script from the miniImagenet directory:
7 | cd data/miniImagenet/
8 | python proc_images.py
9 | """
10 |
11 | from __future__ import print_function
12 | import csv
13 | import glob
14 | import os
15 | from tqdm import tqdm
16 |
17 | from PIL import Image
18 |
19 | path_to_images = 'miniimagenet/images/'
20 |
21 | all_images = glob.glob(path_to_images + '*')
22 |
23 | # Resize images
24 | for i, image_file in enumerate(tqdm(all_images)):
25 | im = Image.open(image_file)
26 | if not im.size == (84, 84):
27 | im = im.resize((84, 84), resample=Image.LANCZOS)
28 | im.save(image_file)
29 |
30 | # Put in correct directory
31 | for datatype in ['train', 'val', 'test']:
32 | print('Processing {} data'.format(datatype))
33 | os.system('mkdir miniimagenet/' + datatype)
34 |
35 | with open(os.path.join('miniimagenet-data', datatype + '.csv'), 'r') as f:
36 | reader = csv.reader(f, delimiter=',')
37 | last_label = ''
38 | for i, row in enumerate(reader):
39 | if i == 0: # skip the headers
40 | continue
41 | label = row[1]
42 | image_name = row[0]
43 | if label != last_label:
44 | cur_dir = 'miniimagenet/' + datatype + '/' + label + '/'
45 | os.system('mkdir ' + cur_dir)
46 | last_label = label
47 | os.system('mv miniimagenet/images/' + image_name + ' ' + cur_dir)
48 |
--------------------------------------------------------------------------------
/get_dataset_script/resize_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from imageio import imread, imwrite
4 | from skimage.transform import resize
5 |
6 | target_dir = './data/aircraft' if len(sys.argv) < 2 else sys.argv[1]
7 | img_size = [84, 84]
8 |
9 | _ids = []
10 |
11 | for root, dirnames, filenames in os.walk(target_dir):
12 | for filename in filenames:
13 | if filename.endswith(('.jpg', '.webp', '.JPEG', '.png', 'jpeg')):
14 | _ids.append(os.path.join(root, filename))
15 |
16 | for i, path in enumerate(_ids):
17 | img = imread(path)
18 | print('{}/{} size: {}'.format(i, len(_ids), img.shape))
19 | imwrite(path, resize(img, img_size))
20 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | import torch
6 | import numpy as np
7 | from tensorboardX import SummaryWriter
8 |
9 | from maml.datasets.omniglot import OmniglotMetaDataset
10 | from maml.datasets.miniimagenet import MiniimagenetMetaDataset
11 | from maml.datasets.cifar100 import Cifar100MetaDataset
12 | from maml.datasets.bird import BirdMetaDataset
13 | from maml.datasets.aircraft import AircraftMetaDataset
14 | from maml.datasets.multimodal_few_shot import MultimodalFewShotDataset
15 | from maml.models.fully_connected import FullyConnectedModel, MultiFullyConnectedModel
16 | from maml.models.conv_net import ConvModel
17 | from maml.models.gated_conv_net import GatedConvModel
18 | from maml.models.gated_net import GatedNet
19 | from maml.models.simple_embedding_model import SimpleEmbeddingModel
20 | from maml.models.lstm_embedding_model import LSTMEmbeddingModel
21 | from maml.models.gru_embedding_model import GRUEmbeddingModel
22 | from maml.models.conv_embedding_model import ConvEmbeddingModel
23 | from maml.metalearner import MetaLearner
24 | from maml.trainer import Trainer
25 | from maml.utils import optimizer_to_device, get_git_revision_hash
26 |
27 |
28 | def main(args):
29 | is_training = not args.eval
30 | run_name = 'train' if is_training else 'eval'
31 |
32 | if is_training:
33 | writer = SummaryWriter('./train_dir/{0}/{1}'.format(
34 | args.output_folder, run_name))
35 | with open('./train_dir/{}/config.txt'.format(
36 | args.output_folder), 'w') as config_txt:
37 | for k, v in sorted(vars(args).items()):
38 | config_txt.write('{}: {}\n'.format(k, v))
39 | else:
40 | writer = None
41 |
42 | save_folder = './train_dir/{0}'.format(args.output_folder)
43 | if not os.path.exists(save_folder):
44 | os.makedirs(save_folder)
45 |
46 | config_name = '{0}_config.json'.format(run_name)
47 | with open(os.path.join(save_folder, config_name), 'w') as f:
48 | config = {k: v for (k, v) in vars(args).items() if k != 'device'}
49 | config.update(device=args.device.type)
50 | try:
51 | config.update({'git_hash': get_git_revision_hash()})
52 | except:
53 | pass
54 | json.dump(config, f, indent=2)
55 |
56 | _num_tasks = 1
57 | if args.dataset == 'omniglot':
58 | dataset = OmniglotMetaDataset(
59 | root='data',
60 | img_side_len=28, # args.img_side_len,
61 | num_classes_per_batch=args.num_classes_per_batch,
62 | num_samples_per_class=args.num_samples_per_class,
63 | num_total_batches=args.num_batches,
64 | num_val_samples=args.num_val_samples,
65 | meta_batch_size=args.meta_batch_size,
66 | train=is_training,
67 | num_train_classes=args.num_train_classes,
68 | num_workers=args.num_workers,
69 | device=args.device)
70 | loss_func = torch.nn.CrossEntropyLoss()
71 | collect_accuracies = True
72 | elif args.dataset == 'cifar':
73 | dataset = Cifar100MetaDataset(
74 | root='data',
75 | img_side_len=32,
76 | num_classes_per_batch=args.num_classes_per_batch,
77 | num_samples_per_class=args.num_samples_per_class,
78 | num_total_batches=args.num_batches,
79 | num_val_samples=args.num_val_samples,
80 | meta_batch_size=args.meta_batch_size,
81 | train=is_training,
82 | num_train_classes=args.num_train_classes,
83 | num_workers=args.num_workers,
84 | device=args.device)
85 | loss_func = torch.nn.CrossEntropyLoss()
86 | collect_accuracies = True
87 | elif args.dataset == 'multimodal_few_shot':
88 | dataset_list = []
89 | if 'omniglot' in args.multimodal_few_shot:
90 | dataset_list.append(OmniglotMetaDataset(
91 | root='data',
92 | img_side_len=args.common_img_side_len,
93 | img_channel=args.common_img_channel,
94 | num_classes_per_batch=args.num_classes_per_batch,
95 | num_samples_per_class=args.num_samples_per_class,
96 | num_total_batches=args.num_batches,
97 | num_val_samples=args.num_val_samples,
98 | meta_batch_size=args.meta_batch_size,
99 | train=is_training,
100 | num_train_classes=args.num_train_classes,
101 | num_workers=args.num_workers,
102 | device=args.device)
103 | )
104 | if 'miniimagenet' in args.multimodal_few_shot:
105 | dataset_list.append( MiniimagenetMetaDataset(
106 | root='data',
107 | img_side_len=args.common_img_side_len,
108 | img_channel=args.common_img_channel,
109 | num_classes_per_batch=args.num_classes_per_batch,
110 | num_samples_per_class=args.num_samples_per_class,
111 | num_total_batches=args.num_batches,
112 | num_val_samples=args.num_val_samples,
113 | meta_batch_size=args.meta_batch_size,
114 | train=is_training,
115 | num_train_classes=args.num_train_classes,
116 | num_workers=args.num_workers,
117 | device=args.device)
118 | )
119 | if 'cifar' in args.multimodal_few_shot:
120 | dataset_list.append(Cifar100MetaDataset(
121 | root='data',
122 | img_side_len=args.common_img_side_len,
123 | img_channel=args.common_img_channel,
124 | num_classes_per_batch=args.num_classes_per_batch,
125 | num_samples_per_class=args.num_samples_per_class,
126 | num_total_batches=args.num_batches,
127 | num_val_samples=args.num_val_samples,
128 | meta_batch_size=args.meta_batch_size,
129 | train=is_training,
130 | num_train_classes=args.num_train_classes,
131 | num_workers=args.num_workers,
132 | device=args.device)
133 | )
134 | if 'doublemnist' in args.multimodal_few_shot:
135 | dataset_list.append( DoubleMNISTMetaDataset(
136 | root='data',
137 | img_side_len=args.common_img_side_len,
138 | img_channel=args.common_img_channel,
139 | num_classes_per_batch=args.num_classes_per_batch,
140 | num_samples_per_class=args.num_samples_per_class,
141 | num_total_batches=args.num_batches,
142 | num_val_samples=args.num_val_samples,
143 | meta_batch_size=args.meta_batch_size,
144 | train=is_training,
145 | num_train_classes=args.num_train_classes,
146 | num_workers=args.num_workers,
147 | device=args.device)
148 | )
149 | if 'triplemnist' in args.multimodal_few_shot:
150 | dataset_list.append( TripleMNISTMetaDataset(
151 | root='data',
152 | img_side_len=args.common_img_side_len,
153 | img_channel=args.common_img_channel,
154 | num_classes_per_batch=args.num_classes_per_batch,
155 | num_samples_per_class=args.num_samples_per_class,
156 | num_total_batches=args.num_batches,
157 | num_val_samples=args.num_val_samples,
158 | meta_batch_size=args.meta_batch_size,
159 | train=is_training,
160 | num_train_classes=args.num_train_classes,
161 | num_workers=args.num_workers,
162 | device=args.device)
163 | )
164 | if 'bird' in args.multimodal_few_shot:
165 | dataset_list.append( BirdMetaDataset(
166 | root='data',
167 | img_side_len=args.common_img_side_len,
168 | img_channel=args.common_img_channel,
169 | num_classes_per_batch=args.num_classes_per_batch,
170 | num_samples_per_class=args.num_samples_per_class,
171 | num_total_batches=args.num_batches,
172 | num_val_samples=args.num_val_samples,
173 | meta_batch_size=args.meta_batch_size,
174 | train=is_training,
175 | num_train_classes=args.num_train_classes,
176 | num_workers=args.num_workers,
177 | device=args.device)
178 | )
179 | if 'aircraft' in args.multimodal_few_shot:
180 | dataset_list.append( AircraftMetaDataset(
181 | root='data',
182 | img_side_len=args.common_img_side_len,
183 | img_channel=args.common_img_channel,
184 | num_classes_per_batch=args.num_classes_per_batch,
185 | num_samples_per_class=args.num_samples_per_class,
186 | num_total_batches=args.num_batches,
187 | num_val_samples=args.num_val_samples,
188 | meta_batch_size=args.meta_batch_size,
189 | train=is_training,
190 | num_train_classes=args.num_train_classes,
191 | num_workers=args.num_workers,
192 | device=args.device)
193 | )
194 | assert len(dataset_list) > 0
195 | print('Multimodal Few Shot Datasets: {}'.format(
196 | ' '.join([dataset.name for dataset in dataset_list])))
197 | dataset = MultimodalFewShotDataset(
198 | dataset_list,
199 | num_total_batches=args.num_batches,
200 | mix_meta_batch=args.mix_meta_batch,
201 | mix_mini_batch=args.mix_mini_batch,
202 | txt_file=args.sample_embedding_file+'.txt' if args.num_sample_embedding > 0 else None,
203 | train=is_training,
204 | )
205 | loss_func = torch.nn.CrossEntropyLoss()
206 | collect_accuracies = True
207 | elif args.dataset == 'doublemnist':
208 | dataset = DoubleMNISTMetaDataset(
209 | root='data',
210 | img_side_len=64,
211 | num_classes_per_batch=args.num_classes_per_batch,
212 | num_samples_per_class=args.num_samples_per_class,
213 | num_total_batches=args.num_batches,
214 | num_val_samples=args.num_val_samples,
215 | meta_batch_size=args.meta_batch_size,
216 | train=is_training,
217 | num_train_classes=args.num_train_classes,
218 | num_workers=args.num_workers,
219 | device=args.device)
220 | loss_func = torch.nn.CrossEntropyLoss()
221 | collect_accuracies = True
222 | elif args.dataset == 'triplemnist':
223 | dataset = TripleMNISTMetaDataset(
224 | root='data',
225 | img_side_len=84,
226 | num_classes_per_batch=args.num_classes_per_batch,
227 | num_samples_per_class=args.num_samples_per_class,
228 | num_total_batches=args.num_batches,
229 | num_val_samples=args.num_val_samples,
230 | meta_batch_size=args.meta_batch_size,
231 | train=is_training,
232 | num_train_classes=args.num_train_classes,
233 | num_workers=args.num_workers,
234 | device=args.device)
235 | loss_func = torch.nn.CrossEntropyLoss()
236 | collect_accuracies = True
237 | elif args.dataset == 'miniimagenet':
238 | dataset = MiniimagenetMetaDataset(
239 | root='data',
240 | img_side_len=84,
241 | num_classes_per_batch=args.num_classes_per_batch,
242 | num_samples_per_class=args.num_samples_per_class,
243 | num_total_batches=args.num_batches,
244 | num_val_samples=args.num_val_samples,
245 | meta_batch_size=args.meta_batch_size,
246 | train=is_training,
247 | num_train_classes=args.num_train_classes,
248 | num_workers=args.num_workers,
249 | device=args.device)
250 | loss_func = torch.nn.CrossEntropyLoss()
251 | collect_accuracies = True
252 | elif args.dataset == 'sinusoid':
253 | dataset = SinusoidMetaDataset(
254 | num_total_batches=args.num_batches,
255 | num_samples_per_function=args.num_samples_per_class,
256 | num_val_samples=args.num_val_samples,
257 | meta_batch_size=args.meta_batch_size,
258 | amp_range=args.amp_range,
259 | phase_range=args.phase_range,
260 | input_range=args.input_range,
261 | oracle=args.oracle,
262 | train=is_training,
263 | device=args.device)
264 | loss_func = torch.nn.MSELoss()
265 | collect_accuracies = False
266 | elif args.dataset == 'linear':
267 | dataset = LinearMetaDataset(
268 | num_total_batches=args.num_batches,
269 | num_samples_per_function=args.num_samples_per_class,
270 | num_val_samples=args.num_val_samples,
271 | meta_batch_size=args.meta_batch_size,
272 | slope_range=args.slope_range,
273 | intersect_range=args.intersect_range,
274 | input_range=args.input_range,
275 | oracle=args.oracle,
276 | train=is_training,
277 | device=args.device)
278 | loss_func = torch.nn.MSELoss()
279 | collect_accuracies = False
280 | elif args.dataset == 'mixed':
281 | dataset = MixedFunctionsMetaDataset(
282 | num_total_batches=args.num_batches,
283 | num_samples_per_function=args.num_samples_per_class,
284 | num_val_samples=args.num_val_samples,
285 | meta_batch_size=args.meta_batch_size,
286 | amp_range=args.amp_range,
287 | phase_range=args.phase_range,
288 | slope_range=args.slope_range,
289 | intersect_range=args.intersect_range,
290 | input_range=args.input_range,
291 | noise_std=args.noise_std,
292 | oracle=args.oracle,
293 | task_oracle=args.task_oracle,
294 | train=is_training,
295 | device=args.device)
296 | loss_func = torch.nn.MSELoss()
297 | collect_accuracies = False
298 | _num_tasks=2
299 | elif args.dataset == 'many':
300 | dataset = ManyFunctionsMetaDataset(
301 | num_total_batches=args.num_batches,
302 | num_samples_per_function=args.num_samples_per_class,
303 | num_val_samples=args.num_val_samples,
304 | meta_batch_size=args.meta_batch_size,
305 | amp_range=args.amp_range,
306 | phase_range=args.phase_range,
307 | slope_range=args.slope_range,
308 | intersect_range=args.intersect_range,
309 | input_range=args.input_range,
310 | noise_std=args.noise_std,
311 | oracle=args.oracle,
312 | task_oracle=args.task_oracle,
313 | train=is_training,
314 | device=args.device)
315 | loss_func = torch.nn.MSELoss()
316 | collect_accuracies = False
317 | _num_tasks=3
318 | elif args.dataset == 'multisinusoids':
319 | dataset = MultiSinusoidsMetaDataset(
320 | num_total_batches=args.num_batches,
321 | num_samples_per_function=args.num_samples_per_class,
322 | num_val_samples=args.num_val_samples,
323 | meta_batch_size=args.meta_batch_size,
324 | amp_range=args.amp_range,
325 | phase_range=args.phase_range,
326 | slope_range=args.slope_range,
327 | intersect_range=args.intersect_range,
328 | input_range=args.input_range,
329 | noise_std=args.noise_std,
330 | oracle=args.oracle,
331 | task_oracle=args.task_oracle,
332 | train=is_training,
333 | device=args.device)
334 | loss_func = torch.nn.MSELoss()
335 | collect_accuracies = False
336 | else:
337 | raise ValueError('Unrecognized dataset {}'.format(args.dataset))
338 |
339 | embedding_model = None
340 |
341 | if args.model_type == 'fc':
342 | model = FullyConnectedModel(
343 | input_size=np.prod(dataset.input_size),
344 | output_size=dataset.output_size,
345 | hidden_sizes=args.hidden_sizes,
346 | disable_norm=args.disable_norm,
347 | bias_transformation_size=args.bias_transformation_size)
348 | elif args.model_type == 'multi':
349 | model = MultiFullyConnectedModel(
350 | input_size=np.prod(dataset.input_size),
351 | output_size=dataset.output_size,
352 | hidden_sizes=args.hidden_sizes,
353 | disable_norm=args.disable_norm,
354 | num_tasks=_num_tasks,
355 | bias_transformation_size=args.bias_transformation_size)
356 | elif args.model_type == 'conv':
357 | model = ConvModel(
358 | input_channels=dataset.input_size[0],
359 | output_size=dataset.output_size,
360 | num_channels=args.num_channels,
361 | img_side_len=dataset.input_size[1],
362 | use_max_pool=args.use_max_pool,
363 | verbose=args.verbose)
364 | elif args.model_type == 'gatedconv':
365 | model = GatedConvModel(
366 | input_channels=dataset.input_size[0],
367 | output_size=dataset.output_size,
368 | use_max_pool=args.use_max_pool,
369 | num_channels=args.num_channels,
370 | img_side_len=dataset.input_size[1],
371 | condition_type=args.condition_type,
372 | condition_order=args.condition_order,
373 | verbose=args.verbose)
374 | elif args.model_type == 'gated':
375 | model = GatedNet(
376 | input_size=np.prod(dataset.input_size),
377 | output_size=dataset.output_size,
378 | hidden_sizes=args.hidden_sizes,
379 | condition_type=args.condition_type,
380 | condition_order=args.condition_order)
381 | else:
382 | raise ValueError('Unrecognized model type {}'.format(args.model_type))
383 | model_parameters = list(model.parameters())
384 |
385 | if args.embedding_type == '':
386 | embedding_model = None
387 | elif args.embedding_type == 'simple':
388 | embedding_model = SimpleEmbeddingModel(
389 | num_embeddings=dataset.num_tasks,
390 | embedding_dims=args.embedding_dims)
391 | embedding_parameters = list(embedding_model.parameters())
392 | elif args.embedding_type == 'GRU':
393 | embedding_model = GRUEmbeddingModel(
394 | input_size=np.prod(dataset.input_size),
395 | output_size=dataset.output_size,
396 | embedding_dims=args.embedding_dims,
397 | hidden_size=args.embedding_hidden_size,
398 | num_layers=args.embedding_num_layers)
399 | embedding_parameters = list(embedding_model.parameters())
400 | elif args.embedding_type == 'LSTM':
401 | embedding_model = LSTMEmbeddingModel(
402 | input_size=np.prod(dataset.input_size),
403 | output_size=dataset.output_size,
404 | embedding_dims=args.embedding_dims,
405 | hidden_size=args.embedding_hidden_size,
406 | num_layers=args.embedding_num_layers)
407 | embedding_parameters = list(embedding_model.parameters())
408 | elif args.embedding_type == 'ConvGRU':
409 | embedding_model = ConvEmbeddingModel(
410 | input_size=np.prod(dataset.input_size),
411 | output_size=dataset.output_size,
412 | embedding_dims=args.embedding_dims,
413 | hidden_size=args.embedding_hidden_size,
414 | num_layers=args.embedding_num_layers,
415 | convolutional=args.conv_embedding,
416 | num_conv=args.num_conv_embedding_layer,
417 | num_channels=args.num_channels,
418 | rnn_aggregation=(not args.no_rnn_aggregation),
419 | embedding_pooling=args.embedding_pooling,
420 | batch_norm=args.conv_embedding_batch_norm,
421 | avgpool_after_conv=args.conv_embedding_avgpool_after_conv,
422 | linear_before_rnn=args.linear_before_rnn,
423 | num_sample_embedding=args.num_sample_embedding,
424 | sample_embedding_file=args.sample_embedding_file+'.'+args.sample_embedding_file_type,
425 | img_size=dataset.input_size,
426 | verbose=args.verbose)
427 | embedding_parameters = list(embedding_model.parameters())
428 | else:
429 | raise ValueError('Unrecognized embedding type {}'.format(
430 | args.embedding_type))
431 |
432 | optimizers = None
433 | if embedding_model:
434 | optimizers = ( torch.optim.Adam(model_parameters, lr=args.slow_lr),
435 | torch.optim.Adam(embedding_parameters, lr=args.slow_lr) )
436 | else:
437 | optimizers = ( torch.optim.Adam(model_parameters, lr=args.slow_lr), )
438 |
439 | if args.checkpoint != '':
440 | checkpoint = torch.load(args.checkpoint)
441 | model.load_state_dict(checkpoint['model_state_dict'])
442 | model.to(args.device)
443 | if 'optimizer' in checkpoint:
444 | pass
445 | else:
446 | optimizers[0].load_state_dict(checkpoint['optimizers'][0])
447 | optimizer_to_device(optimizers[0], args.device)
448 |
449 | if embedding_model:
450 | embedding_model.load_state_dict(
451 | checkpoint['embedding_model_state_dict'])
452 | optimizers[1].load_state_dict(checkpoint['optimizers'][1])
453 | optimizer_to_device(optimizers[1], args.device)
454 |
455 | meta_learner = MetaLearner(
456 | model, embedding_model, optimizers, fast_lr=args.fast_lr,
457 | loss_func=loss_func, first_order=args.first_order,
458 | num_updates=args.num_updates,
459 | inner_loop_grad_clip=args.inner_loop_grad_clip,
460 | collect_accuracies=collect_accuracies, device=args.device,
461 | alternating=args.alternating, embedding_schedule=args.embedding_schedule,
462 | classifier_schedule=args.classifier_schedule, embedding_grad_clip=args.embedding_grad_clip)
463 |
464 | trainer = Trainer(
465 | meta_learner=meta_learner, meta_dataset=dataset, writer=writer,
466 | log_interval=args.log_interval, save_interval=args.save_interval,
467 | model_type=args.model_type, save_folder=save_folder,
468 | total_iter=args.num_batches//args.meta_batch_size
469 | )
470 |
471 | if is_training:
472 | trainer.train()
473 | else:
474 | trainer.eval()
475 |
476 |
477 | if __name__ == '__main__':
478 |
479 | def str2bool(arg):
480 | return arg.lower() == 'true'
481 |
482 | parser = argparse.ArgumentParser(
483 | description='Model-Agnostic Meta-Learning (MAML)')
484 |
485 | parser.add_argument('--mmaml-model', type=str2bool, default=False,
486 | help='gated_conv + ConvGRU')
487 | parser.add_argument('--maml-model', type=str2bool, default=False,
488 | help='conv')
489 |
490 | # Model
491 | parser.add_argument('--hidden-sizes', type=int,
492 | default=[256, 128, 64, 64], nargs='+',
493 | help='number of hidden units per layer')
494 | parser.add_argument('--model-type', type=str, default='gatedconv',
495 | help='type of the model')
496 | parser.add_argument('--condition-type', type=str, default='affine',
497 | choices=['affine', 'sigmoid', 'softmax'],
498 | help='type of the conditional layers')
499 | parser.add_argument('--condition-order', type=str, default='low2high',
500 | help='order of the conditional layers to be used')
501 | parser.add_argument('--use-max-pool', type=str2bool, default=False,
502 | help='choose whether to use max pooling with convolutional model')
503 | parser.add_argument('--num-channels', type=int, default=32,
504 | help='number of channels in convolutional layers')
505 | parser.add_argument('--disable-norm', action='store_true',
506 | help='disable batchnorm after linear layers in a fully connected model')
507 | parser.add_argument('--bias-transformation-size', type=int, default=0,
508 | help='size of bias transformation vector that is concatenated with '
509 | 'input')
510 |
511 | # Embedding
512 | parser.add_argument('--embedding-type', type=str, default='',
513 | help='type of the embedding')
514 | parser.add_argument('--embedding-hidden-size', type=int, default=128,
515 | help='number of hidden units per layer in recurrent embedding model')
516 | parser.add_argument('--embedding-num-layers', type=int, default=2,
517 | help='number of layers in recurrent embedding model')
518 | parser.add_argument('--embedding-dims', type=int, nargs='+', default=0,
519 | help='dimensions of the embeddings')
520 |
521 | # Randomly sampled embedding vectors
522 | parser.add_argument('--num-sample-embedding', type=int, default=0,
523 | help='number of randomly sampled embedding vectors')
524 | parser.add_argument(
525 | '--sample-embedding-file', type=str, default='embeddings',
526 | help='the file name of randomly sampled embedding vectors')
527 | parser.add_argument(
528 | '--sample-embedding-file-type', type=str, default='hdf5')
529 |
530 | # Inner loop
531 | parser.add_argument('--first-order', action='store_true',
532 | help='use the first-order approximation of MAML')
533 | parser.add_argument('--fast-lr', type=float, default=0.05,
534 | help='learning rate for the 1-step gradient update of MAML')
535 | parser.add_argument('--inner-loop-grad-clip', type=float, default=20.0,
536 | help='enable gradient clipping in the inner loop')
537 | parser.add_argument('--num-updates', type=int, default=5,
538 | help='how many update steps in the inner loop')
539 |
540 | # Optimization
541 | parser.add_argument('--num-batches', type=int, default=1920000,
542 | help='number of batches')
543 | parser.add_argument('--meta-batch-size', type=int, default=10,
544 | help='number of tasks per batch')
545 | parser.add_argument('--slow-lr', type=float, default=0.001,
546 | help='learning rate for the global update of MAML')
547 |
548 | # Miscellaneous
549 | parser.add_argument('--output-folder', type=str, default='maml',
550 | help='name of the output folder')
551 | parser.add_argument('--device', type=str, default='cuda',
552 | help='set the device (cpu or cuda)')
553 | parser.add_argument('--num-workers', type=int, default=4,
554 | help='how many DataLoader workers to use')
555 | parser.add_argument('--log-interval', type=int, default=100,
556 | help='number of batches between tensorboard writes')
557 | parser.add_argument('--save-interval', type=int, default=1000,
558 | help='number of batches between model saves')
559 | parser.add_argument('--eval', action='store_true', default=False,
560 | help='evaluate model')
561 | parser.add_argument('--checkpoint', type=str, default='',
562 | help='path to saved parameters.')
563 |
564 | # Dataset
565 | parser.add_argument('--dataset', type=str, default='multimodal_few_shot',
566 | help='which dataset to use')
567 | parser.add_argument('--data-root', type=str, default='data',
568 | help='path to store datasets')
569 | parser.add_argument('--num-train-classes', type=int, default=1100,
570 | help='how many classes for training')
571 | parser.add_argument('--num-classes-per-batch', type=int, default=5,
572 | help='how many classes per task')
573 | parser.add_argument('--num-samples-per-class', type=int, default=1,
574 | help='how many samples per class for training')
575 | parser.add_argument('--num-val-samples', type=int, default=15,
576 | help='how many samples per class for validation')
577 | parser.add_argument('--img-side-len', type=int, default=28,
578 | help='width and height of the input images')
579 | parser.add_argument('--input-range', type=float, default=[-5.0, 5.0],
580 | nargs='+', help='input range of simple functions')
581 | parser.add_argument('--phase-range', type=float, default=[0, np.pi],
582 | nargs='+', help='phase range of sinusoids')
583 | parser.add_argument('--amp-range', type=float, default=[0.1, 5.0],
584 | nargs='+', help='amp range of sinusoids')
585 | parser.add_argument('--slope-range', type=float, default=[-3.0, 3.0],
586 | nargs='+', help='slope range of linear functions')
587 | parser.add_argument('--intersect-range', type=float, default=[-3.0, 3.0],
588 | nargs='+', help='intersect range of linear functions')
589 | parser.add_argument('--noise-std', type=float, default=0.0,
590 | help='add gaussian noise to mixed functions')
591 | parser.add_argument('--oracle', action='store_true',
592 | help='concatenate phase and amp to sinusoid inputs')
593 | parser.add_argument('--task-oracle', action='store_true',
594 | help='uses task id for prediction in some models')
595 |
596 | # Combine few-shot learning datasets
597 | parser.add_argument('--multimodal_few_shot', type=str,
598 | default=['omniglot', 'cifar', 'miniimagenet', 'doublemnist', 'triplemnist'],
599 | choices=['omniglot', 'cifar', 'miniimagenet', 'doublemnist', 'triplemnist',
600 | 'bird', 'aircraft'],
601 | nargs='+')
602 | parser.add_argument('--common-img-side-len', type=int, default=84)
603 | parser.add_argument('--common-img-channel', type=int, default=3,
604 | help='3 for RGB and 1 for grayscale')
605 | parser.add_argument('--mix-meta-batch', type=str2bool, default=True)
606 | parser.add_argument('--mix-mini-batch', type=str2bool, default=False)
607 |
608 | parser.add_argument('--alternating', action='store_true',
609 | help='')
610 | parser.add_argument('--classifier-schedule', type=int, default=10,
611 | help='')
612 | parser.add_argument('--embedding-schedule', type=int, default=10,
613 | help='')
614 | parser.add_argument('--conv-embedding', type=str2bool, default=True,
615 | help='')
616 | parser.add_argument('--conv-embedding-batch-norm', type=str2bool, default=True,
617 | help='')
618 | parser.add_argument('--conv-embedding-avgpool-after-conv', type=str2bool, default=True,
619 | help='')
620 | parser.add_argument('--num-conv-embedding-layer', type=int, default=4,
621 | help='')
622 | parser.add_argument('--no-rnn-aggregation', type=str2bool, default=True,
623 | help='')
624 | parser.add_argument('--embedding-pooling', type=str,
625 | choices=['avg', 'max'], default='avg', help='')
626 | parser.add_argument('--linear-before-rnn', action='store_true',
627 | help='')
628 | parser.add_argument('--embedding-grad-clip', type=float, default=0.0,
629 | help='')
630 | parser.add_argument('--verbose', type=str2bool, default=False,
631 | help='')
632 |
633 | args = parser.parse_args()
634 |
635 | # Create logs and saves folder if they don't exist
636 | if not os.path.exists('./train_dir'):
637 | os.makedirs('./train_dir')
638 |
639 | # Make sure num sample embedding < num sample tasks
640 | args.num_sample_embedding = min(args.num_sample_embedding, args.num_batches)
641 |
642 | # computer embedding dims
643 | num_gated_conv_layers = 4
644 | if args.embedding_dims == 0:
645 | args.embedding_dims = []
646 | for i in range(num_gated_conv_layers):
647 | embedding_dim = args.num_channels*2**i
648 | if args.condition_type == 'affine':
649 | embedding_dim *= 2
650 | args.embedding_dims.append(embedding_dim)
651 |
652 | assert not (args.mmaml_model and args.maml_model)
653 |
654 | # mmaml model: gated conv + convGRU
655 | if args.mmaml_model is True:
656 | print('Use MMAML')
657 | args.model_type = 'gatedconv'
658 | args.embedding_type = 'ConvGRU'
659 |
660 | # maml model: conv
661 | if args.maml_model is True:
662 | print('Use vanilla MAML')
663 | args.model_type = 'conv'
664 | args.embedding_type = ''
665 |
666 | # Device
667 | args.device = torch.device(args.device
668 | if torch.cuda.is_available() else 'cpu')
669 |
670 | # print args
671 | if args.verbose:
672 | print('='*10 + ' ARGS ' + '='*10)
673 | for k, v in sorted(vars(args).items()):
674 | print('{}: {}'.format(k, v))
675 | print('='*26)
676 |
677 | main(args)
678 |
--------------------------------------------------------------------------------
/maml/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/maml/__init__.py
--------------------------------------------------------------------------------
/maml/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/maml/datasets/__init__.py
--------------------------------------------------------------------------------
/maml/datasets/aircraft.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 | from collections import defaultdict
4 |
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data import DataLoader
8 | from torchvision import transforms
9 |
10 | from maml.sampler import ClassBalancedSampler
11 | from maml.datasets.metadataset import Task
12 |
13 |
14 | class AircraftMAMLSplit():
15 | def __init__(self, root, train=True, num_train_classes=70,
16 | transform=None, target_transform=None, **kwargs):
17 | self.transform = transform
18 | self.target_transform = target_transform
19 | self.root = root + '/aircraft'
20 |
21 | self._train = train
22 |
23 | if self._train:
24 | all_character_dirs = glob.glob(self.root + '/train/**')
25 | self._characters = all_character_dirs
26 | else:
27 | all_character_dirs = glob.glob(self.root + '/test/**')
28 | self._characters = all_character_dirs
29 |
30 | self._character_images = []
31 | for i, char_path in enumerate(self._characters):
32 | img_list = [(cp, i) for cp in glob.glob(char_path + '/*')]
33 | self._character_images.append(img_list)
34 |
35 | self._flat_character_images = sum(self._character_images, [])
36 |
37 | def __getitem__(self, index):
38 | """
39 | Args:
40 | index (int): Index
41 | Returns:
42 | tuple: (image, target) where target is index of the target
43 | character class.
44 | """
45 | image_path, character_class = self._flat_character_images[index]
46 | image = Image.open(image_path, mode='r')
47 |
48 | if self.transform:
49 | image = self.transform(image)
50 |
51 | if self.target_transform:
52 | character_class = self.target_transform(character_class)
53 |
54 | return image, character_class
55 |
56 |
57 | class AircraftMetaDataset(object):
58 | """
59 | TODO: Check if the data loader is fast enough.
60 | Args:
61 | root: path to aircraft dataset
62 | img_side_len: images are scaled to this size
63 | num_classes_per_batch: number of classes to sample for each batch
64 | num_samples_per_class: number of samples to sample for each class
65 | for each batch. For K shot learning this should be K + number
66 | of validation samples
67 | num_total_batches: total number of tasks to generate
68 | train: whether to create data loader from the test or validation data
69 | """
70 | def __init__(self, name='aircraft', root='data',
71 | img_side_len=84, img_channel=3,
72 | num_classes_per_batch=5, num_samples_per_class=6,
73 | num_total_batches=200000,
74 | num_val_samples=1, meta_batch_size=40, train=True,
75 | num_train_classes=1100, num_workers=0, device='cpu'):
76 | self.name = name
77 | self._root = root
78 | self._img_side_len = img_side_len
79 | self._img_channel = img_channel
80 | self._num_classes_per_batch = num_classes_per_batch
81 | self._num_samples_per_class = num_samples_per_class
82 | self._num_total_batches = num_total_batches
83 | self._num_val_samples = num_val_samples
84 | self._meta_batch_size = meta_batch_size
85 | self._num_train_classes = num_train_classes
86 | self._train = train
87 | self._num_workers = num_workers
88 | self._device = device
89 |
90 | self._total_samples_per_class = (
91 | num_samples_per_class + num_val_samples)
92 | self._dataloader = self._get_aircraft_data_loader()
93 |
94 | self.input_size = (img_channel, img_side_len, img_side_len)
95 | self.output_size = self._num_classes_per_batch
96 |
97 | def _get_aircraft_data_loader(self):
98 | assert self._img_channel == 1 or self._img_channel == 3
99 | resize = transforms.Resize(
100 | (self._img_side_len, self._img_side_len), Image.LANCZOS)
101 | if self._img_channel == 1:
102 | img_transform = transforms.Compose(
103 | [resize, transforms.Grayscale(num_output_channels=1),
104 | transforms.ToTensor()])
105 | else:
106 | img_transform = transforms.Compose(
107 | [resize, transforms.ToTensor()])
108 | dset = AircraftMAMLSplit(
109 | self._root, transform=img_transform, train=self._train,
110 | download=True, num_train_classes=self._num_train_classes)
111 | _, labels = zip(*dset._flat_character_images)
112 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch,
113 | self._total_samples_per_class,
114 | self._num_total_batches, self._train)
115 |
116 | batch_size = (self._num_classes_per_batch *
117 | self._total_samples_per_class *
118 | self._meta_batch_size)
119 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler,
120 | num_workers=self._num_workers, pin_memory=True)
121 | return loader
122 |
123 | def _make_single_batch(self, imgs, labels):
124 | """Split imgs and labels into train and validation set.
125 | TODO: check if this might become the bottleneck"""
126 | # relabel classes randomly
127 | new_labels = list(range(self._num_classes_per_batch))
128 | random.shuffle(new_labels)
129 | labels = labels.tolist()
130 | label_set = set(labels)
131 | label_map = {label: new_labels[i] for i, label in enumerate(label_set)}
132 | labels = [label_map[l] for l in labels]
133 |
134 | label_indices = defaultdict(list)
135 | for i, label in enumerate(labels):
136 | label_indices[label].append(i)
137 |
138 | # assign samples to train and validation sets
139 | val_indices = []
140 | train_indices = []
141 | for label, indices in label_indices.items():
142 | val_indices.extend(indices[:self._num_val_samples])
143 | train_indices.extend(indices[self._num_val_samples:])
144 | label_tensor = torch.tensor(labels, device=self._device)
145 | imgs = imgs.to(self._device)
146 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name)
147 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name)
148 |
149 | return train_task, val_task
150 |
151 | def _make_meta_batch(self, imgs, labels):
152 | batches = []
153 | inner_batch_size = (
154 | self._total_samples_per_class * self._num_classes_per_batch)
155 | for i in range(0, len(imgs) - 1, inner_batch_size):
156 | batch_imgs = imgs[i:i+inner_batch_size]
157 | batch_labels = labels[i:i+inner_batch_size]
158 | batch = self._make_single_batch(batch_imgs, batch_labels)
159 | batches.append(batch)
160 |
161 | train_tasks, val_tasks = zip(*batches)
162 |
163 | return train_tasks, val_tasks
164 |
165 | def __iter__(self):
166 | for imgs, labels in iter(self._dataloader):
167 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels)
168 | yield train_tasks, val_tasks
169 |
--------------------------------------------------------------------------------
/maml/datasets/bird.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 | from collections import defaultdict
4 |
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data import DataLoader
8 | from torchvision import transforms
9 |
10 | from maml.sampler import ClassBalancedSampler
11 | from maml.datasets.metadataset import Task
12 |
13 |
14 | class BirdMAMLSplit():
15 | def __init__(self, root, train=True, num_train_classes=140,
16 | transform=None, target_transform=None, **kwargs):
17 | self.transform = transform
18 | self.target_transform = target_transform
19 | self.root = root + '/bird'
20 |
21 | self._train = train
22 |
23 | if self._train:
24 | all_character_dirs = glob.glob(self.root + '/train/**')
25 | self._characters = all_character_dirs
26 | else:
27 | all_character_dirs = glob.glob(self.root + '/test/**')
28 | self._characters = all_character_dirs
29 |
30 | self._character_images = []
31 | for i, char_path in enumerate(self._characters):
32 | img_list = [(cp, i) for cp in glob.glob(char_path + '/*')]
33 | self._character_images.append(img_list)
34 |
35 | self._flat_character_images = sum(self._character_images, [])
36 |
37 | def __getitem__(self, index):
38 | """
39 | Args:
40 | index (int): Index
41 | Returns:
42 | tuple: (image, target) where target is index of the target
43 | character class.
44 | """
45 | image_path, character_class = self._flat_character_images[index]
46 | image = Image.open(image_path, mode='r')
47 |
48 | if self.transform:
49 | image = self.transform(image)
50 |
51 | if self.target_transform:
52 | character_class = self.target_transform(character_class)
53 |
54 | return image, character_class
55 |
56 |
57 | class BirdMetaDataset(object):
58 | """
59 | TODO: Check if the data loader is fast enough.
60 | Args:
61 | root: path to bird dataset
62 | img_side_len: images are scaled to this size
63 | num_classes_per_batch: number of classes to sample for each batch
64 | num_samples_per_class: number of samples to sample for each class
65 | for each batch. For K shot learning this should be K + number
66 | of validation samples
67 | num_total_batches: total number of tasks to generate
68 | train: whether to create data loader from the test or validation data
69 | """
70 | def __init__(self, name='Bird', root='data',
71 | img_side_len=84, img_channel=3,
72 | num_classes_per_batch=5, num_samples_per_class=6,
73 | num_total_batches=200000,
74 | num_val_samples=1, meta_batch_size=40, train=True,
75 | num_train_classes=1100, num_workers=0, device='cpu'):
76 | self.name = name
77 | self._root = root
78 | self._img_side_len = img_side_len
79 | self._img_channel = img_channel
80 | self._num_classes_per_batch = num_classes_per_batch
81 | self._num_samples_per_class = num_samples_per_class
82 | self._num_total_batches = num_total_batches
83 | self._num_val_samples = num_val_samples
84 | self._meta_batch_size = meta_batch_size
85 | self._num_train_classes = num_train_classes
86 | self._train = train
87 | self._num_workers = num_workers
88 | self._device = device
89 |
90 | self._total_samples_per_class = (
91 | num_samples_per_class + num_val_samples)
92 | self._dataloader = self._get_bird_data_loader()
93 |
94 | self.input_size = (img_channel, img_side_len, img_side_len)
95 | self.output_size = self._num_classes_per_batch
96 |
97 | def _get_bird_data_loader(self):
98 | assert self._img_channel == 1 or self._img_channel == 3
99 | resize = transforms.Resize(
100 | (self._img_side_len, self._img_side_len), Image.LANCZOS)
101 | if self._img_channel == 1:
102 | img_transform = transforms.Compose(
103 | [resize, transforms.Grayscale(num_output_channels=1),
104 | transforms.ToTensor()])
105 | else:
106 | img_transform = transforms.Compose(
107 | [resize, transforms.ToTensor()])
108 | dset = BirdMAMLSplit(
109 | self._root, transform=img_transform, train=self._train,
110 | download=True, num_train_classes=self._num_train_classes)
111 | _, labels = zip(*dset._flat_character_images)
112 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch,
113 | self._total_samples_per_class,
114 | self._num_total_batches, self._train)
115 |
116 | batch_size = (self._num_classes_per_batch *
117 | self._total_samples_per_class *
118 | self._meta_batch_size)
119 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler,
120 | num_workers=self._num_workers, pin_memory=True)
121 | return loader
122 |
123 | def _make_single_batch(self, imgs, labels):
124 | """Split imgs and labels into train and validation set.
125 | TODO: check if this might become the bottleneck"""
126 | # relabel classes randomly
127 | new_labels = list(range(self._num_classes_per_batch))
128 | random.shuffle(new_labels)
129 | labels = labels.tolist()
130 | label_set = set(labels)
131 | label_map = {label: new_labels[i] for i, label in enumerate(label_set)}
132 | labels = [label_map[l] for l in labels]
133 |
134 | label_indices = defaultdict(list)
135 | for i, label in enumerate(labels):
136 | label_indices[label].append(i)
137 |
138 | # assign samples to train and validation sets
139 | val_indices = []
140 | train_indices = []
141 | for label, indices in label_indices.items():
142 | val_indices.extend(indices[:self._num_val_samples])
143 | train_indices.extend(indices[self._num_val_samples:])
144 | label_tensor = torch.tensor(labels, device=self._device)
145 | imgs = imgs.to(self._device)
146 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name)
147 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name)
148 |
149 | return train_task, val_task
150 |
151 | def _make_meta_batch(self, imgs, labels):
152 | batches = []
153 | inner_batch_size = (
154 | self._total_samples_per_class * self._num_classes_per_batch)
155 | for i in range(0, len(imgs) - 1, inner_batch_size):
156 | batch_imgs = imgs[i:i+inner_batch_size]
157 | batch_labels = labels[i:i+inner_batch_size]
158 | batch = self._make_single_batch(batch_imgs, batch_labels)
159 | batches.append(batch)
160 |
161 | train_tasks, val_tasks = zip(*batches)
162 |
163 | return train_tasks, val_tasks
164 |
165 | def __iter__(self):
166 | for imgs, labels in iter(self._dataloader):
167 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels)
168 | yield train_tasks, val_tasks
169 |
--------------------------------------------------------------------------------
/maml/datasets/cifar100.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import random
4 | from collections import defaultdict
5 |
6 | import torch
7 | import numpy as np
8 | from PIL import Image
9 | from torch.utils.data import DataLoader
10 | from torchvision import transforms
11 | from torchvision.datasets.utils import list_files
12 |
13 | from maml.sampler import ClassBalancedSampler
14 | from maml.datasets.metadataset import Task
15 |
16 | class Cifar100MAMLSplit():
17 | def __init__(self, root, train=True, num_train_classes=100,
18 | transform=None, target_transform=None, **kwargs):
19 | self.transform = transform
20 | self.target_transform = target_transform
21 | self.root = os.path.join(root, 'cifar100.pth')
22 |
23 | self._train = train
24 | self._num_train_classes = num_train_classes
25 |
26 | self._dataset = torch.load(self.root)
27 | if self._train:
28 | self._images = torch.FloatTensor(self._dataset['data']['train'].reshape([-1, 3, 32, 32]))
29 | self._labels = torch.LongTensor(self._dataset['label']['train'])
30 | else:
31 | self._images = torch.FloatTensor(self._dataset['data']['test'].reshape([-1, 3, 32, 32]))
32 | self._labels = torch.LongTensor(self._dataset['label']['test'])
33 |
34 | def __getitem__(self, index):
35 | image = self._images[index]
36 |
37 | if self.transform:
38 | image = self.transform(self._images[index])
39 |
40 | return image, self._labels[index]
41 |
42 | class Cifar100MetaDataset(object):
43 | def __init__(self, name='FC100', root='data',
44 | img_side_len=32, img_channel=3,
45 | num_classes_per_batch=5, num_samples_per_class=6,
46 | num_total_batches=200000,
47 | num_val_samples=1, meta_batch_size=32, train=True,
48 | num_train_classes=100, num_workers=0, device='cpu'):
49 | self.name = name
50 | self._root = root
51 | self._img_side_len = img_side_len
52 | self._img_channel = img_channel
53 | self._num_classes_per_batch = num_classes_per_batch
54 | self._num_samples_per_class = num_samples_per_class
55 | self._num_total_batches = num_total_batches
56 | self._num_val_samples = num_val_samples
57 | self._meta_batch_size = meta_batch_size
58 | self._num_train_classes = num_train_classes
59 | self._train = train
60 | self._num_workers = num_workers
61 | self._device = device
62 |
63 | self._total_samples_per_class = (num_samples_per_class + num_val_samples)
64 | self._dataloader = self._get_cifar100_data_loader()
65 |
66 | self.input_size = (img_channel, img_side_len, img_side_len)
67 | self.output_size = self._num_classes_per_batch
68 |
69 | def _get_cifar100_data_loader(self):
70 | assert self._img_channel == 1 or self._img_channel == 3
71 | to_imgae = transforms.ToPILImage()
72 | resize = transforms.Resize(self._img_side_len, Image.LANCZOS)
73 | if self._img_channel == 1:
74 | img_transform = transforms.Compose(
75 | [to_imgae, resize,
76 | transforms.Grayscale(num_output_channels=1),
77 | transforms.ToTensor()])
78 | else:
79 | img_transform = transforms.Compose(
80 | [to_imgae, resize, transforms.ToTensor()])
81 | dset = Cifar100MAMLSplit(self._root, transform=img_transform,
82 | train=self._train, download=True,
83 | num_train_classes=self._num_train_classes)
84 | labels = dset._labels.numpy().tolist()
85 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch,
86 | self._total_samples_per_class,
87 | self._num_total_batches, self._train)
88 |
89 | batch_size = (self._num_classes_per_batch
90 | * self._total_samples_per_class
91 | * self._meta_batch_size)
92 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler,
93 | num_workers=self._num_workers, pin_memory=True)
94 | return loader
95 |
96 | def _make_single_batch(self, imgs, labels):
97 | """Split imgs and labels into train and validation set.
98 | TODO: check if this might become the bottleneck"""
99 | # relabel classes randomly
100 | new_labels = list(range(self._num_classes_per_batch))
101 | random.shuffle(new_labels)
102 | labels = labels.tolist()
103 | label_set = set(labels)
104 | label_map = {label: new_labels[i] for i, label in enumerate(label_set)}
105 | labels = [label_map[l] for l in labels]
106 |
107 | label_indices = defaultdict(list)
108 | for i, label in enumerate(labels):
109 | label_indices[label].append(i)
110 |
111 | # assign samples to train and validation sets
112 | val_indices = []
113 | train_indices = []
114 | for label, indices in label_indices.items():
115 | val_indices.extend(indices[:self._num_val_samples])
116 | train_indices.extend(indices[self._num_val_samples:])
117 | label_tensor = torch.tensor(labels, device=self._device)
118 | imgs = imgs.to(self._device)
119 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name)
120 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name)
121 |
122 | return train_task, val_task
123 |
124 | def _make_meta_batch(self, imgs, labels):
125 | batches = []
126 | inner_batch_size = (self._total_samples_per_class
127 | * self._num_classes_per_batch)
128 | for i in range(0, len(imgs) - 1, inner_batch_size):
129 | batch_imgs = imgs[i:i+inner_batch_size]
130 | batch_labels = labels[i:i+inner_batch_size]
131 | batch = self._make_single_batch(batch_imgs, batch_labels)
132 | batches.append(batch)
133 |
134 | train_tasks, val_tasks = zip(*batches)
135 |
136 | return train_tasks, val_tasks
137 |
138 | def __iter__(self):
139 | for imgs, labels in iter(self._dataloader):
140 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels)
141 | yield train_tasks, val_tasks
142 |
--------------------------------------------------------------------------------
/maml/datasets/metadataset.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, namedtuple
2 |
3 | Task = namedtuple('Task', ['x', 'y', 'task_info'])
--------------------------------------------------------------------------------
/maml/datasets/miniimagenet.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 | from collections import defaultdict
4 |
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data import DataLoader
8 | from torchvision import transforms
9 |
10 | from maml.sampler import ClassBalancedSampler
11 | from maml.datasets.metadataset import Task
12 |
13 |
14 | class MiniimagenetMAMLSplit():
15 | def __init__(self, root, train=True, num_train_classes=1200,
16 | transform=None, target_transform=None, **kwargs):
17 | self.transform = transform
18 | self.target_transform = target_transform
19 | self.root = root + '/miniimagenet'
20 |
21 | self._train = train
22 |
23 | if self._train:
24 | all_character_dirs = glob.glob(self.root + '/train/**')
25 | self._characters = all_character_dirs
26 | else:
27 | all_character_dirs = glob.glob(self.root + '/test/**')
28 | self._characters = all_character_dirs
29 |
30 | self._character_images = []
31 | for i, char_path in enumerate(self._characters):
32 | img_list = [(cp, i) for cp in glob.glob(char_path + '/*')]
33 | self._character_images.append(img_list)
34 |
35 | self._flat_character_images = sum(self._character_images, [])
36 |
37 | def __getitem__(self, index):
38 | """
39 | Args:
40 | index (int): Index
41 | Returns:
42 | tuple: (image, target) where target is index of the target
43 | character class.
44 | """
45 | image_path, character_class = self._flat_character_images[index]
46 | image = Image.open(image_path, mode='r')
47 |
48 | if self.transform:
49 | image = self.transform(image)
50 |
51 | if self.target_transform:
52 | character_class = self.target_transform(character_class)
53 |
54 | return image, character_class
55 |
56 |
57 | class MiniimagenetMetaDataset(object):
58 | """
59 | TODO: Check if the data loader is fast enough.
60 | Args:
61 | root: path to miniimagenet dataset
62 | img_side_len: images are scaled to this size
63 | num_classes_per_batch: number of classes to sample for each batch
64 | num_samples_per_class: number of samples to sample for each class
65 | for each batch. For K shot learning this should be K + number
66 | of validation samples
67 | num_total_batches: total number of tasks to generate
68 | train: whether to create data loader from the test or validation data
69 | """
70 | def __init__(self, name='MiniImageNet', root='data',
71 | img_side_len=84, img_channel=3,
72 | num_classes_per_batch=5, num_samples_per_class=6,
73 | num_total_batches=200000,
74 | num_val_samples=1, meta_batch_size=40, train=True,
75 | num_train_classes=1100, num_workers=0, device='cpu'):
76 | self.name = name
77 | self._root = root
78 | self._img_side_len = img_side_len
79 | self._img_channel = img_channel
80 | self._num_classes_per_batch = num_classes_per_batch
81 | self._num_samples_per_class = num_samples_per_class
82 | self._num_total_batches = num_total_batches
83 | self._num_val_samples = num_val_samples
84 | self._meta_batch_size = meta_batch_size
85 | self._num_train_classes = num_train_classes
86 | self._train = train
87 | self._num_workers = num_workers
88 | self._device = device
89 |
90 | self._total_samples_per_class = (
91 | num_samples_per_class + num_val_samples)
92 | self._dataloader = self._get_miniimagenet_data_loader()
93 |
94 | self.input_size = (img_channel, img_side_len, img_side_len)
95 | self.output_size = self._num_classes_per_batch
96 |
97 | def _get_miniimagenet_data_loader(self):
98 | assert self._img_channel == 1 or self._img_channel == 3
99 | resize = transforms.Resize(
100 | (self._img_side_len, self._img_side_len), Image.LANCZOS)
101 | if self._img_channel == 1:
102 | img_transform = transforms.Compose(
103 | [resize, transforms.Grayscale(num_output_channels=1),
104 | transforms.ToTensor()])
105 | else:
106 | img_transform = transforms.Compose(
107 | [resize, transforms.ToTensor()])
108 | dset = MiniimagenetMAMLSplit(
109 | self._root, transform=img_transform, train=self._train,
110 | download=True, num_train_classes=self._num_train_classes)
111 | _, labels = zip(*dset._flat_character_images)
112 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch,
113 | self._total_samples_per_class,
114 | self._num_total_batches, self._train)
115 |
116 | batch_size = (self._num_classes_per_batch *
117 | self._total_samples_per_class *
118 | self._meta_batch_size)
119 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler,
120 | num_workers=self._num_workers, pin_memory=True)
121 | return loader
122 |
123 | def _make_single_batch(self, imgs, labels):
124 | """Split imgs and labels into train and validation set.
125 | TODO: check if this might become the bottleneck"""
126 | # relabel classes randomly
127 | new_labels = list(range(self._num_classes_per_batch))
128 | random.shuffle(new_labels)
129 | labels = labels.tolist()
130 | label_set = set(labels)
131 | label_map = {label: new_labels[i] for i, label in enumerate(label_set)}
132 | labels = [label_map[l] for l in labels]
133 |
134 | label_indices = defaultdict(list)
135 | for i, label in enumerate(labels):
136 | label_indices[label].append(i)
137 |
138 | # assign samples to train and validation sets
139 | val_indices = []
140 | train_indices = []
141 | for label, indices in label_indices.items():
142 | val_indices.extend(indices[:self._num_val_samples])
143 | train_indices.extend(indices[self._num_val_samples:])
144 | label_tensor = torch.tensor(labels, device=self._device)
145 | imgs = imgs.to(self._device)
146 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name)
147 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name)
148 |
149 | return train_task, val_task
150 |
151 | def _make_meta_batch(self, imgs, labels):
152 | batches = []
153 | inner_batch_size = (
154 | self._total_samples_per_class * self._num_classes_per_batch)
155 | for i in range(0, len(imgs) - 1, inner_batch_size):
156 | batch_imgs = imgs[i:i+inner_batch_size]
157 | batch_labels = labels[i:i+inner_batch_size]
158 | batch = self._make_single_batch(batch_imgs, batch_labels)
159 | batches.append(batch)
160 |
161 | train_tasks, val_tasks = zip(*batches)
162 |
163 | return train_tasks, val_tasks
164 |
165 | def __iter__(self):
166 | for imgs, labels in iter(self._dataloader):
167 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels)
168 | yield train_tasks, val_tasks
169 |
--------------------------------------------------------------------------------
/maml/datasets/multimodal_few_shot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from itertools import chain
3 | from maml.datasets.metadataset import Task
4 |
5 |
6 | class MultimodalFewShotDataset(object):
7 |
8 | def __init__(self, datasets, num_total_batches,
9 | name='MultimodalFewShot',
10 | mix_meta_batch=True, mix_mini_batch=False,
11 | train=True, verbose=False, txt_file=None):
12 | self._datasets = datasets
13 | self._num_total_batches = num_total_batches
14 | self.name = name
15 | self.num_dataset = len(datasets)
16 | self.dataset_names = [dataset.name for dataset in self._datasets]
17 | self._meta_batch_size = datasets[0]._meta_batch_size
18 | self._mix_meta_batch = mix_meta_batch
19 | self._mix_mini_batch = mix_mini_batch
20 | self._train = train
21 | self._verbose = verbose
22 | self._txt_file = open(txt_file, 'w') if not txt_file is None else None
23 |
24 | # make sure all input/output sizes match
25 | input_size_list = [dataset.input_size for dataset in self._datasets]
26 | assert input_size_list.count(input_size_list[0]) == len(input_size_list)
27 | output_size_list = [dataset.output_size for dataset in self._datasets]
28 | assert output_size_list.count(output_size_list[0]) == len(output_size_list)
29 | self.input_size = datasets[0].input_size
30 | self.output_size = datasets[0].output_size
31 |
32 | # build iterators
33 | self._datasets_iter = [iter(dataset) for dataset in self._datasets]
34 | self._iter_index = 0
35 |
36 | # print info
37 | print('Multimodal Few Shot Datasets: {}'.format(' '.join(self.dataset_names)))
38 | print('mix meta batch: {}'.format(mix_meta_batch))
39 | print('mix mini batch: {}'.format(mix_mini_batch))
40 |
41 | def __next__(self):
42 | if self.n < self._num_total_batches:
43 | if not self._mix_meta_batch and not self._mix_mini_batch:
44 | dataset_index = np.random.randint(len(self._datasets))
45 | if self._verbose:
46 | print('Sample from: {}'.format(self._datasets[dataset_index].name))
47 | train_tasks, val_tasks = next(self._datasets_iter[dataset_index])
48 | return train_tasks, val_tasks
49 | else:
50 | # get all tasks
51 | tasks = []
52 | all_train_tasks = []
53 | all_val_tasks = []
54 | for dataset_iter in self._datasets_iter:
55 | train_tasks, val_tasks = next(dataset_iter)
56 | all_train_tasks.extend(train_tasks)
57 | all_val_tasks.extend(val_tasks)
58 |
59 | if not self._mix_mini_batch:
60 | # mix them to obtain a meta batch
61 | """
62 | # randomly sample task
63 | dataset_indexes = np.random.choice(
64 | len(all_train_tasks), size=self._meta_batch_size, replace=False)
65 | """
66 | # balancedly sample from all datasets
67 | dataset_indexes = []
68 | if self._train:
69 | dataset_start_idx = np.random.randint(0, self.num_dataset)
70 | else:
71 | dataset_start_idx = (self._iter_index + self._meta_batch_size) % self.num_dataset
72 | self._iter_index += self._meta_batch_size
73 | self._iter_index = self._iter_index % self.num_dataset
74 |
75 | for i in range(self._meta_batch_size):
76 | dataset_indexes.append(
77 | np.random.randint(0, self._meta_batch_size)+
78 | ((i+dataset_start_idx)%self.num_dataset)*self._meta_batch_size)
79 |
80 | train_tasks = []
81 | val_tasks = []
82 | dataset_names = []
83 | for dataset_index in dataset_indexes:
84 | train_tasks.append(all_train_tasks[dataset_index])
85 | val_tasks.append(all_val_tasks[dataset_index])
86 | dataset_names.append(self._datasets[dataset_index//self._meta_batch_size].name)
87 | if self._verbose:
88 | print('Sample from: {} (indexes: {})'.format(
89 | [name for name in dataset_names], dataset_indexes))
90 | if self._txt_file is not None:
91 | for name in dataset_names:
92 | self._txt_file.write(name+'\n')
93 | return train_tasks, val_tasks
94 | else:
95 | # mix them to obtain a mini batch and make a meta batch
96 | raise NotImplementedError
97 | else:
98 | raise StopIteration
99 |
100 | def __iter__(self):
101 | self.n = 0
102 | return self
103 |
--------------------------------------------------------------------------------
/maml/datasets/omniglot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import random
4 | from collections import defaultdict
5 |
6 | import torch
7 | import numpy as np
8 | from PIL import Image
9 | from torch.utils.data import DataLoader
10 | from torchvision import transforms
11 | from torchvision.datasets import Omniglot
12 | from torchvision.datasets.utils import list_files
13 |
14 | from maml.sampler import ClassBalancedSampler
15 | from maml.datasets.metadataset import Task
16 |
17 |
18 | class OmniglotMAMLSplit(Omniglot):
19 | """Implements similar train / test split for Omniglot as
20 | https://github.com/cbfinn/maml/blob/master/data_generator.py
21 |
22 | Uses torchvision.datasets.Omniglot for downloading and checking
23 | dataset integrity.
24 | """
25 | def __init__(self, root, train=True, num_train_classes=1100, **kwargs):
26 | super(OmniglotMAMLSplit, self).__init__(root, download=True,
27 | background=True, **kwargs)
28 |
29 | self._train = train
30 | self._num_train_classes = num_train_classes
31 |
32 | # download testing data and test integrity
33 | self.background = False
34 | self.download()
35 | if not self._check_integrity():
36 | raise RuntimeError('Dataset not found or corrupted')
37 |
38 | all_character_dirs = glob.glob(self.root + '/**/**/**')
39 | if self._train:
40 | self._characters = all_character_dirs[:self._num_train_classes]
41 | else:
42 | self._characters = all_character_dirs[self._num_train_classes:]
43 |
44 | self._character_images = []
45 | for i, char_path in enumerate(self._characters):
46 | img_list = [(cp, i) for cp in glob.glob(char_path + '/*')]
47 | self._character_images.append(img_list)
48 |
49 | self._flat_character_images = sum(self._character_images, [])
50 |
51 | def __getitem__(self, index):
52 | """
53 | Args:
54 | index (int): Index
55 | Returns:
56 | tuple: (image, target) where target is index of the target
57 | character class.
58 | """
59 | image_path, character_class = self._flat_character_images[index]
60 | image = Image.open(image_path, mode='r').convert('L')
61 |
62 | if self.transform:
63 | image = self.transform(image)
64 |
65 | if self.target_transform:
66 | character_class = self.target_transform(character_class)
67 |
68 | return image, character_class
69 |
70 |
71 | class OmniglotMetaDataset(object):
72 | """
73 | TODO: Check if the data loader is fast enough.
74 | Args:
75 | root: path to omniglot dataset
76 | img_side_len: images are scaled to this size
77 | num_classes_per_batch: number of classes to sample for each batch
78 | num_samples_per_class: number of samples to sample for each class
79 | for each batch. For K shot learning this should be K + number
80 | of validation samples
81 | num_total_batches: total number of tasks to generate
82 | train: whether to create data loader from the test or validation data
83 | """
84 | def __init__(self, name='Omniglot', root='data',
85 | img_side_len=28, img_channel=1,
86 | num_classes_per_batch=5, num_samples_per_class=6,
87 | num_total_batches=200000,
88 | num_val_samples=1, meta_batch_size=40, train=True,
89 | num_train_classes=1100, num_workers=0, device='cpu'):
90 | self.name = name
91 | self._root = root
92 | self._img_side_len = img_side_len
93 | self._img_channel = img_channel
94 | self._num_classes_per_batch = num_classes_per_batch
95 | self._num_samples_per_class = num_samples_per_class
96 | self._num_total_batches = num_total_batches
97 | self._num_val_samples = num_val_samples
98 | self._meta_batch_size = meta_batch_size
99 | self._num_train_classes = num_train_classes
100 | self._train = train
101 | self._num_workers = num_workers
102 | self._device = device
103 |
104 | self._total_samples_per_class = (
105 | num_samples_per_class + num_val_samples)
106 | self._dataloader = self._get_omniglot_data_loader()
107 |
108 | self.input_size = (img_channel, img_side_len, img_side_len)
109 | self.output_size = self._num_classes_per_batch
110 |
111 | def _get_omniglot_data_loader(self):
112 | assert self._img_channel == 1 or self._img_channel == 3
113 | resize = transforms.Resize(self._img_side_len, Image.LANCZOS)
114 | invert = transforms.Lambda(lambda x: 1.0 - x)
115 | if self._img_channel > 1:
116 | # tile the image
117 | tile = transforms.Lambda(lambda x: x.repeat(self._img_channel, 1, 1))
118 | img_transform = transforms.Compose(
119 | [resize, transforms.ToTensor(), invert, tile])
120 | else:
121 | img_transform = transforms.Compose(
122 | [resize, transforms.ToTensor(), invert])
123 | dset = OmniglotMAMLSplit(self._root, transform=img_transform,
124 | train=self._train,
125 | num_train_classes=self._num_train_classes)
126 | _, labels = zip(*dset._flat_character_images)
127 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch,
128 | self._total_samples_per_class,
129 | self._num_total_batches, self._train)
130 |
131 | batch_size = (self._num_classes_per_batch
132 | * self._total_samples_per_class
133 | * self._meta_batch_size)
134 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler,
135 | num_workers=self._num_workers, pin_memory=True)
136 | return loader
137 |
138 | def _make_single_batch(self, imgs, labels):
139 | """Split imgs and labels into train and validation set.
140 | TODO: check if this might become the bottleneck"""
141 | # relabel classes randomly
142 | new_labels = list(range(self._num_classes_per_batch))
143 | random.shuffle(new_labels)
144 | labels = labels.tolist()
145 | label_set = set(labels)
146 | label_map = {label: i for i, label in zip(new_labels, label_set)}
147 | labels = [label_map[l] for l in labels]
148 |
149 | label_indices = defaultdict(list)
150 | for i, label in enumerate(labels):
151 | label_indices[label].append(i)
152 |
153 | # rotate randomly to create new classes
154 | # TODO: move this to torch once supported.
155 | for label, indices in label_indices.items():
156 | rotation = np.random.randint(4)
157 | for i in range(len(indices)):
158 | img = imgs[indices[i]].numpy()
159 | # copy here for contiguity
160 | img = np.copy(np.rot90(img, k=rotation, axes=(1,2)))
161 | imgs[indices[i]] = torch.from_numpy(img)
162 |
163 | # assign samples to train and validation sets
164 | val_indices = []
165 | train_indices = []
166 | for label, indices in label_indices.items():
167 | val_indices.extend(indices[:self._num_val_samples])
168 | train_indices.extend(indices[self._num_val_samples:])
169 | label_tensor = torch.tensor(labels, device=self._device)
170 | imgs = imgs.to(self._device)
171 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name)
172 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name)
173 |
174 | return train_task, val_task
175 |
176 | def _make_meta_batch(self, imgs, labels):
177 | batches = []
178 | inner_batch_size = (self._total_samples_per_class
179 | * self._num_classes_per_batch)
180 | for i in range(0, len(imgs) - 1, inner_batch_size):
181 | batch_imgs = imgs[i:i+inner_batch_size]
182 | batch_labels = labels[i:i+inner_batch_size]
183 | batch = self._make_single_batch(batch_imgs, batch_labels)
184 | batches.append(batch)
185 |
186 | train_tasks, val_tasks = zip(*batches)
187 |
188 | return train_tasks, val_tasks
189 |
190 | def __iter__(self):
191 | for imgs, labels in iter(self._dataloader):
192 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels)
193 | yield train_tasks, val_tasks
194 |
--------------------------------------------------------------------------------
/maml/metalearner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from torch.nn.utils.clip_grad import clip_grad_norm_
5 | from maml.utils import accuracy
6 |
7 | def get_grad_norm(parameters, norm_type=2):
8 | if isinstance(parameters, torch.Tensor):
9 | parameters = [parameters]
10 | parameters = list(filter(lambda p: p.grad is not None, parameters))
11 | norm_type = float(norm_type)
12 | total_norm = 0
13 | for p in parameters:
14 | param_norm = p.grad.data.norm(norm_type)
15 | total_norm += param_norm.item() ** norm_type
16 | total_norm = total_norm ** (1. / norm_type)
17 |
18 | return total_norm
19 |
20 | class MetaLearner(object):
21 | def __init__(self, model, embedding_model, optimizers, fast_lr, loss_func,
22 | first_order, num_updates, inner_loop_grad_clip,
23 | collect_accuracies, device, alternating=False,
24 | embedding_schedule=10, classifier_schedule=10,
25 | embedding_grad_clip=0):
26 | self._model = model
27 | self._embedding_model = embedding_model
28 | self._fast_lr = fast_lr
29 | self._optimizers = optimizers
30 | self._loss_func = loss_func
31 | self._first_order = first_order
32 | self._num_updates = num_updates
33 | self._inner_loop_grad_clip = inner_loop_grad_clip
34 | self._collect_accuracies = collect_accuracies
35 | self._device = device
36 | self._alternating = alternating
37 | self._alternating_schedules = (classifier_schedule, embedding_schedule)
38 | self._alternating_count = 0
39 | self._alternating_index = 1
40 | self._embedding_grad_clip = embedding_grad_clip
41 | self._grads_mean = []
42 |
43 | self.to(device)
44 |
45 | self._reset_measurements()
46 |
47 | def _reset_measurements(self):
48 | self._count_iters = 0.0
49 | self._cum_loss = 0.0
50 | self._cum_accuracy = 0.0
51 |
52 | def _update_measurements(self, task, loss, preds):
53 | self._count_iters += 1.0
54 | self._cum_loss += loss.data.cpu().numpy()
55 | if self._collect_accuracies:
56 | self._cum_accuracy += accuracy(
57 | preds, task.y).data.cpu().numpy()
58 |
59 | def _pop_measurements(self):
60 | measurements = {}
61 | loss = self._cum_loss / self._count_iters
62 | measurements['loss'] = loss
63 | if self._collect_accuracies:
64 | accuracy = self._cum_accuracy / self._count_iters
65 | measurements['accuracy'] = accuracy
66 | self._reset_measurements()
67 | return measurements
68 |
69 | def measure(self, tasks, train_tasks=None, adapted_params_list=None,
70 | embeddings_list=None):
71 | """Measures performance on tasks. Either train_tasks has to be a list
72 | of training task for computing embeddings, or adapted_params_list and
73 | embeddings_list have to contain adapted_params and embeddings"""
74 | if adapted_params_list is None:
75 | adapted_params_list = [None] * len(tasks)
76 | if embeddings_list is None:
77 | embeddings_list = [None] * len(tasks)
78 | for i in range(len(tasks)):
79 | params = adapted_params_list[i]
80 | if params is None:
81 | params = self._model.param_dict
82 | embeddings = embeddings_list[i]
83 | task = tasks[i]
84 | preds = self._model(task, params=params, embeddings=embeddings)
85 | loss = self._loss_func(preds, task.y)
86 | self._update_measurements(task, loss, preds)
87 |
88 | measurements = self._pop_measurements()
89 | return measurements
90 |
91 | def measure_each(self, tasks, train_tasks=None, adapted_params_list=None,
92 | embeddings_list=None):
93 | """Measures performance on tasks. Either train_tasks has to be a list
94 | of training task for computing embeddings, or adapted_params_list and
95 | embeddings_list have to contain adapted_params and embeddings"""
96 | """Return a list of losses and accuracies"""
97 | if adapted_params_list is None:
98 | adapted_params_list = [None] * len(tasks)
99 | if embeddings_list is None:
100 | embeddings_list = [None] * len(tasks)
101 | accuracies = []
102 | for i in range(len(tasks)):
103 | params = adapted_params_list[i]
104 | if params is None:
105 | params = self._model.param_dict
106 | embeddings = embeddings_list[i]
107 | task = tasks[i]
108 | preds = self._model(task, params=params, embeddings=embeddings)
109 | pred_y = np.argmax(preds.data.cpu().numpy(), axis=-1)
110 | accuracy = np.mean(
111 | task.y.data.cpu().numpy() ==
112 | np.argmax(preds.data.cpu().numpy(), axis=-1))
113 | accuracies.append(accuracy)
114 |
115 | return accuracies
116 |
117 | def update_params(self, loss, params):
118 | """Apply one step of gradient descent on the loss function `loss`,
119 | with step-size `self._fast_lr`, and returns the updated parameters.
120 | """
121 | create_graph = not self._first_order
122 | grads = torch.autograd.grad(loss, params.values(),
123 | create_graph=create_graph, allow_unused=True)
124 | for (name, param), grad in zip(params.items(), grads):
125 | if self._inner_loop_grad_clip > 0 and grad is not None:
126 | grad = grad.clamp(min=-self._inner_loop_grad_clip,
127 | max=self._inner_loop_grad_clip)
128 | if grad is not None:
129 | params[name] = param - self._fast_lr * grad
130 |
131 | return params
132 |
133 | def adapt(self, train_tasks):
134 | adapted_params = []
135 | embeddings_list = []
136 |
137 | for task in train_tasks:
138 | params = self._model.param_dict
139 | embeddings = None
140 | if self._embedding_model:
141 | embeddings = self._embedding_model(task)
142 | for i in range(self._num_updates):
143 | preds = self._model(task, params=params, embeddings=embeddings)
144 | loss = self._loss_func(preds, task.y)
145 | params = self.update_params(loss, params=params)
146 | if i == 0:
147 | self._update_measurements(task, loss, preds)
148 | adapted_params.append(params)
149 | embeddings_list.append(embeddings)
150 |
151 | measurements = self._pop_measurements()
152 | return measurements, adapted_params, embeddings_list
153 |
154 | def step(self, adapted_params_list, embeddings_list, val_tasks,
155 | is_training):
156 | for optimizer in self._optimizers:
157 | optimizer.zero_grad()
158 | post_update_losses = []
159 |
160 | for adapted_params, embeddings, task in zip(
161 | adapted_params_list, embeddings_list, val_tasks):
162 | preds = self._model(task, params=adapted_params,
163 | embeddings=embeddings)
164 | loss = self._loss_func(preds, task.y)
165 | post_update_losses.append(loss)
166 | self._update_measurements(task, loss, preds)
167 |
168 | mean_loss = torch.mean(torch.stack(post_update_losses))
169 | if is_training:
170 | mean_loss.backward()
171 | if self._alternating:
172 | self._optimizers[self._alternating_index].step()
173 | self._alternating_count += 1
174 | if self._alternating_count % self._alternating_schedules[self._alternating_index] == 0:
175 | self._alternating_index = (1 - self._alternating_index)
176 | self._alternating_count = 0
177 | else:
178 | self._optimizers[0].step()
179 | if len(self._optimizers) > 1:
180 | if self._embedding_grad_clip > 0:
181 | _grad_norm = clip_grad_norm_(self._embedding_model.parameters(), self._embedding_grad_clip)
182 | else:
183 | _grad_norm = get_grad_norm(self._embedding_model.parameters())
184 | # grad_norm
185 | self._grads_mean.append(_grad_norm)
186 | self._optimizers[1].step()
187 |
188 | measurements = self._pop_measurements()
189 | return measurements
190 |
191 | def to(self, device, **kwargs):
192 | self._device = device
193 | self._model.to(device, **kwargs)
194 | if self._embedding_model:
195 | self._embedding_model.to(device, **kwargs)
196 |
197 | def state_dict(self):
198 | state = {
199 | 'model_state_dict': self._model.state_dict(),
200 | 'optimizers': [ optimizer.state_dict() for optimizer in self._optimizers ]
201 | }
202 | if self._embedding_model:
203 | state.update(
204 | {'embedding_model_state_dict':
205 | self._embedding_model.state_dict()})
206 | return state
207 |
--------------------------------------------------------------------------------
/maml/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/maml/models/__init__.py
--------------------------------------------------------------------------------
/maml/models/conv_embedding_model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class ConvEmbeddingModel(torch.nn.Module):
8 | def __init__(self, input_size, output_size, embedding_dims,
9 | hidden_size=128, num_layers=1,
10 | convolutional=False, num_conv=4, num_channels=32, num_channels_max=256,
11 | rnn_aggregation=False, linear_before_rnn=False,
12 | embedding_pooling='max', batch_norm=True, avgpool_after_conv=True,
13 | num_sample_embedding=0, sample_embedding_file='embedding.hdf5',
14 | img_size=(1, 28, 28), verbose=False):
15 |
16 | super(ConvEmbeddingModel, self).__init__()
17 | self._input_size = input_size
18 | self._output_size = output_size
19 | self._hidden_size = hidden_size
20 | self._num_layers = num_layers
21 | self._embedding_dims = embedding_dims
22 | self._bidirectional = True
23 | self._device = 'cpu'
24 | self._convolutional = convolutional
25 | self._num_conv = num_conv
26 | self._num_channels = num_channels
27 | self._num_channels_max = num_channels_max
28 | self._batch_norm = batch_norm
29 | self._img_size = img_size
30 | self._rnn_aggregation = rnn_aggregation
31 | self._embedding_pooling = embedding_pooling
32 | self._linear_before_rnn = linear_before_rnn
33 | self._embeddings_array = []
34 | self._num_sample_embedding = num_sample_embedding
35 | self._sample_embedding_file = sample_embedding_file
36 | self._avgpool_after_conv = avgpool_after_conv
37 | self._reuse = False
38 | self._verbose = verbose
39 |
40 | if self._convolutional:
41 | conv_list = OrderedDict([])
42 | num_ch = [self._img_size[0]] + [self._num_channels*2**i for i in range(self._num_conv)]
43 | num_ch = [min(num_channels_max, ch) for ch in num_ch]
44 | for i in range(self._num_conv):
45 | conv_list.update({
46 | 'conv{}'.format(i+1):
47 | torch.nn.Conv2d(num_ch[i], num_ch[i+1],
48 | (3, 3), stride=2, padding=1)})
49 | if self._batch_norm:
50 | conv_list.update({
51 | 'bn{}'.format(i+1):
52 | torch.nn.BatchNorm2d(num_ch[i+1], momentum=0.001)})
53 | conv_list.update({'relu{}'.format(i+1): torch.nn.ReLU(inplace=True)})
54 | self.conv = torch.nn.Sequential(conv_list)
55 | self._num_layer_per_conv = len(conv_list) // self._num_conv
56 |
57 | if self._linear_before_rnn:
58 | linear_input_size = self.compute_input_size(
59 | 1, 3, 2, self.conv[self._num_layer_per_conv*(self._num_conv-1)].out_channels)
60 | rnn_input_size = 128
61 | else:
62 | if self._avgpool_after_conv:
63 | rnn_input_size = self.conv[self._num_layer_per_conv*(self._num_conv-1)].out_channels
64 | else:
65 | rnn_input_size = self.compute_input_size(
66 | 1, 3, 2, self.conv[self._num_layer_per_conv*(self._num_conv-1)].out_channels)
67 | else:
68 | rnn_input_size = int(input_size)
69 |
70 | if self._rnn_aggregation:
71 | if self._linear_before_rnn:
72 | self.linear = torch.nn.Linear(linear_input_size, rnn_input_size)
73 | self.relu_after_linear = torch.nn.ReLU(inplace=True)
74 | self.rnn = torch.nn.GRU(rnn_input_size, hidden_size,
75 | num_layers, bidirectional=self._bidirectional)
76 | embedding_input_size = hidden_size*(2 if self._bidirectional else 1)
77 | else:
78 | self.rnn = None
79 | embedding_input_size = hidden_size
80 | self.linear = torch.nn.Linear(rnn_input_size, embedding_input_size)
81 | self.relu_after_linear = torch.nn.ReLU(inplace=True)
82 |
83 | self._embeddings = torch.nn.ModuleList()
84 | for dim in embedding_dims:
85 | self._embeddings.append(torch.nn.Linear(embedding_input_size, dim))
86 |
87 | def compute_input_size(self, p, k, s, ch):
88 | current_img_size = self._img_size[1]
89 | for _ in range(self._num_conv):
90 | current_img_size = (current_img_size+2*p-k)//s+1
91 | return ch * int(current_img_size) ** 2
92 |
93 | def forward(self, task, params=None):
94 | if not self._reuse and self._verbose: print('='*8 + ' Emb Model ' + '='*8)
95 | if params is None:
96 | params = OrderedDict(self.named_parameters())
97 |
98 | if self._convolutional:
99 | x = task.x
100 | if not self._reuse and self._verbose: print('input size: {}'.format(x.size()))
101 | for layer_name, layer in self.conv.named_children():
102 | weight = params.get('conv.' + layer_name + '.weight', None)
103 | bias = params.get('conv.' + layer_name + '.bias', None)
104 | if 'conv' in layer_name:
105 | x = F.conv2d(x, weight=weight, bias=bias, stride=2, padding=1)
106 | elif 'relu' in layer_name:
107 | x = F.relu(x)
108 | elif 'bn' in layer_name:
109 | x = F.batch_norm(x, weight=weight, bias=bias,
110 | running_mean=layer.running_mean,
111 | running_var=layer.running_var,
112 | training=True)
113 | if not self._reuse and self._verbose: print('{}: {}'.format(layer_name, x.size()))
114 | if self._avgpool_after_conv:
115 | x = x.view(x.size(0), x.size(1), -1)
116 | if not self._reuse and self._verbose: print('reshape to: {}'.format(x.size()))
117 | x = torch.mean(x, dim=2)
118 | if not self._reuse and self._verbose: print('reduce mean: {}'.format(x.size()))
119 |
120 | else:
121 | x = task.x.view(task.x.size(0), -1)
122 | if not self._reuse and self._verbose: print('flatten: {}'.format(x.size()))
123 | else:
124 | x = task.x.view(task.x.size(0), -1)
125 | if not self._reuse and self._verbose: print('flatten: {}'.format(x.size()))
126 |
127 | if self._rnn_aggregation:
128 | # LSTM input dimensions are seq_len, batch, input_size
129 | batch_size = 1
130 | h0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1),
131 | batch_size, self._hidden_size, device=self._device)
132 | if self._linear_before_rnn:
133 | x = F.relu(self.linear(x))
134 | inputs = x.view(x.size(0), 1, -1)
135 | output, hn = self.rnn(inputs, h0)
136 | if self._bidirectional:
137 | N, B, H = output.shape
138 | output = output.view(N, B, 2, H // 2)
139 | embedding_input = torch.cat([output[-1, :, 0], output[0, :, 1]], dim=1)
140 |
141 | else:
142 | inputs = F.relu(self.linear(x).view(1, x.size(0), -1).transpose(1, 2))
143 | if not self._reuse and self._verbose: print('fc: {}'.format(inputs.size()))
144 | if self._embedding_pooling == 'max':
145 | embedding_input = F.max_pool1d(inputs, x.size(0)).view(1, -1)
146 | elif self._embedding_pooling == 'avg':
147 | embedding_input = F.avg_pool1d(inputs, x.size(0)).view(1, -1)
148 | else:
149 | raise NotImplementedError
150 | if not self._reuse and self._verbose: print('reshape after {}pool: {}'.format(
151 | self._embedding_pooling, embedding_input.size()))
152 |
153 | # randomly sample embedding vectors
154 | if not self._num_sample_embedding == 0:
155 | self._embeddings_array.append(embedding_input.cpu().clone().detach().numpy())
156 | if len(self._embeddings_array) >= self._num_sample_embedding:
157 | if self._sample_embedding_file.split('.')[-1] == 'hdf5':
158 | import h5py
159 | f = h5py.File(self._sample_embedding_file, 'w')
160 | f['embedding'] = np.squeeze(np.stack(self._embeddings_array))
161 | f.close()
162 | elif self._sample_embedding_file.split('.')[-1] == 'pt':
163 | torch.save(np.squeeze(np.stack(self._embeddings_array)),
164 | self._sample_embedding_file)
165 | else:
166 | raise NotImplementedError
167 |
168 | out_embeddings = []
169 | for i, embedding in enumerate(self._embeddings):
170 | embedding_vec = embedding(embedding_input)
171 | out_embeddings.append(embedding_vec)
172 | if not self._reuse and self._verbose: print('emb vec {} size: {}'.format(
173 | i+1, embedding_vec.size()))
174 | if not self._reuse and self._verbose: print('='*27)
175 | self._reuse = True
176 | return out_embeddings
177 |
178 | def to(self, device, **kwargs):
179 | self._device = device
180 | super(ConvEmbeddingModel, self).to(device, **kwargs)
181 |
--------------------------------------------------------------------------------
/maml/models/conv_net.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from maml.models.model import Model
7 |
8 |
9 | def weight_init(module):
10 | if (isinstance(module, torch.nn.Linear)
11 | or isinstance(module, torch.nn.Conv2d)):
12 | torch.nn.init.xavier_uniform_(module.weight)
13 | module.bias.data.zero_()
14 |
15 |
16 | class ConvModel(Model):
17 | """
18 | NOTE: difference to tf implementation: batch norm scaling is enabled here
19 | TODO: enable 'non-transductive' setting as per
20 | https://arxiv.org/abs/1803.02999
21 | """
22 | def __init__(self, input_channels, output_size, num_channels=64,
23 | kernel_size=3, padding=1, nonlinearity=F.relu,
24 | use_max_pool=False, img_side_len=28, verbose=False):
25 | super(ConvModel, self).__init__()
26 | self._input_channels = input_channels
27 | self._output_size = output_size
28 | self._num_channels = num_channels
29 | self._kernel_size = kernel_size
30 | self._nonlinearity = nonlinearity
31 | self._use_max_pool = use_max_pool
32 | self._padding = padding
33 | self._bn_affine = False
34 | self._reuse = False
35 | self._verbose = verbose
36 |
37 | if self._use_max_pool:
38 | self._conv_stride = 1
39 | self._features_size = 1
40 | self.features = torch.nn.Sequential(OrderedDict([
41 | ('layer1_conv', torch.nn.Conv2d(self._input_channels,
42 | self._num_channels,
43 | self._kernel_size,
44 | stride=self._conv_stride,
45 | padding=self._padding)),
46 | ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels,
47 | affine=self._bn_affine,
48 | momentum=0.001)),
49 | ('layer1_max_pool', torch.nn.MaxPool2d(kernel_size=2,
50 | stride=2)),
51 | ('layer1_relu', torch.nn.ReLU(inplace=True)),
52 | ('layer2_conv', torch.nn.Conv2d(self._num_channels,
53 | self._num_channels*2,
54 | self._kernel_size,
55 | stride=self._conv_stride,
56 | padding=self._padding)),
57 | ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2,
58 | affine=self._bn_affine,
59 | momentum=0.001)),
60 | ('layer2_max_pool', torch.nn.MaxPool2d(kernel_size=2,
61 | stride=2)),
62 | ('layer2_relu', torch.nn.ReLU(inplace=True)),
63 | ('layer3_conv', torch.nn.Conv2d(self._num_channels*2,
64 | self._num_channels*4,
65 | self._kernel_size,
66 | stride=self._conv_stride,
67 | padding=self._padding)),
68 | ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4,
69 | affine=self._bn_affine,
70 | momentum=0.001)),
71 | ('layer3_max_pool', torch.nn.MaxPool2d(kernel_size=2,
72 | stride=2)),
73 | ('layer3_relu', torch.nn.ReLU(inplace=True)),
74 | ('layer4_conv', torch.nn.Conv2d(self._num_channels*4,
75 | self._num_channels*8,
76 | self._kernel_size,
77 | stride=self._conv_stride,
78 | padding=self._padding)),
79 | ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8,
80 | affine=self._bn_affine,
81 | momentum=0.001)),
82 | ('layer4_max_pool', torch.nn.MaxPool2d(kernel_size=2,
83 | stride=2)),
84 | ('layer4_relu', torch.nn.ReLU(inplace=True)),
85 | ]))
86 | else:
87 | self._conv_stride = 2
88 | self._features_size = (img_side_len // 14)**2
89 | self.features = torch.nn.Sequential(OrderedDict([
90 | ('layer1_conv', torch.nn.Conv2d(self._input_channels,
91 | self._num_channels,
92 | self._kernel_size,
93 | stride=self._conv_stride,
94 | padding=self._padding)),
95 | ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels,
96 | affine=self._bn_affine,
97 | momentum=0.001)),
98 | ('layer1_relu', torch.nn.ReLU(inplace=True)),
99 | ('layer2_conv', torch.nn.Conv2d(self._num_channels,
100 | self._num_channels*2,
101 | self._kernel_size,
102 | stride=self._conv_stride,
103 | padding=self._padding)),
104 | ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2,
105 | affine=self._bn_affine,
106 | momentum=0.001)),
107 | ('layer2_relu', torch.nn.ReLU(inplace=True)),
108 | ('layer3_conv', torch.nn.Conv2d(self._num_channels*2,
109 | self._num_channels*4,
110 | self._kernel_size,
111 | stride=self._conv_stride,
112 | padding=self._padding)),
113 | ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4,
114 | affine=self._bn_affine,
115 | momentum=0.001)),
116 | ('layer3_relu', torch.nn.ReLU(inplace=True)),
117 | ('layer4_conv', torch.nn.Conv2d(self._num_channels*4,
118 | self._num_channels*8,
119 | self._kernel_size,
120 | stride=self._conv_stride,
121 | padding=self._padding)),
122 | ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8,
123 | affine=self._bn_affine,
124 | momentum=0.001)),
125 | ('layer4_relu', torch.nn.ReLU(inplace=True)),
126 | ]))
127 |
128 | self.classifier = torch.nn.Sequential(OrderedDict([
129 | ('fully_connected', torch.nn.Linear(self._num_channels*8,
130 | self._output_size))
131 | ]))
132 | self.apply(weight_init)
133 |
134 | def forward(self, task, params=None, embeddings=None):
135 | if not self._reuse and self._verbose: print('='*10 + ' Model ' + '='*10)
136 | if params is None:
137 | params = OrderedDict(self.named_parameters())
138 |
139 | x = task.x
140 | if not self._reuse and self._verbose: print('input size: {}'.format(x.size()))
141 | for layer_name, layer in self.features.named_children():
142 | weight = params.get('features.' + layer_name + '.weight', None)
143 | bias = params.get('features.' + layer_name + '.bias', None)
144 | if 'conv' in layer_name:
145 | x = F.conv2d(x, weight=weight, bias=bias,
146 | stride=self._conv_stride, padding=self._padding)
147 | elif 'bn' in layer_name:
148 | x = F.batch_norm(x, weight=weight, bias=bias,
149 | running_mean=layer.running_mean,
150 | running_var=layer.running_var,
151 | training=True)
152 | elif 'max_pool' in layer_name:
153 | x = F.max_pool2d(x, kernel_size=2, stride=2)
154 | elif 'relu' in layer_name:
155 | x = F.relu(x)
156 | elif 'fully_connected' in layer_name:
157 | break
158 | else:
159 | raise ValueError('Unrecognized layer {}'.format(layer_name))
160 | if not self._reuse and self._verbose: print('{}: {}'.format(layer_name, x.size()))
161 |
162 | # in maml network the conv maps are average pooled
163 | x = x.view(x.size(0), self._num_channels*8, self._features_size)
164 | if not self._reuse and self._verbose: print('reshape to: {}'.format(x.size()))
165 | x = torch.mean(x, dim=2)
166 | if not self._reuse and self._verbose: print('reduce mean: {}'.format(x.size()))
167 | logits = F.linear(
168 | x, weight=params['classifier.fully_connected.weight'],
169 | bias=params['classifier.fully_connected.bias'])
170 | if not self._reuse and self._verbose: print('logits size: {}'.format(logits.size()))
171 | if not self._reuse and self._verbose: print('='*27)
172 | self._reuse = True
173 | return logits
174 |
--------------------------------------------------------------------------------
/maml/models/fully_connected.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from maml.models.model import Model
7 |
8 | def weight_init(module):
9 | if isinstance(module, torch.nn.Linear):
10 | torch.nn.init.normal_(module.weight, mean=0, std=0.01)
11 | module.bias.data.zero_()
12 |
13 |
14 | class FullyConnectedModel(Model):
15 | def __init__(self, input_size, output_size, hidden_sizes=(),
16 | nonlinearity=F.relu, disable_norm=False,
17 | bias_transformation_size=0):
18 | super(FullyConnectedModel, self).__init__()
19 | self.hidden_sizes = hidden_sizes
20 | self.nonlinearity = nonlinearity
21 | self.num_layers = len(hidden_sizes) + 1
22 | self.disable_norm = disable_norm
23 | self.bias_transformation_size = bias_transformation_size
24 |
25 | if bias_transformation_size > 0:
26 | input_size = input_size + bias_transformation_size
27 | self.bias_transformation = torch.nn.Parameter(
28 | torch.zeros(bias_transformation_size))
29 |
30 | layer_sizes = [input_size] + hidden_sizes + [output_size]
31 | for i in range(1, self.num_layers):
32 | self.add_module(
33 | 'layer{0}_linear'.format(i),
34 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i]))
35 | if not self.disable_norm:
36 | self.add_module(
37 | 'layer{0}_bn'.format(i),
38 | torch.nn.BatchNorm1d(layer_sizes[i], momentum=0.001))
39 | self.add_module(
40 | 'output_linear',
41 | torch.nn.Linear(layer_sizes[self.num_layers - 1],
42 | layer_sizes[self.num_layers]))
43 | self.apply(weight_init)
44 |
45 | def forward(self, task, params=None, training=True, embeddings=None):
46 | if params is None:
47 | params = OrderedDict(self.named_parameters())
48 | x = task.x.view(task.x.size(0), -1)
49 |
50 | if self.bias_transformation_size > 0:
51 | x = torch.cat((x, params['bias_transformation'].expand(
52 | x.size(0), params['bias_transformation'].size(0))), dim=1)
53 |
54 | for key, module in self.named_modules():
55 | if 'linear' in key:
56 | x = F.linear(x, weight=params[key + '.weight'],
57 | bias=params[key + '.bias'])
58 | if self.disable_norm and 'output' not in key:
59 | x = self.nonlinearity(x)
60 | if 'bn' in key:
61 | x = F.batch_norm(x, weight=params[key + '.weight'],
62 | bias=params[key + '.bias'],
63 | running_mean=module.running_mean,
64 | running_var=module.running_var,
65 | training=training)
66 | x = self.nonlinearity(x)
67 | return x
68 |
69 | class MultiFullyConnectedModel(Model):
70 | def __init__(self, input_size, output_size, hidden_sizes=(),
71 | nonlinearity=F.relu, disable_norm=False, num_tasks=1,
72 | bias_transformation_size=0):
73 | super(MultiFullyConnectedModel, self).__init__()
74 | self.hidden_sizes = hidden_sizes
75 | self.nonlinearity = nonlinearity
76 | self.num_layers = len(hidden_sizes) + 1
77 | self.disable_norm = disable_norm
78 | self.bias_transformation_size = bias_transformation_size
79 | self.num_tasks = num_tasks
80 |
81 | if bias_transformation_size > 0:
82 | input_size = input_size + bias_transformation_size
83 | self.bias_transformation = torch.nn.Parameter(
84 | torch.zeros(bias_transformation_size))
85 |
86 | layer_sizes = [input_size] + hidden_sizes + [output_size]
87 | for j in range(0, self.num_tasks):
88 | for i in range(1, self.num_layers):
89 | self.add_module(
90 | 'task{0}_layer{1}_linear'.format(j, i),
91 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i]))
92 | if not self.disable_norm:
93 | self.add_module(
94 | 'task{0}_layer{1}_bn'.format(j, i),
95 | torch.nn.BatchNorm1d(layer_sizes[i], momentum=0.001))
96 | self.add_module(
97 | 'task{0}_output_linear'.format(j),
98 | torch.nn.Linear(layer_sizes[self.num_layers - 1],
99 | layer_sizes[self.num_layers]))
100 | self.apply(weight_init)
101 |
102 | def forward(self, task, params=None, training=True, embeddings=None):
103 | if params is None:
104 | params = OrderedDict(self.named_parameters())
105 | x = task.x.view(task.x.size(0), -1)
106 | task_id = task.task_info['task_id']
107 |
108 | if self.bias_transformation_size > 0:
109 | x = torch.cat((x, params['bias_transformation'].expand(
110 | x.size(0), params['bias_transformation'].size(0))), dim=1)
111 |
112 | for key, module in self.named_modules():
113 | if 'task{0}'.format(task_id) in key:
114 | if 'linear' in key:
115 | x = F.linear(x, weight=params[key + '.weight'],
116 | bias=params[key + '.bias'])
117 | if self.disable_norm and 'output' not in key:
118 | x = self.nonlinearity(x)
119 | if 'bn' in key:
120 | x = F.batch_norm(x, weight=params[key + '.weight'],
121 | bias=params[key + '.bias'],
122 | running_mean=module.running_mean,
123 | running_var=module.running_var,
124 | training=training)
125 | x = self.nonlinearity(x)
126 | return x
127 |
--------------------------------------------------------------------------------
/maml/models/gated_conv_net.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from maml.models.model import Model
7 |
8 |
9 | def weight_init(module):
10 | if (isinstance(module, torch.nn.Linear)
11 | or isinstance(module, torch.nn.Conv2d)):
12 | torch.nn.init.xavier_uniform_(module.weight)
13 | module.bias.data.zero_()
14 |
15 |
16 | class GatedConvModel(Model):
17 | """
18 | NOTE: difference to tf implementation: batch norm scaling is enabled here
19 | TODO: enable 'non-transductive' setting as per
20 | https://arxiv.org/abs/1803.02999
21 | """
22 | def __init__(self, input_channels, output_size, num_channels=64,
23 | kernel_size=3, padding=1, nonlinearity=F.relu,
24 | use_max_pool=False, img_side_len=28,
25 | condition_type='affine', condition_order='low2high', verbose=False):
26 | super(GatedConvModel, self).__init__()
27 | self._input_channels = input_channels
28 | self._output_size = output_size
29 | self._num_channels = num_channels
30 | self._kernel_size = kernel_size
31 | self._nonlinearity = nonlinearity
32 | self._use_max_pool = use_max_pool
33 | self._padding = padding
34 | self._condition_type = condition_type
35 | self._condition_order = condition_order
36 | self._bn_affine = False
37 | self._reuse = False
38 | self._verbose = verbose
39 |
40 | if self._use_max_pool:
41 | self._conv_stride = 1
42 | self._features_size = 1
43 | self.features = torch.nn.Sequential(OrderedDict([
44 | ('layer1_conv', torch.nn.Conv2d(self._input_channels,
45 | self._num_channels,
46 | self._kernel_size,
47 | stride=self._conv_stride,
48 | padding=self._padding)),
49 | ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels,
50 | affine=self._bn_affine,
51 | momentum=0.001)),
52 | ('layer1_condition', None),
53 | ('layer1_max_pool', torch.nn.MaxPool2d(kernel_size=2,
54 | stride=2)),
55 | ('layer1_relu', torch.nn.ReLU(inplace=True)),
56 | ('layer2_conv', torch.nn.Conv2d(self._num_channels,
57 | self._num_channels*2,
58 | self._kernel_size,
59 | stride=self._conv_stride,
60 | padding=self._padding)),
61 | ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2,
62 | affine=self._bn_affine,
63 | momentum=0.001)),
64 | ('layer2_condition', None),
65 | ('layer2_max_pool', torch.nn.MaxPool2d(kernel_size=2,
66 | stride=2)),
67 | ('layer2_relu', torch.nn.ReLU(inplace=True)),
68 | ('layer3_conv', torch.nn.Conv2d(self._num_channels*2,
69 | self._num_channels*4,
70 | self._kernel_size,
71 | stride=self._conv_stride,
72 | padding=self._padding)),
73 | ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4,
74 | affine=self._bn_affine,
75 | momentum=0.001)),
76 | ('layer3_condition', None),
77 | ('layer3_max_pool', torch.nn.MaxPool2d(kernel_size=2,
78 | stride=2)),
79 | ('layer3_relu', torch.nn.ReLU(inplace=True)),
80 | ('layer4_conv', torch.nn.Conv2d(self._num_channels*4,
81 | self._num_channels*8,
82 | self._kernel_size,
83 | stride=self._conv_stride,
84 | padding=self._padding)),
85 | ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8,
86 | affine=self._bn_affine,
87 | momentum=0.001)),
88 | ('layer4_condition', None),
89 | ('layer4_max_pool', torch.nn.MaxPool2d(kernel_size=2,
90 | stride=2)),
91 | ('layer4_relu', torch.nn.ReLU(inplace=True)),
92 | ]))
93 | else:
94 | self._conv_stride = 2
95 | self._features_size = (img_side_len // 14)**2
96 | self.features = torch.nn.Sequential(OrderedDict([
97 | ('layer1_conv', torch.nn.Conv2d(self._input_channels,
98 | self._num_channels,
99 | self._kernel_size,
100 | stride=self._conv_stride,
101 | padding=self._padding)),
102 | ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels,
103 | affine=self._bn_affine,
104 | momentum=0.001)),
105 | ('layer1_condition', torch.nn.ReLU(inplace=True)),
106 | ('layer1_relu', torch.nn.ReLU(inplace=True)),
107 | ('layer2_conv', torch.nn.Conv2d(self._num_channels,
108 | self._num_channels*2,
109 | self._kernel_size,
110 | stride=self._conv_stride,
111 | padding=self._padding)),
112 | ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2,
113 | affine=self._bn_affine,
114 | momentum=0.001)),
115 | ('layer2_condition', torch.nn.ReLU(inplace=True)),
116 | ('layer2_relu', torch.nn.ReLU(inplace=True)),
117 | ('layer3_conv', torch.nn.Conv2d(self._num_channels*2,
118 | self._num_channels*4,
119 | self._kernel_size,
120 | stride=self._conv_stride,
121 | padding=self._padding)),
122 | ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4,
123 | affine=self._bn_affine,
124 | momentum=0.001)),
125 | ('layer3_condition', torch.nn.ReLU(inplace=True)),
126 | ('layer3_relu', torch.nn.ReLU(inplace=True)),
127 | ('layer4_conv', torch.nn.Conv2d(self._num_channels*4,
128 | self._num_channels*8,
129 | self._kernel_size,
130 | stride=self._conv_stride,
131 | padding=self._padding)),
132 | ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8,
133 | affine=self._bn_affine,
134 | momentum=0.001)),
135 | ('layer4_condition', torch.nn.ReLU(inplace=True)),
136 | ('layer4_relu', torch.nn.ReLU(inplace=True)),
137 | ]))
138 |
139 | self.classifier = torch.nn.Sequential(OrderedDict([
140 | ('fully_connected', torch.nn.Linear(self._num_channels*8,
141 | self._output_size))
142 | ]))
143 | self.apply(weight_init)
144 |
145 | def conditional_layer(self, x, embedding):
146 | if self._condition_type == 'sigmoid_gate':
147 | x = x * F.sigmoid(embedding).expand_as(x)
148 | elif self._condition_type == 'affine':
149 | gammas, betas = torch.split(embedding, x.size(1), dim=-1)
150 | gammas = gammas.view(1, -1, 1, 1).expand_as(x)
151 | betas = betas.view(1, -1, 1, 1).expand_as(x)
152 | gammas = gammas + torch.ones_like(gammas)
153 | x = x * gammas + betas
154 | elif self._condition_type == 'softmax':
155 | x = x * F.softmax(embedding).view(1, -1, 1, 1).expand_as(x)
156 | else:
157 | raise ValueError('Unrecognized conditional layer type {}'.format(
158 | self._condition_type))
159 | return x
160 |
161 | def forward(self, task, params=None, embeddings=None):
162 | if not self._reuse and self._verbose: print('='*10 + ' Model ' + '='*10)
163 | if params is None:
164 | params = OrderedDict(self.named_parameters())
165 |
166 | if embeddings is not None:
167 | embeddings = {'layer{}_condition'.format(i): embedding
168 | for i, embedding in enumerate(embeddings, start=1)}
169 |
170 | x = task.x
171 | if not self._reuse and self._verbose: print('input size: {}'.format(x.size()))
172 | for layer_name, layer in self.features.named_children():
173 | weight = params.get('features.' + layer_name + '.weight', None)
174 | bias = params.get('features.' + layer_name + '.bias', None)
175 | if 'conv' in layer_name:
176 | x = F.conv2d(x, weight=weight, bias=bias,
177 | stride=self._conv_stride, padding=self._padding)
178 | elif 'condition' in layer_name:
179 | x = self.conditional_layer(x, embeddings[layer_name]) if embeddings is not None else x
180 | elif 'bn' in layer_name:
181 | x = F.batch_norm(x, weight=weight, bias=bias,
182 | running_mean=layer.running_mean,
183 | running_var=layer.running_var,
184 | training=True)
185 | elif 'max_pool' in layer_name:
186 | x = F.max_pool2d(x, kernel_size=2, stride=2)
187 | elif 'relu' in layer_name:
188 | x = F.relu(x)
189 | elif 'fully_connected' in layer_name:
190 | break
191 | else:
192 | raise ValueError('Unrecognized layer {}'.format(layer_name))
193 | if not self._reuse and self._verbose: print('{}: {}'.format(layer_name, x.size()))
194 |
195 | # in maml network the conv maps are average pooled
196 | x = x.view(x.size(0), self._num_channels*8, self._features_size)
197 | if not self._reuse and self._verbose: print('reshape to: {}'.format(x.size()))
198 | x = torch.mean(x, dim=2)
199 | if not self._reuse and self._verbose: print('reduce mean: {}'.format(x.size()))
200 | logits = F.linear(
201 | x, weight=params['classifier.fully_connected.weight'],
202 | bias=params['classifier.fully_connected.bias'])
203 | if not self._reuse and self._verbose: print('logits size: {}'.format(logits.size()))
204 | if not self._reuse and self._verbose: print('='*27)
205 | self._reuse = True
206 | return logits
207 |
--------------------------------------------------------------------------------
/maml/models/gated_net.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from maml.models.model import Model
7 | from IPython import embed
8 |
9 | def weight_init(module):
10 | if isinstance(module, torch.nn.Linear):
11 | torch.nn.init.normal_(module.weight, mean=0, std=0.01)
12 | module.bias.data.zero_()
13 |
14 |
15 | class GatedNet(Model):
16 | def __init__(self, input_size, output_size, hidden_sizes=[40, 40],
17 | nonlinearity=F.relu, condition_type='sigmoid_gate', condition_order='low2high'):
18 | super(GatedNet, self).__init__()
19 | self._nonlinearity = nonlinearity
20 | self._condition_type = condition_type
21 | self._condition_order = condition_order
22 |
23 | self.num_layers = len(hidden_sizes) + 1
24 |
25 | layer_sizes = [input_size] + hidden_sizes + [output_size]
26 | for i in range(1, self.num_layers):
27 | self.add_module(
28 | 'layer{0}_linear'.format(i),
29 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i]))
30 | self.add_module(
31 | 'output_linear',
32 | torch.nn.Linear(layer_sizes[self.num_layers - 1],
33 | layer_sizes[self.num_layers]))
34 | self.apply(weight_init)
35 |
36 | def conditional_layer(self, x, embedding):
37 | if self._condition_type == 'sigmoid_gate':
38 | x = x * F.sigmoid(embedding).expand_as(x)
39 | elif self._condition_type == 'affine':
40 | gammas, betas = torch.split(embedding, x.size(1), dim=-1)
41 | gammas = gammas + torch.ones_like(gammas)
42 | x = x * gammas + betas
43 | elif self._condition_type == 'softmax':
44 | x = x * F.softmax(embedding).expand_as(x)
45 | else:
46 | raise ValueError('Unrecognized conditional layer type {}'.format(
47 | self._condition_type))
48 | return x
49 |
50 | def forward(self, task, params=None, embeddings=None, training=True):
51 | if params is None:
52 | params = OrderedDict(self.named_parameters())
53 |
54 | if embeddings is not None:
55 | if self._condition_order == 'high2low': ## High2Low
56 | embeddings = {'layer{}_linear'.format(len(params)-i): embedding
57 | for i, embedding in enumerate(embeddings[::-1])}
58 | elif self._condition_order == 'low2high': ## Low2High
59 | embeddings = {'layer{}_linear'.format(i): embedding
60 | for i, embedding in enumerate(embeddings[::-1], start=1)}
61 | else:
62 | raise NotImplementedError('Unsuppported order for using conditional layers')
63 | x = task.x.view(task.x.size(0), -1)
64 |
65 | for key, module in self.named_modules():
66 | if 'linear' in key:
67 | x = F.linear(x, weight=params[key + '.weight'],
68 | bias=params[key + '.bias'])
69 | if 'output' not in key and embeddings is not None: # conditioning and nonlinearity
70 | if type(embeddings.get(key, -1)) != type(-1):
71 | x = self.conditional_layer(x, embeddings[key])
72 |
73 | x = self._nonlinearity(x)
74 |
75 | return x
76 |
--------------------------------------------------------------------------------
/maml/models/gru_embedding_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class GRUEmbeddingModel(torch.nn.Module):
4 | def __init__(self, input_size, output_size, embedding_dims,
5 | hidden_size=40, num_layers=2):
6 | super(GRUEmbeddingModel, self).__init__()
7 | self._input_size = input_size
8 | self._output_size = output_size
9 | self._hidden_size = hidden_size
10 | self._num_layers = num_layers
11 | self._embedding_dims = embedding_dims
12 | self._bidirectional = True
13 | self._device = 'cpu'
14 |
15 | rnn_input_size = int(input_size + output_size)
16 | self.rnn = torch.nn.GRU(rnn_input_size, hidden_size, num_layers, bidirectional=self._bidirectional)
17 |
18 | self._embeddings = torch.nn.ModuleList()
19 | for dim in embedding_dims:
20 | self._embeddings.append(torch.nn.Linear(hidden_size*(2 if self._bidirectional else 1), dim))
21 |
22 | def forward(self, task):
23 | batch_size = 1
24 | h0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1),
25 | batch_size, self._hidden_size, device=self._device)
26 |
27 | x = task.x.view(task.x.size(0), -1)
28 | y = task.y.view(task.y.size(0), -1)
29 |
30 | # LSTM input dimensions are seq_len, batch, input_size
31 | inputs = torch.cat((x, y), dim=1).view(x.size(0), 1, -1)
32 | output, _ = self.rnn(inputs, h0)
33 | if self._bidirectional:
34 | N, B, H = output.shape
35 | output = output.view(N, B, 2, H // 2)
36 | embedding_input = torch.cat([output[-1, :, 0], output[0, :, 1]], dim=1)
37 |
38 | out_embeddings = []
39 | for embedding in self._embeddings:
40 | out_embeddings.append(embedding(embedding_input))
41 | return out_embeddings
42 |
43 | def to(self, device, **kwargs):
44 | self._device = device
45 | super(GRUEmbeddingModel, self).to(device, **kwargs)
46 |
--------------------------------------------------------------------------------
/maml/models/lstm_embedding_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class LSTMEmbeddingModel(torch.nn.Module):
4 | def __init__(self, input_size, output_size, embedding_dims,
5 | hidden_size=40, num_layers=2):
6 | super(LSTMEmbeddingModel, self).__init__()
7 | self._input_size = input_size
8 | self._output_size = output_size
9 | self._hidden_size = hidden_size
10 | self._num_layers = num_layers
11 | self._embedding_dims = embedding_dims
12 | self._bidirectional = True
13 | self._device = 'cpu'
14 |
15 | rnn_input_size = int(input_size + output_size)
16 | self.rnn = torch.nn.LSTM(rnn_input_size, hidden_size, num_layers, bidirectional=self._bidirectional)
17 |
18 | self._embeddings = torch.nn.ModuleList()
19 | for dim in embedding_dims:
20 | self._embeddings.append(torch.nn.Linear(hidden_size*(2 if self._bidirectional else 1), dim))
21 |
22 | def forward(self, task):
23 | batch_size = 1
24 | h0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1),
25 | batch_size, self._hidden_size, device=self._device)
26 | c0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1),
27 | batch_size, self._hidden_size, device=self._device)
28 |
29 | x = task.x.view(task.x.size(0), -1)
30 | y = task.y.view(task.y.size(0), -1)
31 |
32 | # LSTM input dimensions are seq_len, batch, input_size
33 | inputs = torch.cat((x, y), dim=1).view(x.size(0), 1, -1)
34 | output, (hn, cn) = self.rnn(inputs, (h0, c0))
35 | if self._bidirectional:
36 | N, B, H = output.shape
37 | output = output.view(N, B, 2, H // 2)
38 | embedding_input = torch.cat([output[-1, :, 0], output[0, :, 1]], dim=1)
39 |
40 | out_embeddings = []
41 | for embedding in self._embeddings:
42 | out_embeddings.append(embedding(embedding_input))
43 | return out_embeddings
44 |
45 | def to(self, device, **kwargs):
46 | self._device = device
47 | super(LSTMEmbeddingModel, self).to(device, **kwargs)
48 |
--------------------------------------------------------------------------------
/maml/models/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | class Model(torch.nn.Module):
7 | def __init__(self):
8 | super(Model, self).__init__()
9 |
10 | @property
11 | def param_dict(self):
12 | return OrderedDict(self.named_parameters())
--------------------------------------------------------------------------------
/maml/models/simple_embedding_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class SimpleEmbeddingModel(torch.nn.Module):
4 | def __init__(self, num_embeddings, embedding_dims):
5 | super(SimpleEmbeddingModel, self).__init__()
6 | self._embeddings = torch.nn.ModuleList()
7 | for dim in embedding_dims:
8 | self._embeddings.append(torch.nn.Embedding(num_embeddings, dim))
9 | self._device = 'cpu'
10 |
11 | def forward(self, task):
12 | task_id = torch.tensor(task.task_id, dtype=torch.long,
13 | device=self._device)
14 | out_embeddings = []
15 | for embedding in self._embeddings:
16 | out_embeddings.append(embedding(task_id))
17 | return out_embeddings
18 |
19 | def to(self, device, **kwargs):
20 | self._device = device
21 | super(SimpleEmbeddingModel, self).to(device, **kwargs)
--------------------------------------------------------------------------------
/maml/sampler.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict, namedtuple
3 |
4 | from torch.utils.data.sampler import Sampler
5 |
6 |
7 | class ClassBalancedSampler(Sampler):
8 | """Generates indices for class balanced batch by sampling with replacement.
9 | """
10 | def __init__(self, dataset_labels, num_classes_per_batch,
11 | num_samples_per_class, num_total_batches, train):
12 | """
13 | Args:
14 | dataset_labels: list of dataset labels
15 | num_classes_per_batch: number of classes to sample for each batch
16 | num_samples_per_class: number of samples to sample for each class
17 | for each batch. For K shot learning this should be K + number
18 | of validation samples
19 | num_total_batches: total number of batches to generate
20 | """
21 | self._dataset_labels = dataset_labels
22 | self._classes = set(self._dataset_labels)
23 | self._class_to_samples = defaultdict(set)
24 | for i, c in enumerate(self._dataset_labels):
25 | self._class_to_samples[c].add(i)
26 |
27 | self._num_classes_per_batch = num_classes_per_batch
28 | self._num_samples_per_class = num_samples_per_class
29 | self._num_total_batches = num_total_batches
30 | self._train = train
31 |
32 | def __iter__(self):
33 | for i in range(self._num_total_batches):
34 | if len(self._class_to_samples.keys()) >= self._num_classes_per_batch:
35 | batch_classes = random.sample(
36 | self._class_to_samples.keys(), self._num_classes_per_batch)
37 | else:
38 | batch_classes = [random.choice(list(self._class_to_samples.keys()))
39 | for _ in range(self._num_classes_per_batch)]
40 | batch_samples = []
41 | for c in batch_classes:
42 | if len(self._class_to_samples[c]) >= self._num_samples_per_class:
43 | class_samples = random.sample(
44 | self._class_to_samples[c], self._num_samples_per_class)
45 | else:
46 | class_samples = [random.choice(list(self._class_to_samples[c]))
47 | for _ in range(self._num_samples_per_class)]
48 |
49 | for sample in class_samples:
50 | batch_samples.append(sample)
51 | random.shuffle(batch_samples)
52 | for sample in batch_samples:
53 | yield sample
54 |
55 | def __len__(self):
56 | return self._num_total_batches
57 |
--------------------------------------------------------------------------------
/maml/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 |
8 | class Trainer(object):
9 | def __init__(self, meta_learner, meta_dataset, writer, log_interval,
10 | save_interval, model_type, save_folder, total_iter):
11 | self._meta_learner = meta_learner
12 | self._meta_dataset = meta_dataset
13 | self._writer = writer
14 | self._log_interval = log_interval
15 | self._save_interval = save_interval
16 | self._model_type = model_type
17 | self._save_folder = save_folder
18 | self._total_iter = total_iter
19 |
20 |
21 | def run(self, is_training):
22 | if not is_training:
23 | all_pre_val_measurements = defaultdict(list)
24 | all_pre_train_measurements = defaultdict(list)
25 | all_post_val_measurements = defaultdict(list)
26 | all_post_train_measurements = defaultdict(list)
27 |
28 | # compute running accuracies for all datasets
29 | if self._meta_dataset.name == 'MultimodalFewShot':
30 | accuracies = [[] for i in range(self._meta_dataset.num_dataset)]
31 |
32 | for i, (train_tasks, val_tasks) in enumerate(
33 | iter(self._meta_dataset), start=1):
34 |
35 | # Save model
36 | if (i % self._save_interval == 0 or i == 1) and is_training:
37 | save_name = 'maml_{0}_{1}.pt'.format(self._model_type, i)
38 | save_path = os.path.join(self._save_folder, save_name)
39 | with open(save_path, 'wb') as f:
40 | torch.save(self._meta_learner.state_dict(), f)
41 |
42 | (pre_train_measurements, adapted_params, embeddings
43 | ) = self._meta_learner.adapt(train_tasks)
44 | post_val_measurements = self._meta_learner.step(
45 | adapted_params, embeddings, val_tasks, is_training)
46 |
47 | # Tensorboard
48 | if (i % self._log_interval == 0 or i == 1):
49 | pre_val_measurements = self._meta_learner.measure(
50 | tasks=val_tasks, embeddings_list=embeddings)
51 | post_train_measurements = self._meta_learner.measure(
52 | tasks=train_tasks, adapted_params_list=adapted_params,
53 | embeddings_list=embeddings)
54 |
55 | _grads_mean = np.mean(self._meta_learner._grads_mean)
56 | self._meta_learner._grads_mean = []
57 |
58 | self.log_output(
59 | pre_val_measurements, pre_train_measurements,
60 | post_val_measurements, post_train_measurements,
61 | i, _grads_mean)
62 |
63 | if is_training:
64 | self.write_tensorboard(
65 | pre_val_measurements, pre_train_measurements,
66 | post_val_measurements, post_train_measurements,
67 | i, _grads_mean)
68 |
69 | if self._meta_dataset.name == 'MultimodalFewShot':
70 |
71 | post_val_accuracies = self._meta_learner.measure_each(
72 | tasks=val_tasks, adapted_params_list=adapted_params,
73 | embeddings_list=embeddings)
74 |
75 | if is_training:
76 | accuracies = [[] for i in range(self._meta_dataset.num_dataset)]
77 | for i, accuracy in enumerate(post_val_accuracies):
78 | accuracies[self._meta_dataset.dataset_names.index(
79 | val_tasks[i].task_info)].append(accuracy)
80 |
81 | accuracy_str = []
82 | for i, accuracy in enumerate(accuracies):
83 | accuracy_str.append('{}: {}'.format(
84 | self._meta_dataset.dataset_names[i],
85 | 'NaN' if len(accuracy) == 0 \
86 | else '{:.3f}%'.format(100*np.mean(accuracy))))
87 |
88 | print('Individual accuracies: {}'.format(' '.join(accuracy_str)))
89 | print('All accuracy: {:.3f}%'.format(100*np.mean(
90 | [item for accuracy in accuracies for item in accuracy])))
91 |
92 | # Collect evaluation statistics over full dataset
93 | if not is_training:
94 | for key, value in sorted(pre_val_measurements.items()):
95 | all_pre_val_measurements[key].append(value)
96 | for key, value in sorted(pre_train_measurements.items()):
97 | all_pre_train_measurements[key].append(value)
98 | for key, value in sorted(post_val_measurements.items()):
99 | all_post_val_measurements[key].append(value)
100 | for key, value in sorted(post_train_measurements.items()):
101 | all_post_train_measurements[key].append(value)
102 |
103 | # Compute evaluation statistics assuming all batches were the same size
104 | if not is_training:
105 | results = {'num_batches': i}
106 | for key, value in sorted(all_pre_val_measurements.items()):
107 | results['pre_val_' + key] = value
108 | for key, value in sorted(all_pre_train_measurements.items()):
109 | results['pre_train_' + key] = value
110 | for key, value in sorted(all_post_val_measurements.items()):
111 | results['post_val_' + key] = value
112 | for key, value in sorted(all_post_train_measurements.items()):
113 | results['post_train_' + key] = value
114 |
115 | print('Evaluation results:')
116 | for key, value in sorted(results.items()):
117 | if not isinstance(value, int):
118 | print('{}: {} +- {}'.format(
119 | key, np.mean(value), self.compute_confidence_interval(value)))
120 | else:
121 | print('{}: {}'.format(key, value))
122 |
123 | results_path = os.path.join(self._save_folder, 'results.json')
124 | with open(results_path, 'w') as f:
125 | json.dump(results, f)
126 |
127 | def compute_confidence_interval(self, value):
128 | """
129 | Compute 95% +- confidence intervals over tasks
130 | change 1.960 to 2.576 for 99% +- confidence intervals
131 | """
132 | return np.std(value) * 1.960 / np.sqrt(len(value))
133 |
134 | def train(self):
135 | self.run(is_training=True)
136 |
137 | def eval(self):
138 | self.run(is_training=False)
139 |
140 | def write_tensorboard(self, pre_val_measurements, pre_train_measurements,
141 | post_val_measurements, post_train_measurements,
142 | iteration, embedding_grads_mean=None):
143 | for key, value in pre_val_measurements.items():
144 | self._writer.add_scalar(
145 | '{}/before_update/meta_val'.format(key), value, iteration)
146 | for key, value in pre_train_measurements.items():
147 | self._writer.add_scalar(
148 | '{}/before_update/meta_train'.format(key), value, iteration)
149 | for key, value in post_train_measurements.items():
150 | self._writer.add_scalar(
151 | '{}/after_update/meta_train'.format(key), value, iteration)
152 | for key, value in post_val_measurements.items():
153 | self._writer.add_scalar(
154 | '{}/after_update/meta_val'.format(key), value, iteration)
155 | if embedding_grads_mean is not None:
156 | self._writer.add_scalar(
157 | 'embedding_grads_mean', embedding_grads_mean, iteration)
158 |
159 | def log_output(self, pre_val_measurements, pre_train_measurements,
160 | post_val_measurements, post_train_measurements,
161 | iteration, embedding_grads_mean=None):
162 | log_str = 'Iteration: {}/{} '.format(iteration, self._total_iter)
163 | for key, value in sorted(pre_val_measurements.items()):
164 | log_str = (log_str + '{} meta_val before: {:.3f} '
165 | ''.format(key, value))
166 | for key, value in sorted(pre_train_measurements.items()):
167 | log_str = (log_str + '{} meta_train before: {:.3f} '
168 | ''.format(key, value))
169 | for key, value in sorted(post_train_measurements.items()):
170 | log_str = (log_str + '{} meta_train after: {:.3f} '
171 | ''.format(key, value))
172 | for key, value in sorted(post_val_measurements.items()):
173 | log_str = (log_str + '{} meta_val after: {:.3f} '
174 | ''.format(key, value))
175 | if embedding_grads_mean is not None:
176 | log_str = (log_str + 'embedding_grad_norm after: {:.3f} '
177 | ''.format(embedding_grads_mean))
178 | print(log_str)
179 |
--------------------------------------------------------------------------------
/maml/utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import torch
4 |
5 |
6 | def accuracy(preds, y):
7 | _, preds = torch.max(preds.data, 1)
8 | total = y.size(0)
9 | correct = (preds == y).sum().float()
10 | return correct / total
11 |
12 |
13 | def optimizer_to_device(optimizer, device):
14 | for state in optimizer.state.values():
15 | for k, v in state.items():
16 | if isinstance(v, torch.Tensor):
17 | state[k] = v.to(device)
18 |
19 |
20 | def get_git_revision_hash():
21 | return str(subprocess.check_output(['git', 'rev-parse', 'HEAD']))
--------------------------------------------------------------------------------
/miniimagenet-data/download_mini_imagenet.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | def download_file_from_google_drive(id, destination):
4 | URL = "https://drive.google.com/uc?export=download"
5 |
6 | session = requests.Session()
7 |
8 | response = session.get(URL, params = { 'id' : id }, stream = True)
9 | token = get_confirm_token(response)
10 |
11 | if token:
12 | params = { 'id' : id, 'confirm' : token }
13 | response = session.get(URL, params = params, stream = True)
14 |
15 | save_response_content(response, destination)
16 |
17 | def get_confirm_token(response):
18 | for key, value in response.cookies.items():
19 | if key.startswith('download_warning'):
20 | return value
21 |
22 | return None
23 |
24 | def save_response_content(response, destination):
25 | CHUNK_SIZE = 32768
26 |
27 | with open(destination, "wb") as f:
28 | for chunk in response.iter_content(CHUNK_SIZE):
29 | if chunk: # filter out keep-alive new chunks
30 | f.write(chunk)
31 |
32 | if __name__ == "__main__":
33 | file_id = "0B3Irx3uQNoBMQ1FlNXJsZUdYWEE"
34 | destination = './images.zip'
35 | download_file_from_google_drive(file_id, destination)
36 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipython
2 | torch=='1.2.0'
3 | torchvision=='0.4.0'
4 | imageio
5 | scikit-image
6 | tqdm
7 | Pillow
8 | numpy
9 | scipy
10 | requests
11 | subprocess
12 | tensorboardX
13 |
--------------------------------------------------------------------------------