├── .gitignore ├── LICENSE ├── README.md ├── asset ├── classification_summary.png ├── model.png └── tb.png ├── download.py ├── get_dataset_script ├── download_miniimagenet.py ├── get_aircraft.py ├── get_bird.py ├── get_cifar.py ├── get_miniimagenet.py ├── proc_aircraft.py ├── proc_bird.py ├── proc_cifar.py ├── proc_miniimagenet.py └── resize_dataset.py ├── main.py ├── maml ├── __init__.py ├── datasets │ ├── __init__.py │ ├── aircraft.py │ ├── bird.py │ ├── cifar100.py │ ├── metadataset.py │ ├── miniimagenet.py │ ├── multimodal_few_shot.py │ └── omniglot.py ├── metalearner.py ├── models │ ├── __init__.py │ ├── conv_embedding_model.py │ ├── conv_net.py │ ├── fully_connected.py │ ├── gated_conv_net.py │ ├── gated_net.py │ ├── gru_embedding_model.py │ ├── lstm_embedding_model.py │ ├── model.py │ └── simple_embedding_model.py ├── sampler.py ├── trainer.py └── utils.py ├── miniimagenet-data ├── download_mini_imagenet.py ├── test.csv ├── train.csv └── val.csv ├── requirements.txt └── visualization.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # project specific 2 | data/ 3 | saves/ 4 | logs/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | *.txt 112 | *.sw[opn] 113 | *.pt 114 | *.hdf5 115 | train_dir/ 116 | *tgz 117 | *tar.gz 118 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tristan Deleu 4 | Copyright (c) 2018 Risto Vuorio 5 | Copyright (c) 2019 Hexiang Hu 6 | Copyright (c) 2019 Shao-Hua Sun 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Model-Agnostic Meta-Learning for Few-shot Classification 2 | 3 | This project is an implementation of [**Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation**](https://arxiv.org/abs/1910.13616), which is published in [**NeurIPS 2019**](https://neurips.cc/Conferences/2019/). Please visit our [project page](https://vuoristo.github.io/MMAML/) for more information and contact [Shao-Hua Sun](http://shaohua0116.github.io/) for any questions. 4 | 5 | Model-agnostic meta-learners aim to acquire meta-prior parameters from a distribution of tasks and adapt to novel tasks with few gradient updates. Yet, seeking a common initialization shared across the entire task distribution substantially limits the diversity of the task distributions that they are able to learn from. We propose a multimodal MAML (MMAML) framework, which is able to modulate its meta-learned prior according to the identified mode, allowing more efficient fast adaptation. An illustration of the proposed framework is as follows. 6 | 7 |

8 | 9 |

10 | 11 | We evaluate our model and baselines (MAML and Multi-MAML) on multiple multimodal settings based on the following five datasets: (a) [Omniglot](https://www.omniglot.com/), (b) [Mini-ImageNet](https://openreview.net/forum?id=rJY0-Kcll), (c) [FC100](https://arxiv.org/abs/1805.10123) (e.g. CIFAR100), (d) [CUB-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), and (e) [FGVC-Aircraft](https://arxiv.org/abs/1306.5151). 12 | 13 |

14 | 15 |

16 | 17 | ## Datasets 18 | 19 | Run the following command to download and preprocess the datasets 20 | 21 | ```bash 22 | python download.py --dataset aircraft bird cifar miniimagenet 23 | ``` 24 | 25 | ## Getting started 26 | 27 | Please first install the following prerequisites: `wget`, `unzip`. 28 | 29 | To avoid any conflict with your existing Python setup, and to keep this project self-contained, it is suggested to work in a virtual environment with [`virtualenv`](http://docs.python-guide.org/en/latest/dev/virtualenvs/). To install `virtualenv`: 30 | ``` 31 | pip install --upgrade virtualenv 32 | ``` 33 | Create a virtual environment, activate it and install the requirements in [`requirements.txt`](requirements.txt). 34 | ``` 35 | virtualenv mmaml_venv 36 | source mmaml_venv/bin/activate 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ## Usage 41 | 42 | After downloading the datasets, we can start to train models with the following commands. 43 | 44 | ### Training command 45 | 46 | ```bash 47 | $ python main.py -dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s 48 | ``` 49 | - Selected arguments (see the `trainer.py` for more details) 50 | - --output-folder: a nickname for the training 51 | - --dataset: choose among `omniglot`, `miniimagenet`, `cifar`, `bird` (CUB), and `aircraft`. You can also add your own datasets. 52 | - Checkpoints: specify the path to a pre-trained checkpoint 53 | - --checkpoint: load all the parameters (e.g. `train_dir/mmaml_5mode_5w1s/maml_gatedconv_60000.pt`). 54 | - Hyperparameters 55 | - --num-batches: number of batches 56 | - --meta-batch-size: number of tasks per batch 57 | - --slow-lr: learning rate for the global update of MAML 58 | - --fast-lr: learning rate for the adapted models 59 | - --num-updates: how many update steps in the inner loop 60 | - --num-classes-per-batch: how many classes per task (`N`-way) 61 | - --num-samples-per-class: how many samples per class for training (`K`-shot) 62 | - --num-val-samples: how many samples per class for validation 63 | - --max\_steps: the max training iterations 64 | - Logging 65 | - --log-interval: number of batches between tensorboard writes 66 | - --save-interval: number of batches between model saves 67 | - Model 68 | - maml-model: set to `True` to train a MAML model 69 | - mmaml-model: set to `True` to train a MMAML (our) model 70 | 71 | ### Interpret TensorBoard 72 | 73 | Launch Tensorboard and go to the specified port, you can see differernt accuracies and losses in the **scalars** tab. 74 | 75 |

76 | 77 |

78 | 79 | You can reproduce our results with the following training commands. 80 | 81 | ### 2 Modes (Omniglot and Mini-ImageNet) 82 | 83 | | Setup | Method | Command | 84 | | :---: | :----: | ---------------------------------------- | 85 | | 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --output-folder maml_2mode_5w1s``` | 86 | | 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --output-folder mmaml_2mode_5w1s``` | 87 | | 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_2mode_5w5s``` | 88 | | 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_2mode_5w5s``` | 89 | | 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_2mode_20w1s``` | 90 | | 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_2mode_20w1s``` | 91 | 92 | ### 3 Modes (Omniglot, Mini-ImageNet, and FC100) 93 | 94 | | Setup | Method | Command | 95 | | :---: | :----: | ---------------------------------------- | 96 | | 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --output-folder maml_3mode_5w1s``` | 97 | | 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --output-folder mmaml_3mode_5w1s``` | 98 | | 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` | 99 | | 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` | 100 | | 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_3mode_20w1s``` | 101 | | 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_3mode_20w1s``` | 102 | 103 | ### 5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB) 104 | 105 | | Setup | Method | Command | 106 | | :---: | :----: | ---------------------------------------- | 107 | | 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s``` | 108 | | 5w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s``` | 109 | | 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s``` | 110 | | 5w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s``` | 111 | | 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` | 112 | | 5w5s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s``` | 113 | | 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` | 114 | | 5w5s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s``` | 115 | | 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s``` | 116 | | 20w1s | MAML | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s``` | 117 | | 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s``` | 118 | | 20w1s | Ours | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s``` | 119 | 120 | ### Multi-MAML 121 | 122 | | Setup | Dataset | Command | 123 | | :---: | :-----------: | ---------------------------------------- | 124 | | 5w1s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --output-folder multi_omniglot_5w1s``` | 125 | | 5w1s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_miniimagenet_5w1s``` | 126 | | 5w1s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_cifar_5w1s``` | 127 | | 5w1s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_bird_5w1s``` | 128 | | 5w1s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_aircraft_5w1s``` | 129 | | 5w5s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --num-samples-per-class 5 --output-folder multi_omniglot_5w5s``` | 130 | | 5w5s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_miniimagenet_5w5s``` | 131 | | 5w5s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_cifar_5w5s``` | 132 | | 5w5s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_bird_5w5s``` | 133 | | 5w5s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_aircraft_5w5s``` | 134 | | 20w1s | Omniglot | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.1 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_omniglot_20w1s``` | 135 | | 20w1s | Mini-ImageNet | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_miniimagenet_20w1s``` | 136 | | 20w1s | FC100 | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_cifar_20w1s``` | 137 | | 20w1s | Bird | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_bird_20w1s``` | 138 | | 20w1s | Aircraft | ```python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_aircraft_20w1s``` | 139 | 140 | 141 | ## Results 142 | 143 | ### 2 Modes (Omniglot and Mini-ImageNet) 144 | 145 | | Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 146 | | :----------: | :----------: | :----------: | :-----------: | 147 | | MAML | 66.80% | 77.79% | 44.69% | 148 | | Multi-MAML | 66.85% | 73.07% | 53.15% | 149 | | MMAML (Ours) | 69.93% | 78.73% | 47.80% | 150 | 151 | ### 3 Modes (Omniglot, Mini-ImageNet, and FC100) 152 | 153 | | Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 154 | | :----------: | :----------: | :----------: | :-----------: | 155 | | MAML | 54.55% | 67.97% | 28.22% | 156 | | Multi-MAML | 55.90% | 62.20% | 39.77% | 157 | | MMAML (Ours) | 57.47% | 70.15% | 36.27% | 158 | 159 | ### 5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB) 160 | 161 | | Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 162 | | :----------: | :----------: | :----------: | :-----------: | 163 | | MAML | 44.09% | 54.41% | 28.85% | 164 | | Multi-MAML | 45.46% | 55.92% | 33.78% | 165 | | MMAML (Ours) | 49.06% | 60.83% | 33.97% | 166 | 167 | Please check out [our paper](https://arxiv.org/abs/1910.13616) for more comprehensive results. 168 | 169 | ## Related work 170 | - \[MAML\] [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400) in ICML 2017 171 | - [Probabilistic Model-Agnostic Meta-Learning](https://arxiv.org/abs/1806.02817) in NeurIPS 2018 172 | - [Bayesian Model-Agnostic Meta-Learning](https://arxiv.org/abs/1806.03836) in NeurIPS 2018 173 | - [Gradient-Based Meta-Learning with Learned Layerwise Metric and Subspace](https://arxiv.org/abs/1801.05558) in ICML 2018 174 | - [Reptile: A Scalable Meta-Learning Algorithm](https://openai.com/blog/reptile/) 175 | - [Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples](https://arxiv.org/abs/1903.03096) In Meta-Learning Workshop at NeurIPS 2018 176 | - [TADAM: Task dependent adaptive metric for improved few-shot learning](https://arxiv.org/abs/1805.10123) in NeurIPS 2018 177 | - [ProMP: Proximal Meta-Policy Search](https://arxiv.org/abs/1810.06784) in ICLR 2019 178 | 179 | ## Cite the paper 180 | If you find this useful, please cite 181 | ``` 182 | @inproceedings{vuorio2019multimodal, 183 | title={Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation}, 184 | author={Vuorio, Risto and Sun, Shao-Hua and Hu, Hexiang and Lim, Joseph J.}, 185 | booktitle={Neural Information Processing Systems}, 186 | year={2019}, 187 | } 188 | ``` 189 | 190 | ## Authors 191 | [Shao-Hua Sun](http://shaohua0116.github.io/), [Risto Vuorio](https://vuoristo.github.io/), [Hexiang Hu](http://hexianghu.com/) 192 | -------------------------------------------------------------------------------- /asset/classification_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/asset/classification_summary.png -------------------------------------------------------------------------------- /asset/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/asset/model.png -------------------------------------------------------------------------------- /asset/tb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/asset/tb.png -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | 5 | parser = argparse.ArgumentParser(description='Download datasets for MMAML.') 6 | parser.add_argument('--dataset', metavar='N', type=str, nargs='+', 7 | choices=['aircraft', 'bird', 'cifar', 'miniimagenet']) 8 | 9 | 10 | def download(dataset): 11 | cmd = ['python', 'get_dataset_script/get_{}.py'.format(dataset)] 12 | print(' '.join(cmd)) 13 | subprocess.call(cmd) 14 | return 15 | 16 | 17 | if __name__ == '__main__': 18 | args = parser.parse_args() 19 | if len(args.dataset) > 0: 20 | for dataset in args.dataset: 21 | download(dataset) 22 | -------------------------------------------------------------------------------- /get_dataset_script/download_miniimagenet.py: -------------------------------------------------------------------------------- 1 | # https://drive.google.com/drive/folders/0B-r7apOz1BHAWXYwT1lGb3J1Yjg 2 | # Download files on Google Drive 3 | import requests 4 | 5 | def download_file_from_google_drive(id, destination): 6 | URL = "https://drive.google.com/uc?export=download" 7 | 8 | session = requests.Session() 9 | 10 | response = session.get(URL, params = { 'id' : id }, stream = True) 11 | token = get_confirm_token(response) 12 | 13 | if token: 14 | params = { 'id' : id, 'confirm' : token } 15 | response = session.get(URL, params = params, stream = True) 16 | 17 | save_response_content(response, destination) 18 | 19 | def get_confirm_token(response): 20 | for key, value in response.cookies.items(): 21 | if key.startswith('download_warning'): 22 | return value 23 | 24 | return None 25 | 26 | def save_response_content(response, destination): 27 | CHUNK_SIZE = 32768 28 | 29 | with open(destination, "wb") as f: 30 | for chunk in response.iter_content(CHUNK_SIZE): 31 | if chunk: # filter out keep-alive new chunks 32 | f.write(chunk) 33 | 34 | if __name__ == "__main__": 35 | file_id = "1HkgrkAwukzEZA0TpO7010PkAOREb2Nuk" 36 | destination = './mini-imagenet.zip' 37 | download_file_from_google_drive(file_id, destination) 38 | -------------------------------------------------------------------------------- /get_dataset_script/get_aircraft.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | cmds = [] 5 | cmds.append(['wget', 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz']) 6 | cmds.append(['tar', 'xvzf', 'fgvc-aircraft-2013b.tar.gz']) 7 | cmds.append(['python', 'get_dataset_script/proc_aircraft.py']) 8 | cmds.append(['rm', '-rf', 'fgvc-aircraft-2013b.tar.gz', 'fgvc-aircraft-2013b']) 9 | 10 | for cmd in cmds: 11 | print(' '.join(cmd)) 12 | subprocess.call(cmd) 13 | -------------------------------------------------------------------------------- /get_dataset_script/get_bird.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | 4 | 5 | cmds = [] 6 | cmds.append(['wget', 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz']) 7 | cmds.append(['tar', 'xvzf', 'CUB_200_2011.tgz']) 8 | cmds.append(['python', 'get_dataset_script/proc_bird.py']) 9 | cmds.append(['rm', '-rf', 'CUB_200_2011', 'CUB_200_2011.tgz']) 10 | 11 | for cmd in cmds: 12 | print(' '.join(cmd)) 13 | subprocess.call(cmd) 14 | -------------------------------------------------------------------------------- /get_dataset_script/get_cifar.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | cmds = [] 5 | cmds.append(['wget', 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz']) 6 | cmds.append(['tar', '-xzvf', 'cifar-100-python.tar.gz']) 7 | cmds.append(['mv', 'cifar-100-python', './data']) 8 | cmds.append(['python3', 'get_dataset_script/proc_cifar.py']) 9 | cmds.append(['rm', '-rf', 'cifar-100-python.tar.gz']) 10 | 11 | for cmd in cmds: 12 | print(' '.join(cmd)) 13 | subprocess.call(cmd) 14 | -------------------------------------------------------------------------------- /get_dataset_script/get_miniimagenet.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | cmds = [] 5 | cmds.append(['python', 'get_dataset_script/download_miniimagenet.py']) 6 | cmds.append(['unzip', 'mini-imagenet.zip']) 7 | cmds.append(['rm', '-rf', 'mini-imagenet.zip']) 8 | cmds.append(['mkdir', 'miniimagenet']) 9 | cmds.append(['mv', 'images', 'miniimagenet']) 10 | cmds.append(['python', 'get_dataset_script/proc_miniimagenet.py']) 11 | cmds.append(['mv', 'miniimagenet', './data']) 12 | 13 | for cmd in cmds: 14 | print(' '.join(cmd)) 15 | subprocess.call(cmd) 16 | -------------------------------------------------------------------------------- /get_dataset_script/proc_aircraft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import subprocess 4 | 5 | 6 | source_dir = './fgvc-aircraft-2013b/data' 7 | target_dir = './data/aircraft' 8 | 9 | percentage_train_class = 70 10 | percentage_val_class = 15 11 | percentage_test_class = 15 12 | train_val_test_ratio = [ 13 | percentage_train_class, percentage_val_class, percentage_test_class] 14 | 15 | with open(os.path.join(source_dir, 'variants.txt')) as f: 16 | lines = f.readlines() 17 | classes = [line.strip() for line in lines] 18 | 19 | rs = np.random.RandomState(123) 20 | rs.shuffle(classes) 21 | num_train, num_val, num_test = [ 22 | int(float(ratio)/np.sum(train_val_test_ratio)*len(classes)) 23 | for ratio in train_val_test_ratio] 24 | 25 | classes = { 26 | 'train': classes[:num_train], 27 | 'val': classes[num_train:num_train+num_val], 28 | 'test': classes[num_train+num_val:] 29 | } 30 | 31 | if not os.path.exists(target_dir): 32 | os.makedirs(target_dir) 33 | for k in classes.keys(): 34 | target_dir_k = os.path.join(target_dir, k) 35 | if not os.path.exists(target_dir_k): 36 | os.makedirs(target_dir_k) 37 | for c in classes[k]: 38 | c = c.replace('/', '-') 39 | target_dir_k_c = os.path.join(target_dir_k, c) 40 | if not os.path.exists(target_dir_k_c): 41 | os.makedirs(target_dir_k_c) 42 | 43 | lines = [] 44 | with open(os.path.join(source_dir, 'images_variant_trainval.txt')) as f: 45 | lines += f.readlines() 46 | with open(os.path.join(source_dir, 'images_variant_test.txt')) as f: 47 | lines += f.readlines() 48 | lines = [line.strip() for line in lines] 49 | 50 | for i, line in enumerate(lines): 51 | image_num, image_class = line.split(' ', 1) 52 | image_class = image_class.replace('/', '-') 53 | image_k = list(classes.keys())[np.argmax([image_class in classes[k] for k in list(classes.keys())])] 54 | image_source_path = os.path.join(source_dir, 'images', '{}.jpg'.format(image_num)) 55 | image_target_path = os.path.join(target_dir, image_k, image_class) 56 | cmd = ['mv', image_source_path, image_target_path] 57 | subprocess.call(cmd) 58 | print('{}/{} {}'.format(i, len(lines), ' '.join(cmd))) 59 | 60 | # resize images 61 | cmd = ['python', 'get_dataset_script/resize_dataset.py', './data/aircraft'] 62 | subprocess.call(cmd) 63 | -------------------------------------------------------------------------------- /get_dataset_script/proc_bird.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import subprocess 4 | from imageio import imread, imwrite 5 | 6 | 7 | source_dir = './CUB_200_2011/images' 8 | target_dir = './data/bird' 9 | 10 | percentage_train_class = 70 11 | percentage_val_class = 15 12 | percentage_test_class = 15 13 | train_val_test_ratio = [ 14 | percentage_train_class, percentage_val_class, percentage_test_class] 15 | 16 | classes = os.listdir(source_dir) 17 | 18 | rs = np.random.RandomState(123) 19 | rs.shuffle(classes) 20 | num_train, num_val, num_test = [ 21 | int(float(ratio)/np.sum(train_val_test_ratio)*len(classes)) 22 | for ratio in train_val_test_ratio] 23 | 24 | classes = { 25 | 'train': classes[:num_train], 26 | 'val': classes[num_train:num_train+num_val], 27 | 'test': classes[num_train+num_val:] 28 | } 29 | 30 | if not os.path.exists(target_dir): 31 | os.makedirs(target_dir) 32 | for k in classes.keys(): 33 | target_dir_k = os.path.join(target_dir, k) 34 | if not os.path.exists(target_dir_k): 35 | os.makedirs(target_dir_k) 36 | cmd = ['mv'] + [os.path.join(source_dir, c) for c in classes[k]] + [target_dir_k] 37 | subprocess.call(cmd) 38 | 39 | _ids = [] 40 | 41 | for root, dirnames, filenames in os.walk(target_dir): 42 | for filename in filenames: 43 | if filename.endswith(('.jpg', '.webp', '.JPEG', '.png', 'jpeg')): 44 | _ids.append(os.path.join(root, filename)) 45 | 46 | for path in _ids: 47 | try: 48 | img = imread(path) 49 | except: 50 | print(img) 51 | if len(img.shape) < 3: 52 | print(path) 53 | img = np.tile(np.expand_dims(img, axis=-1), [1, 1, 3]) 54 | imwrite(path, img) 55 | else: 56 | if img.shape[-1] == 1: 57 | print(path) 58 | img = np.tile(img, [1, 1, 3]) 59 | imwrite(path, img) 60 | 61 | # resize images 62 | cmd = ['python', 'get_dataset_script/resize_dataset.py', './data/bird'] 63 | subprocess.call(cmd) 64 | -------------------------------------------------------------------------------- /get_dataset_script/proc_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import os 5 | 6 | img_size=32 7 | classes={ 8 | 'train': [1, 2, 3, 4, 5, 6, 9, 10, 15, 17, 18, 19], 9 | 'val': [8, 11, 13, 16], 10 | 'test': [0, 7, 12, 14] 11 | } 12 | 13 | def _get_file_path(filename=""): 14 | return os.path.join('./data', "cifar-100-python/", filename) 15 | 16 | def _unpickle(filename): 17 | """ 18 | Unpickle the given file and return the data. 19 | Note that the appropriate dir-name is prepended the filename. 20 | """ 21 | 22 | # Create full path for the file. 23 | file_path = _get_file_path(filename) 24 | 25 | print("Loading data: " + file_path) 26 | 27 | with open(file_path, mode='rb') as file: 28 | # In Python 3.X it is important to set the encoding, 29 | # otherwise an exception is raised here. 30 | data = pickle.load(file, encoding='latin1') 31 | 32 | return data 33 | 34 | # import IPython 35 | # IPython.embed() 36 | meta = _unpickle('meta') 37 | train = _unpickle('train') 38 | test = _unpickle('test') 39 | 40 | data = np.concatenate([train['data'], test['data']]) 41 | labels = np.array(train['fine_labels'] + test['fine_labels']) 42 | filts = np.array(train['coarse_labels'] + test['coarse_labels']) 43 | 44 | cifar_data = {} 45 | cifar_label = {} 46 | for k, v in classes.items(): 47 | data_filter = np.zeros_like(filts) 48 | for i in v: data_filter += ( filts == i ) 49 | assert data_filter.max() == 1 50 | 51 | cifar_data[k] = data[data_filter == 1] 52 | cifar_label[k] = labels[data_filter == 1] 53 | 54 | torch.save({'data': cifar_data, 'label': cifar_label}, './data/cifar100.pth') 55 | -------------------------------------------------------------------------------- /get_dataset_script/proc_miniimagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code) 3 | 4 | Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the 5 | csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'. 6 | Then run this script from the miniImagenet directory: 7 | cd data/miniImagenet/ 8 | python proc_images.py 9 | """ 10 | 11 | from __future__ import print_function 12 | import csv 13 | import glob 14 | import os 15 | from tqdm import tqdm 16 | 17 | from PIL import Image 18 | 19 | path_to_images = 'miniimagenet/images/' 20 | 21 | all_images = glob.glob(path_to_images + '*') 22 | 23 | # Resize images 24 | for i, image_file in enumerate(tqdm(all_images)): 25 | im = Image.open(image_file) 26 | if not im.size == (84, 84): 27 | im = im.resize((84, 84), resample=Image.LANCZOS) 28 | im.save(image_file) 29 | 30 | # Put in correct directory 31 | for datatype in ['train', 'val', 'test']: 32 | print('Processing {} data'.format(datatype)) 33 | os.system('mkdir miniimagenet/' + datatype) 34 | 35 | with open(os.path.join('miniimagenet-data', datatype + '.csv'), 'r') as f: 36 | reader = csv.reader(f, delimiter=',') 37 | last_label = '' 38 | for i, row in enumerate(reader): 39 | if i == 0: # skip the headers 40 | continue 41 | label = row[1] 42 | image_name = row[0] 43 | if label != last_label: 44 | cur_dir = 'miniimagenet/' + datatype + '/' + label + '/' 45 | os.system('mkdir ' + cur_dir) 46 | last_label = label 47 | os.system('mv miniimagenet/images/' + image_name + ' ' + cur_dir) 48 | -------------------------------------------------------------------------------- /get_dataset_script/resize_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from imageio import imread, imwrite 4 | from skimage.transform import resize 5 | 6 | target_dir = './data/aircraft' if len(sys.argv) < 2 else sys.argv[1] 7 | img_size = [84, 84] 8 | 9 | _ids = [] 10 | 11 | for root, dirnames, filenames in os.walk(target_dir): 12 | for filename in filenames: 13 | if filename.endswith(('.jpg', '.webp', '.JPEG', '.png', 'jpeg')): 14 | _ids.append(os.path.join(root, filename)) 15 | 16 | for i, path in enumerate(_ids): 17 | img = imread(path) 18 | print('{}/{} size: {}'.format(i, len(_ids), img.shape)) 19 | imwrite(path, resize(img, img_size)) 20 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | 9 | from maml.datasets.omniglot import OmniglotMetaDataset 10 | from maml.datasets.miniimagenet import MiniimagenetMetaDataset 11 | from maml.datasets.cifar100 import Cifar100MetaDataset 12 | from maml.datasets.bird import BirdMetaDataset 13 | from maml.datasets.aircraft import AircraftMetaDataset 14 | from maml.datasets.multimodal_few_shot import MultimodalFewShotDataset 15 | from maml.models.fully_connected import FullyConnectedModel, MultiFullyConnectedModel 16 | from maml.models.conv_net import ConvModel 17 | from maml.models.gated_conv_net import GatedConvModel 18 | from maml.models.gated_net import GatedNet 19 | from maml.models.simple_embedding_model import SimpleEmbeddingModel 20 | from maml.models.lstm_embedding_model import LSTMEmbeddingModel 21 | from maml.models.gru_embedding_model import GRUEmbeddingModel 22 | from maml.models.conv_embedding_model import ConvEmbeddingModel 23 | from maml.metalearner import MetaLearner 24 | from maml.trainer import Trainer 25 | from maml.utils import optimizer_to_device, get_git_revision_hash 26 | 27 | 28 | def main(args): 29 | is_training = not args.eval 30 | run_name = 'train' if is_training else 'eval' 31 | 32 | if is_training: 33 | writer = SummaryWriter('./train_dir/{0}/{1}'.format( 34 | args.output_folder, run_name)) 35 | with open('./train_dir/{}/config.txt'.format( 36 | args.output_folder), 'w') as config_txt: 37 | for k, v in sorted(vars(args).items()): 38 | config_txt.write('{}: {}\n'.format(k, v)) 39 | else: 40 | writer = None 41 | 42 | save_folder = './train_dir/{0}'.format(args.output_folder) 43 | if not os.path.exists(save_folder): 44 | os.makedirs(save_folder) 45 | 46 | config_name = '{0}_config.json'.format(run_name) 47 | with open(os.path.join(save_folder, config_name), 'w') as f: 48 | config = {k: v for (k, v) in vars(args).items() if k != 'device'} 49 | config.update(device=args.device.type) 50 | try: 51 | config.update({'git_hash': get_git_revision_hash()}) 52 | except: 53 | pass 54 | json.dump(config, f, indent=2) 55 | 56 | _num_tasks = 1 57 | if args.dataset == 'omniglot': 58 | dataset = OmniglotMetaDataset( 59 | root='data', 60 | img_side_len=28, # args.img_side_len, 61 | num_classes_per_batch=args.num_classes_per_batch, 62 | num_samples_per_class=args.num_samples_per_class, 63 | num_total_batches=args.num_batches, 64 | num_val_samples=args.num_val_samples, 65 | meta_batch_size=args.meta_batch_size, 66 | train=is_training, 67 | num_train_classes=args.num_train_classes, 68 | num_workers=args.num_workers, 69 | device=args.device) 70 | loss_func = torch.nn.CrossEntropyLoss() 71 | collect_accuracies = True 72 | elif args.dataset == 'cifar': 73 | dataset = Cifar100MetaDataset( 74 | root='data', 75 | img_side_len=32, 76 | num_classes_per_batch=args.num_classes_per_batch, 77 | num_samples_per_class=args.num_samples_per_class, 78 | num_total_batches=args.num_batches, 79 | num_val_samples=args.num_val_samples, 80 | meta_batch_size=args.meta_batch_size, 81 | train=is_training, 82 | num_train_classes=args.num_train_classes, 83 | num_workers=args.num_workers, 84 | device=args.device) 85 | loss_func = torch.nn.CrossEntropyLoss() 86 | collect_accuracies = True 87 | elif args.dataset == 'multimodal_few_shot': 88 | dataset_list = [] 89 | if 'omniglot' in args.multimodal_few_shot: 90 | dataset_list.append(OmniglotMetaDataset( 91 | root='data', 92 | img_side_len=args.common_img_side_len, 93 | img_channel=args.common_img_channel, 94 | num_classes_per_batch=args.num_classes_per_batch, 95 | num_samples_per_class=args.num_samples_per_class, 96 | num_total_batches=args.num_batches, 97 | num_val_samples=args.num_val_samples, 98 | meta_batch_size=args.meta_batch_size, 99 | train=is_training, 100 | num_train_classes=args.num_train_classes, 101 | num_workers=args.num_workers, 102 | device=args.device) 103 | ) 104 | if 'miniimagenet' in args.multimodal_few_shot: 105 | dataset_list.append( MiniimagenetMetaDataset( 106 | root='data', 107 | img_side_len=args.common_img_side_len, 108 | img_channel=args.common_img_channel, 109 | num_classes_per_batch=args.num_classes_per_batch, 110 | num_samples_per_class=args.num_samples_per_class, 111 | num_total_batches=args.num_batches, 112 | num_val_samples=args.num_val_samples, 113 | meta_batch_size=args.meta_batch_size, 114 | train=is_training, 115 | num_train_classes=args.num_train_classes, 116 | num_workers=args.num_workers, 117 | device=args.device) 118 | ) 119 | if 'cifar' in args.multimodal_few_shot: 120 | dataset_list.append(Cifar100MetaDataset( 121 | root='data', 122 | img_side_len=args.common_img_side_len, 123 | img_channel=args.common_img_channel, 124 | num_classes_per_batch=args.num_classes_per_batch, 125 | num_samples_per_class=args.num_samples_per_class, 126 | num_total_batches=args.num_batches, 127 | num_val_samples=args.num_val_samples, 128 | meta_batch_size=args.meta_batch_size, 129 | train=is_training, 130 | num_train_classes=args.num_train_classes, 131 | num_workers=args.num_workers, 132 | device=args.device) 133 | ) 134 | if 'doublemnist' in args.multimodal_few_shot: 135 | dataset_list.append( DoubleMNISTMetaDataset( 136 | root='data', 137 | img_side_len=args.common_img_side_len, 138 | img_channel=args.common_img_channel, 139 | num_classes_per_batch=args.num_classes_per_batch, 140 | num_samples_per_class=args.num_samples_per_class, 141 | num_total_batches=args.num_batches, 142 | num_val_samples=args.num_val_samples, 143 | meta_batch_size=args.meta_batch_size, 144 | train=is_training, 145 | num_train_classes=args.num_train_classes, 146 | num_workers=args.num_workers, 147 | device=args.device) 148 | ) 149 | if 'triplemnist' in args.multimodal_few_shot: 150 | dataset_list.append( TripleMNISTMetaDataset( 151 | root='data', 152 | img_side_len=args.common_img_side_len, 153 | img_channel=args.common_img_channel, 154 | num_classes_per_batch=args.num_classes_per_batch, 155 | num_samples_per_class=args.num_samples_per_class, 156 | num_total_batches=args.num_batches, 157 | num_val_samples=args.num_val_samples, 158 | meta_batch_size=args.meta_batch_size, 159 | train=is_training, 160 | num_train_classes=args.num_train_classes, 161 | num_workers=args.num_workers, 162 | device=args.device) 163 | ) 164 | if 'bird' in args.multimodal_few_shot: 165 | dataset_list.append( BirdMetaDataset( 166 | root='data', 167 | img_side_len=args.common_img_side_len, 168 | img_channel=args.common_img_channel, 169 | num_classes_per_batch=args.num_classes_per_batch, 170 | num_samples_per_class=args.num_samples_per_class, 171 | num_total_batches=args.num_batches, 172 | num_val_samples=args.num_val_samples, 173 | meta_batch_size=args.meta_batch_size, 174 | train=is_training, 175 | num_train_classes=args.num_train_classes, 176 | num_workers=args.num_workers, 177 | device=args.device) 178 | ) 179 | if 'aircraft' in args.multimodal_few_shot: 180 | dataset_list.append( AircraftMetaDataset( 181 | root='data', 182 | img_side_len=args.common_img_side_len, 183 | img_channel=args.common_img_channel, 184 | num_classes_per_batch=args.num_classes_per_batch, 185 | num_samples_per_class=args.num_samples_per_class, 186 | num_total_batches=args.num_batches, 187 | num_val_samples=args.num_val_samples, 188 | meta_batch_size=args.meta_batch_size, 189 | train=is_training, 190 | num_train_classes=args.num_train_classes, 191 | num_workers=args.num_workers, 192 | device=args.device) 193 | ) 194 | assert len(dataset_list) > 0 195 | print('Multimodal Few Shot Datasets: {}'.format( 196 | ' '.join([dataset.name for dataset in dataset_list]))) 197 | dataset = MultimodalFewShotDataset( 198 | dataset_list, 199 | num_total_batches=args.num_batches, 200 | mix_meta_batch=args.mix_meta_batch, 201 | mix_mini_batch=args.mix_mini_batch, 202 | txt_file=args.sample_embedding_file+'.txt' if args.num_sample_embedding > 0 else None, 203 | train=is_training, 204 | ) 205 | loss_func = torch.nn.CrossEntropyLoss() 206 | collect_accuracies = True 207 | elif args.dataset == 'doublemnist': 208 | dataset = DoubleMNISTMetaDataset( 209 | root='data', 210 | img_side_len=64, 211 | num_classes_per_batch=args.num_classes_per_batch, 212 | num_samples_per_class=args.num_samples_per_class, 213 | num_total_batches=args.num_batches, 214 | num_val_samples=args.num_val_samples, 215 | meta_batch_size=args.meta_batch_size, 216 | train=is_training, 217 | num_train_classes=args.num_train_classes, 218 | num_workers=args.num_workers, 219 | device=args.device) 220 | loss_func = torch.nn.CrossEntropyLoss() 221 | collect_accuracies = True 222 | elif args.dataset == 'triplemnist': 223 | dataset = TripleMNISTMetaDataset( 224 | root='data', 225 | img_side_len=84, 226 | num_classes_per_batch=args.num_classes_per_batch, 227 | num_samples_per_class=args.num_samples_per_class, 228 | num_total_batches=args.num_batches, 229 | num_val_samples=args.num_val_samples, 230 | meta_batch_size=args.meta_batch_size, 231 | train=is_training, 232 | num_train_classes=args.num_train_classes, 233 | num_workers=args.num_workers, 234 | device=args.device) 235 | loss_func = torch.nn.CrossEntropyLoss() 236 | collect_accuracies = True 237 | elif args.dataset == 'miniimagenet': 238 | dataset = MiniimagenetMetaDataset( 239 | root='data', 240 | img_side_len=84, 241 | num_classes_per_batch=args.num_classes_per_batch, 242 | num_samples_per_class=args.num_samples_per_class, 243 | num_total_batches=args.num_batches, 244 | num_val_samples=args.num_val_samples, 245 | meta_batch_size=args.meta_batch_size, 246 | train=is_training, 247 | num_train_classes=args.num_train_classes, 248 | num_workers=args.num_workers, 249 | device=args.device) 250 | loss_func = torch.nn.CrossEntropyLoss() 251 | collect_accuracies = True 252 | elif args.dataset == 'sinusoid': 253 | dataset = SinusoidMetaDataset( 254 | num_total_batches=args.num_batches, 255 | num_samples_per_function=args.num_samples_per_class, 256 | num_val_samples=args.num_val_samples, 257 | meta_batch_size=args.meta_batch_size, 258 | amp_range=args.amp_range, 259 | phase_range=args.phase_range, 260 | input_range=args.input_range, 261 | oracle=args.oracle, 262 | train=is_training, 263 | device=args.device) 264 | loss_func = torch.nn.MSELoss() 265 | collect_accuracies = False 266 | elif args.dataset == 'linear': 267 | dataset = LinearMetaDataset( 268 | num_total_batches=args.num_batches, 269 | num_samples_per_function=args.num_samples_per_class, 270 | num_val_samples=args.num_val_samples, 271 | meta_batch_size=args.meta_batch_size, 272 | slope_range=args.slope_range, 273 | intersect_range=args.intersect_range, 274 | input_range=args.input_range, 275 | oracle=args.oracle, 276 | train=is_training, 277 | device=args.device) 278 | loss_func = torch.nn.MSELoss() 279 | collect_accuracies = False 280 | elif args.dataset == 'mixed': 281 | dataset = MixedFunctionsMetaDataset( 282 | num_total_batches=args.num_batches, 283 | num_samples_per_function=args.num_samples_per_class, 284 | num_val_samples=args.num_val_samples, 285 | meta_batch_size=args.meta_batch_size, 286 | amp_range=args.amp_range, 287 | phase_range=args.phase_range, 288 | slope_range=args.slope_range, 289 | intersect_range=args.intersect_range, 290 | input_range=args.input_range, 291 | noise_std=args.noise_std, 292 | oracle=args.oracle, 293 | task_oracle=args.task_oracle, 294 | train=is_training, 295 | device=args.device) 296 | loss_func = torch.nn.MSELoss() 297 | collect_accuracies = False 298 | _num_tasks=2 299 | elif args.dataset == 'many': 300 | dataset = ManyFunctionsMetaDataset( 301 | num_total_batches=args.num_batches, 302 | num_samples_per_function=args.num_samples_per_class, 303 | num_val_samples=args.num_val_samples, 304 | meta_batch_size=args.meta_batch_size, 305 | amp_range=args.amp_range, 306 | phase_range=args.phase_range, 307 | slope_range=args.slope_range, 308 | intersect_range=args.intersect_range, 309 | input_range=args.input_range, 310 | noise_std=args.noise_std, 311 | oracle=args.oracle, 312 | task_oracle=args.task_oracle, 313 | train=is_training, 314 | device=args.device) 315 | loss_func = torch.nn.MSELoss() 316 | collect_accuracies = False 317 | _num_tasks=3 318 | elif args.dataset == 'multisinusoids': 319 | dataset = MultiSinusoidsMetaDataset( 320 | num_total_batches=args.num_batches, 321 | num_samples_per_function=args.num_samples_per_class, 322 | num_val_samples=args.num_val_samples, 323 | meta_batch_size=args.meta_batch_size, 324 | amp_range=args.amp_range, 325 | phase_range=args.phase_range, 326 | slope_range=args.slope_range, 327 | intersect_range=args.intersect_range, 328 | input_range=args.input_range, 329 | noise_std=args.noise_std, 330 | oracle=args.oracle, 331 | task_oracle=args.task_oracle, 332 | train=is_training, 333 | device=args.device) 334 | loss_func = torch.nn.MSELoss() 335 | collect_accuracies = False 336 | else: 337 | raise ValueError('Unrecognized dataset {}'.format(args.dataset)) 338 | 339 | embedding_model = None 340 | 341 | if args.model_type == 'fc': 342 | model = FullyConnectedModel( 343 | input_size=np.prod(dataset.input_size), 344 | output_size=dataset.output_size, 345 | hidden_sizes=args.hidden_sizes, 346 | disable_norm=args.disable_norm, 347 | bias_transformation_size=args.bias_transformation_size) 348 | elif args.model_type == 'multi': 349 | model = MultiFullyConnectedModel( 350 | input_size=np.prod(dataset.input_size), 351 | output_size=dataset.output_size, 352 | hidden_sizes=args.hidden_sizes, 353 | disable_norm=args.disable_norm, 354 | num_tasks=_num_tasks, 355 | bias_transformation_size=args.bias_transformation_size) 356 | elif args.model_type == 'conv': 357 | model = ConvModel( 358 | input_channels=dataset.input_size[0], 359 | output_size=dataset.output_size, 360 | num_channels=args.num_channels, 361 | img_side_len=dataset.input_size[1], 362 | use_max_pool=args.use_max_pool, 363 | verbose=args.verbose) 364 | elif args.model_type == 'gatedconv': 365 | model = GatedConvModel( 366 | input_channels=dataset.input_size[0], 367 | output_size=dataset.output_size, 368 | use_max_pool=args.use_max_pool, 369 | num_channels=args.num_channels, 370 | img_side_len=dataset.input_size[1], 371 | condition_type=args.condition_type, 372 | condition_order=args.condition_order, 373 | verbose=args.verbose) 374 | elif args.model_type == 'gated': 375 | model = GatedNet( 376 | input_size=np.prod(dataset.input_size), 377 | output_size=dataset.output_size, 378 | hidden_sizes=args.hidden_sizes, 379 | condition_type=args.condition_type, 380 | condition_order=args.condition_order) 381 | else: 382 | raise ValueError('Unrecognized model type {}'.format(args.model_type)) 383 | model_parameters = list(model.parameters()) 384 | 385 | if args.embedding_type == '': 386 | embedding_model = None 387 | elif args.embedding_type == 'simple': 388 | embedding_model = SimpleEmbeddingModel( 389 | num_embeddings=dataset.num_tasks, 390 | embedding_dims=args.embedding_dims) 391 | embedding_parameters = list(embedding_model.parameters()) 392 | elif args.embedding_type == 'GRU': 393 | embedding_model = GRUEmbeddingModel( 394 | input_size=np.prod(dataset.input_size), 395 | output_size=dataset.output_size, 396 | embedding_dims=args.embedding_dims, 397 | hidden_size=args.embedding_hidden_size, 398 | num_layers=args.embedding_num_layers) 399 | embedding_parameters = list(embedding_model.parameters()) 400 | elif args.embedding_type == 'LSTM': 401 | embedding_model = LSTMEmbeddingModel( 402 | input_size=np.prod(dataset.input_size), 403 | output_size=dataset.output_size, 404 | embedding_dims=args.embedding_dims, 405 | hidden_size=args.embedding_hidden_size, 406 | num_layers=args.embedding_num_layers) 407 | embedding_parameters = list(embedding_model.parameters()) 408 | elif args.embedding_type == 'ConvGRU': 409 | embedding_model = ConvEmbeddingModel( 410 | input_size=np.prod(dataset.input_size), 411 | output_size=dataset.output_size, 412 | embedding_dims=args.embedding_dims, 413 | hidden_size=args.embedding_hidden_size, 414 | num_layers=args.embedding_num_layers, 415 | convolutional=args.conv_embedding, 416 | num_conv=args.num_conv_embedding_layer, 417 | num_channels=args.num_channels, 418 | rnn_aggregation=(not args.no_rnn_aggregation), 419 | embedding_pooling=args.embedding_pooling, 420 | batch_norm=args.conv_embedding_batch_norm, 421 | avgpool_after_conv=args.conv_embedding_avgpool_after_conv, 422 | linear_before_rnn=args.linear_before_rnn, 423 | num_sample_embedding=args.num_sample_embedding, 424 | sample_embedding_file=args.sample_embedding_file+'.'+args.sample_embedding_file_type, 425 | img_size=dataset.input_size, 426 | verbose=args.verbose) 427 | embedding_parameters = list(embedding_model.parameters()) 428 | else: 429 | raise ValueError('Unrecognized embedding type {}'.format( 430 | args.embedding_type)) 431 | 432 | optimizers = None 433 | if embedding_model: 434 | optimizers = ( torch.optim.Adam(model_parameters, lr=args.slow_lr), 435 | torch.optim.Adam(embedding_parameters, lr=args.slow_lr) ) 436 | else: 437 | optimizers = ( torch.optim.Adam(model_parameters, lr=args.slow_lr), ) 438 | 439 | if args.checkpoint != '': 440 | checkpoint = torch.load(args.checkpoint) 441 | model.load_state_dict(checkpoint['model_state_dict']) 442 | model.to(args.device) 443 | if 'optimizer' in checkpoint: 444 | pass 445 | else: 446 | optimizers[0].load_state_dict(checkpoint['optimizers'][0]) 447 | optimizer_to_device(optimizers[0], args.device) 448 | 449 | if embedding_model: 450 | embedding_model.load_state_dict( 451 | checkpoint['embedding_model_state_dict']) 452 | optimizers[1].load_state_dict(checkpoint['optimizers'][1]) 453 | optimizer_to_device(optimizers[1], args.device) 454 | 455 | meta_learner = MetaLearner( 456 | model, embedding_model, optimizers, fast_lr=args.fast_lr, 457 | loss_func=loss_func, first_order=args.first_order, 458 | num_updates=args.num_updates, 459 | inner_loop_grad_clip=args.inner_loop_grad_clip, 460 | collect_accuracies=collect_accuracies, device=args.device, 461 | alternating=args.alternating, embedding_schedule=args.embedding_schedule, 462 | classifier_schedule=args.classifier_schedule, embedding_grad_clip=args.embedding_grad_clip) 463 | 464 | trainer = Trainer( 465 | meta_learner=meta_learner, meta_dataset=dataset, writer=writer, 466 | log_interval=args.log_interval, save_interval=args.save_interval, 467 | model_type=args.model_type, save_folder=save_folder, 468 | total_iter=args.num_batches//args.meta_batch_size 469 | ) 470 | 471 | if is_training: 472 | trainer.train() 473 | else: 474 | trainer.eval() 475 | 476 | 477 | if __name__ == '__main__': 478 | 479 | def str2bool(arg): 480 | return arg.lower() == 'true' 481 | 482 | parser = argparse.ArgumentParser( 483 | description='Model-Agnostic Meta-Learning (MAML)') 484 | 485 | parser.add_argument('--mmaml-model', type=str2bool, default=False, 486 | help='gated_conv + ConvGRU') 487 | parser.add_argument('--maml-model', type=str2bool, default=False, 488 | help='conv') 489 | 490 | # Model 491 | parser.add_argument('--hidden-sizes', type=int, 492 | default=[256, 128, 64, 64], nargs='+', 493 | help='number of hidden units per layer') 494 | parser.add_argument('--model-type', type=str, default='gatedconv', 495 | help='type of the model') 496 | parser.add_argument('--condition-type', type=str, default='affine', 497 | choices=['affine', 'sigmoid', 'softmax'], 498 | help='type of the conditional layers') 499 | parser.add_argument('--condition-order', type=str, default='low2high', 500 | help='order of the conditional layers to be used') 501 | parser.add_argument('--use-max-pool', type=str2bool, default=False, 502 | help='choose whether to use max pooling with convolutional model') 503 | parser.add_argument('--num-channels', type=int, default=32, 504 | help='number of channels in convolutional layers') 505 | parser.add_argument('--disable-norm', action='store_true', 506 | help='disable batchnorm after linear layers in a fully connected model') 507 | parser.add_argument('--bias-transformation-size', type=int, default=0, 508 | help='size of bias transformation vector that is concatenated with ' 509 | 'input') 510 | 511 | # Embedding 512 | parser.add_argument('--embedding-type', type=str, default='', 513 | help='type of the embedding') 514 | parser.add_argument('--embedding-hidden-size', type=int, default=128, 515 | help='number of hidden units per layer in recurrent embedding model') 516 | parser.add_argument('--embedding-num-layers', type=int, default=2, 517 | help='number of layers in recurrent embedding model') 518 | parser.add_argument('--embedding-dims', type=int, nargs='+', default=0, 519 | help='dimensions of the embeddings') 520 | 521 | # Randomly sampled embedding vectors 522 | parser.add_argument('--num-sample-embedding', type=int, default=0, 523 | help='number of randomly sampled embedding vectors') 524 | parser.add_argument( 525 | '--sample-embedding-file', type=str, default='embeddings', 526 | help='the file name of randomly sampled embedding vectors') 527 | parser.add_argument( 528 | '--sample-embedding-file-type', type=str, default='hdf5') 529 | 530 | # Inner loop 531 | parser.add_argument('--first-order', action='store_true', 532 | help='use the first-order approximation of MAML') 533 | parser.add_argument('--fast-lr', type=float, default=0.05, 534 | help='learning rate for the 1-step gradient update of MAML') 535 | parser.add_argument('--inner-loop-grad-clip', type=float, default=20.0, 536 | help='enable gradient clipping in the inner loop') 537 | parser.add_argument('--num-updates', type=int, default=5, 538 | help='how many update steps in the inner loop') 539 | 540 | # Optimization 541 | parser.add_argument('--num-batches', type=int, default=1920000, 542 | help='number of batches') 543 | parser.add_argument('--meta-batch-size', type=int, default=10, 544 | help='number of tasks per batch') 545 | parser.add_argument('--slow-lr', type=float, default=0.001, 546 | help='learning rate for the global update of MAML') 547 | 548 | # Miscellaneous 549 | parser.add_argument('--output-folder', type=str, default='maml', 550 | help='name of the output folder') 551 | parser.add_argument('--device', type=str, default='cuda', 552 | help='set the device (cpu or cuda)') 553 | parser.add_argument('--num-workers', type=int, default=4, 554 | help='how many DataLoader workers to use') 555 | parser.add_argument('--log-interval', type=int, default=100, 556 | help='number of batches between tensorboard writes') 557 | parser.add_argument('--save-interval', type=int, default=1000, 558 | help='number of batches between model saves') 559 | parser.add_argument('--eval', action='store_true', default=False, 560 | help='evaluate model') 561 | parser.add_argument('--checkpoint', type=str, default='', 562 | help='path to saved parameters.') 563 | 564 | # Dataset 565 | parser.add_argument('--dataset', type=str, default='multimodal_few_shot', 566 | help='which dataset to use') 567 | parser.add_argument('--data-root', type=str, default='data', 568 | help='path to store datasets') 569 | parser.add_argument('--num-train-classes', type=int, default=1100, 570 | help='how many classes for training') 571 | parser.add_argument('--num-classes-per-batch', type=int, default=5, 572 | help='how many classes per task') 573 | parser.add_argument('--num-samples-per-class', type=int, default=1, 574 | help='how many samples per class for training') 575 | parser.add_argument('--num-val-samples', type=int, default=15, 576 | help='how many samples per class for validation') 577 | parser.add_argument('--img-side-len', type=int, default=28, 578 | help='width and height of the input images') 579 | parser.add_argument('--input-range', type=float, default=[-5.0, 5.0], 580 | nargs='+', help='input range of simple functions') 581 | parser.add_argument('--phase-range', type=float, default=[0, np.pi], 582 | nargs='+', help='phase range of sinusoids') 583 | parser.add_argument('--amp-range', type=float, default=[0.1, 5.0], 584 | nargs='+', help='amp range of sinusoids') 585 | parser.add_argument('--slope-range', type=float, default=[-3.0, 3.0], 586 | nargs='+', help='slope range of linear functions') 587 | parser.add_argument('--intersect-range', type=float, default=[-3.0, 3.0], 588 | nargs='+', help='intersect range of linear functions') 589 | parser.add_argument('--noise-std', type=float, default=0.0, 590 | help='add gaussian noise to mixed functions') 591 | parser.add_argument('--oracle', action='store_true', 592 | help='concatenate phase and amp to sinusoid inputs') 593 | parser.add_argument('--task-oracle', action='store_true', 594 | help='uses task id for prediction in some models') 595 | 596 | # Combine few-shot learning datasets 597 | parser.add_argument('--multimodal_few_shot', type=str, 598 | default=['omniglot', 'cifar', 'miniimagenet', 'doublemnist', 'triplemnist'], 599 | choices=['omniglot', 'cifar', 'miniimagenet', 'doublemnist', 'triplemnist', 600 | 'bird', 'aircraft'], 601 | nargs='+') 602 | parser.add_argument('--common-img-side-len', type=int, default=84) 603 | parser.add_argument('--common-img-channel', type=int, default=3, 604 | help='3 for RGB and 1 for grayscale') 605 | parser.add_argument('--mix-meta-batch', type=str2bool, default=True) 606 | parser.add_argument('--mix-mini-batch', type=str2bool, default=False) 607 | 608 | parser.add_argument('--alternating', action='store_true', 609 | help='') 610 | parser.add_argument('--classifier-schedule', type=int, default=10, 611 | help='') 612 | parser.add_argument('--embedding-schedule', type=int, default=10, 613 | help='') 614 | parser.add_argument('--conv-embedding', type=str2bool, default=True, 615 | help='') 616 | parser.add_argument('--conv-embedding-batch-norm', type=str2bool, default=True, 617 | help='') 618 | parser.add_argument('--conv-embedding-avgpool-after-conv', type=str2bool, default=True, 619 | help='') 620 | parser.add_argument('--num-conv-embedding-layer', type=int, default=4, 621 | help='') 622 | parser.add_argument('--no-rnn-aggregation', type=str2bool, default=True, 623 | help='') 624 | parser.add_argument('--embedding-pooling', type=str, 625 | choices=['avg', 'max'], default='avg', help='') 626 | parser.add_argument('--linear-before-rnn', action='store_true', 627 | help='') 628 | parser.add_argument('--embedding-grad-clip', type=float, default=0.0, 629 | help='') 630 | parser.add_argument('--verbose', type=str2bool, default=False, 631 | help='') 632 | 633 | args = parser.parse_args() 634 | 635 | # Create logs and saves folder if they don't exist 636 | if not os.path.exists('./train_dir'): 637 | os.makedirs('./train_dir') 638 | 639 | # Make sure num sample embedding < num sample tasks 640 | args.num_sample_embedding = min(args.num_sample_embedding, args.num_batches) 641 | 642 | # computer embedding dims 643 | num_gated_conv_layers = 4 644 | if args.embedding_dims == 0: 645 | args.embedding_dims = [] 646 | for i in range(num_gated_conv_layers): 647 | embedding_dim = args.num_channels*2**i 648 | if args.condition_type == 'affine': 649 | embedding_dim *= 2 650 | args.embedding_dims.append(embedding_dim) 651 | 652 | assert not (args.mmaml_model and args.maml_model) 653 | 654 | # mmaml model: gated conv + convGRU 655 | if args.mmaml_model is True: 656 | print('Use MMAML') 657 | args.model_type = 'gatedconv' 658 | args.embedding_type = 'ConvGRU' 659 | 660 | # maml model: conv 661 | if args.maml_model is True: 662 | print('Use vanilla MAML') 663 | args.model_type = 'conv' 664 | args.embedding_type = '' 665 | 666 | # Device 667 | args.device = torch.device(args.device 668 | if torch.cuda.is_available() else 'cpu') 669 | 670 | # print args 671 | if args.verbose: 672 | print('='*10 + ' ARGS ' + '='*10) 673 | for k, v in sorted(vars(args).items()): 674 | print('{}: {}'.format(k, v)) 675 | print('='*26) 676 | 677 | main(args) 678 | -------------------------------------------------------------------------------- /maml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/maml/__init__.py -------------------------------------------------------------------------------- /maml/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/maml/datasets/__init__.py -------------------------------------------------------------------------------- /maml/datasets/aircraft.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | from collections import defaultdict 4 | 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | 10 | from maml.sampler import ClassBalancedSampler 11 | from maml.datasets.metadataset import Task 12 | 13 | 14 | class AircraftMAMLSplit(): 15 | def __init__(self, root, train=True, num_train_classes=70, 16 | transform=None, target_transform=None, **kwargs): 17 | self.transform = transform 18 | self.target_transform = target_transform 19 | self.root = root + '/aircraft' 20 | 21 | self._train = train 22 | 23 | if self._train: 24 | all_character_dirs = glob.glob(self.root + '/train/**') 25 | self._characters = all_character_dirs 26 | else: 27 | all_character_dirs = glob.glob(self.root + '/test/**') 28 | self._characters = all_character_dirs 29 | 30 | self._character_images = [] 31 | for i, char_path in enumerate(self._characters): 32 | img_list = [(cp, i) for cp in glob.glob(char_path + '/*')] 33 | self._character_images.append(img_list) 34 | 35 | self._flat_character_images = sum(self._character_images, []) 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | Returns: 42 | tuple: (image, target) where target is index of the target 43 | character class. 44 | """ 45 | image_path, character_class = self._flat_character_images[index] 46 | image = Image.open(image_path, mode='r') 47 | 48 | if self.transform: 49 | image = self.transform(image) 50 | 51 | if self.target_transform: 52 | character_class = self.target_transform(character_class) 53 | 54 | return image, character_class 55 | 56 | 57 | class AircraftMetaDataset(object): 58 | """ 59 | TODO: Check if the data loader is fast enough. 60 | Args: 61 | root: path to aircraft dataset 62 | img_side_len: images are scaled to this size 63 | num_classes_per_batch: number of classes to sample for each batch 64 | num_samples_per_class: number of samples to sample for each class 65 | for each batch. For K shot learning this should be K + number 66 | of validation samples 67 | num_total_batches: total number of tasks to generate 68 | train: whether to create data loader from the test or validation data 69 | """ 70 | def __init__(self, name='aircraft', root='data', 71 | img_side_len=84, img_channel=3, 72 | num_classes_per_batch=5, num_samples_per_class=6, 73 | num_total_batches=200000, 74 | num_val_samples=1, meta_batch_size=40, train=True, 75 | num_train_classes=1100, num_workers=0, device='cpu'): 76 | self.name = name 77 | self._root = root 78 | self._img_side_len = img_side_len 79 | self._img_channel = img_channel 80 | self._num_classes_per_batch = num_classes_per_batch 81 | self._num_samples_per_class = num_samples_per_class 82 | self._num_total_batches = num_total_batches 83 | self._num_val_samples = num_val_samples 84 | self._meta_batch_size = meta_batch_size 85 | self._num_train_classes = num_train_classes 86 | self._train = train 87 | self._num_workers = num_workers 88 | self._device = device 89 | 90 | self._total_samples_per_class = ( 91 | num_samples_per_class + num_val_samples) 92 | self._dataloader = self._get_aircraft_data_loader() 93 | 94 | self.input_size = (img_channel, img_side_len, img_side_len) 95 | self.output_size = self._num_classes_per_batch 96 | 97 | def _get_aircraft_data_loader(self): 98 | assert self._img_channel == 1 or self._img_channel == 3 99 | resize = transforms.Resize( 100 | (self._img_side_len, self._img_side_len), Image.LANCZOS) 101 | if self._img_channel == 1: 102 | img_transform = transforms.Compose( 103 | [resize, transforms.Grayscale(num_output_channels=1), 104 | transforms.ToTensor()]) 105 | else: 106 | img_transform = transforms.Compose( 107 | [resize, transforms.ToTensor()]) 108 | dset = AircraftMAMLSplit( 109 | self._root, transform=img_transform, train=self._train, 110 | download=True, num_train_classes=self._num_train_classes) 111 | _, labels = zip(*dset._flat_character_images) 112 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch, 113 | self._total_samples_per_class, 114 | self._num_total_batches, self._train) 115 | 116 | batch_size = (self._num_classes_per_batch * 117 | self._total_samples_per_class * 118 | self._meta_batch_size) 119 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler, 120 | num_workers=self._num_workers, pin_memory=True) 121 | return loader 122 | 123 | def _make_single_batch(self, imgs, labels): 124 | """Split imgs and labels into train and validation set. 125 | TODO: check if this might become the bottleneck""" 126 | # relabel classes randomly 127 | new_labels = list(range(self._num_classes_per_batch)) 128 | random.shuffle(new_labels) 129 | labels = labels.tolist() 130 | label_set = set(labels) 131 | label_map = {label: new_labels[i] for i, label in enumerate(label_set)} 132 | labels = [label_map[l] for l in labels] 133 | 134 | label_indices = defaultdict(list) 135 | for i, label in enumerate(labels): 136 | label_indices[label].append(i) 137 | 138 | # assign samples to train and validation sets 139 | val_indices = [] 140 | train_indices = [] 141 | for label, indices in label_indices.items(): 142 | val_indices.extend(indices[:self._num_val_samples]) 143 | train_indices.extend(indices[self._num_val_samples:]) 144 | label_tensor = torch.tensor(labels, device=self._device) 145 | imgs = imgs.to(self._device) 146 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name) 147 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name) 148 | 149 | return train_task, val_task 150 | 151 | def _make_meta_batch(self, imgs, labels): 152 | batches = [] 153 | inner_batch_size = ( 154 | self._total_samples_per_class * self._num_classes_per_batch) 155 | for i in range(0, len(imgs) - 1, inner_batch_size): 156 | batch_imgs = imgs[i:i+inner_batch_size] 157 | batch_labels = labels[i:i+inner_batch_size] 158 | batch = self._make_single_batch(batch_imgs, batch_labels) 159 | batches.append(batch) 160 | 161 | train_tasks, val_tasks = zip(*batches) 162 | 163 | return train_tasks, val_tasks 164 | 165 | def __iter__(self): 166 | for imgs, labels in iter(self._dataloader): 167 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels) 168 | yield train_tasks, val_tasks 169 | -------------------------------------------------------------------------------- /maml/datasets/bird.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | from collections import defaultdict 4 | 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | 10 | from maml.sampler import ClassBalancedSampler 11 | from maml.datasets.metadataset import Task 12 | 13 | 14 | class BirdMAMLSplit(): 15 | def __init__(self, root, train=True, num_train_classes=140, 16 | transform=None, target_transform=None, **kwargs): 17 | self.transform = transform 18 | self.target_transform = target_transform 19 | self.root = root + '/bird' 20 | 21 | self._train = train 22 | 23 | if self._train: 24 | all_character_dirs = glob.glob(self.root + '/train/**') 25 | self._characters = all_character_dirs 26 | else: 27 | all_character_dirs = glob.glob(self.root + '/test/**') 28 | self._characters = all_character_dirs 29 | 30 | self._character_images = [] 31 | for i, char_path in enumerate(self._characters): 32 | img_list = [(cp, i) for cp in glob.glob(char_path + '/*')] 33 | self._character_images.append(img_list) 34 | 35 | self._flat_character_images = sum(self._character_images, []) 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | Returns: 42 | tuple: (image, target) where target is index of the target 43 | character class. 44 | """ 45 | image_path, character_class = self._flat_character_images[index] 46 | image = Image.open(image_path, mode='r') 47 | 48 | if self.transform: 49 | image = self.transform(image) 50 | 51 | if self.target_transform: 52 | character_class = self.target_transform(character_class) 53 | 54 | return image, character_class 55 | 56 | 57 | class BirdMetaDataset(object): 58 | """ 59 | TODO: Check if the data loader is fast enough. 60 | Args: 61 | root: path to bird dataset 62 | img_side_len: images are scaled to this size 63 | num_classes_per_batch: number of classes to sample for each batch 64 | num_samples_per_class: number of samples to sample for each class 65 | for each batch. For K shot learning this should be K + number 66 | of validation samples 67 | num_total_batches: total number of tasks to generate 68 | train: whether to create data loader from the test or validation data 69 | """ 70 | def __init__(self, name='Bird', root='data', 71 | img_side_len=84, img_channel=3, 72 | num_classes_per_batch=5, num_samples_per_class=6, 73 | num_total_batches=200000, 74 | num_val_samples=1, meta_batch_size=40, train=True, 75 | num_train_classes=1100, num_workers=0, device='cpu'): 76 | self.name = name 77 | self._root = root 78 | self._img_side_len = img_side_len 79 | self._img_channel = img_channel 80 | self._num_classes_per_batch = num_classes_per_batch 81 | self._num_samples_per_class = num_samples_per_class 82 | self._num_total_batches = num_total_batches 83 | self._num_val_samples = num_val_samples 84 | self._meta_batch_size = meta_batch_size 85 | self._num_train_classes = num_train_classes 86 | self._train = train 87 | self._num_workers = num_workers 88 | self._device = device 89 | 90 | self._total_samples_per_class = ( 91 | num_samples_per_class + num_val_samples) 92 | self._dataloader = self._get_bird_data_loader() 93 | 94 | self.input_size = (img_channel, img_side_len, img_side_len) 95 | self.output_size = self._num_classes_per_batch 96 | 97 | def _get_bird_data_loader(self): 98 | assert self._img_channel == 1 or self._img_channel == 3 99 | resize = transforms.Resize( 100 | (self._img_side_len, self._img_side_len), Image.LANCZOS) 101 | if self._img_channel == 1: 102 | img_transform = transforms.Compose( 103 | [resize, transforms.Grayscale(num_output_channels=1), 104 | transforms.ToTensor()]) 105 | else: 106 | img_transform = transforms.Compose( 107 | [resize, transforms.ToTensor()]) 108 | dset = BirdMAMLSplit( 109 | self._root, transform=img_transform, train=self._train, 110 | download=True, num_train_classes=self._num_train_classes) 111 | _, labels = zip(*dset._flat_character_images) 112 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch, 113 | self._total_samples_per_class, 114 | self._num_total_batches, self._train) 115 | 116 | batch_size = (self._num_classes_per_batch * 117 | self._total_samples_per_class * 118 | self._meta_batch_size) 119 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler, 120 | num_workers=self._num_workers, pin_memory=True) 121 | return loader 122 | 123 | def _make_single_batch(self, imgs, labels): 124 | """Split imgs and labels into train and validation set. 125 | TODO: check if this might become the bottleneck""" 126 | # relabel classes randomly 127 | new_labels = list(range(self._num_classes_per_batch)) 128 | random.shuffle(new_labels) 129 | labels = labels.tolist() 130 | label_set = set(labels) 131 | label_map = {label: new_labels[i] for i, label in enumerate(label_set)} 132 | labels = [label_map[l] for l in labels] 133 | 134 | label_indices = defaultdict(list) 135 | for i, label in enumerate(labels): 136 | label_indices[label].append(i) 137 | 138 | # assign samples to train and validation sets 139 | val_indices = [] 140 | train_indices = [] 141 | for label, indices in label_indices.items(): 142 | val_indices.extend(indices[:self._num_val_samples]) 143 | train_indices.extend(indices[self._num_val_samples:]) 144 | label_tensor = torch.tensor(labels, device=self._device) 145 | imgs = imgs.to(self._device) 146 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name) 147 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name) 148 | 149 | return train_task, val_task 150 | 151 | def _make_meta_batch(self, imgs, labels): 152 | batches = [] 153 | inner_batch_size = ( 154 | self._total_samples_per_class * self._num_classes_per_batch) 155 | for i in range(0, len(imgs) - 1, inner_batch_size): 156 | batch_imgs = imgs[i:i+inner_batch_size] 157 | batch_labels = labels[i:i+inner_batch_size] 158 | batch = self._make_single_batch(batch_imgs, batch_labels) 159 | batches.append(batch) 160 | 161 | train_tasks, val_tasks = zip(*batches) 162 | 163 | return train_tasks, val_tasks 164 | 165 | def __iter__(self): 166 | for imgs, labels in iter(self._dataloader): 167 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels) 168 | yield train_tasks, val_tasks 169 | -------------------------------------------------------------------------------- /maml/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | import numpy as np 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from torchvision.datasets.utils import list_files 12 | 13 | from maml.sampler import ClassBalancedSampler 14 | from maml.datasets.metadataset import Task 15 | 16 | class Cifar100MAMLSplit(): 17 | def __init__(self, root, train=True, num_train_classes=100, 18 | transform=None, target_transform=None, **kwargs): 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | self.root = os.path.join(root, 'cifar100.pth') 22 | 23 | self._train = train 24 | self._num_train_classes = num_train_classes 25 | 26 | self._dataset = torch.load(self.root) 27 | if self._train: 28 | self._images = torch.FloatTensor(self._dataset['data']['train'].reshape([-1, 3, 32, 32])) 29 | self._labels = torch.LongTensor(self._dataset['label']['train']) 30 | else: 31 | self._images = torch.FloatTensor(self._dataset['data']['test'].reshape([-1, 3, 32, 32])) 32 | self._labels = torch.LongTensor(self._dataset['label']['test']) 33 | 34 | def __getitem__(self, index): 35 | image = self._images[index] 36 | 37 | if self.transform: 38 | image = self.transform(self._images[index]) 39 | 40 | return image, self._labels[index] 41 | 42 | class Cifar100MetaDataset(object): 43 | def __init__(self, name='FC100', root='data', 44 | img_side_len=32, img_channel=3, 45 | num_classes_per_batch=5, num_samples_per_class=6, 46 | num_total_batches=200000, 47 | num_val_samples=1, meta_batch_size=32, train=True, 48 | num_train_classes=100, num_workers=0, device='cpu'): 49 | self.name = name 50 | self._root = root 51 | self._img_side_len = img_side_len 52 | self._img_channel = img_channel 53 | self._num_classes_per_batch = num_classes_per_batch 54 | self._num_samples_per_class = num_samples_per_class 55 | self._num_total_batches = num_total_batches 56 | self._num_val_samples = num_val_samples 57 | self._meta_batch_size = meta_batch_size 58 | self._num_train_classes = num_train_classes 59 | self._train = train 60 | self._num_workers = num_workers 61 | self._device = device 62 | 63 | self._total_samples_per_class = (num_samples_per_class + num_val_samples) 64 | self._dataloader = self._get_cifar100_data_loader() 65 | 66 | self.input_size = (img_channel, img_side_len, img_side_len) 67 | self.output_size = self._num_classes_per_batch 68 | 69 | def _get_cifar100_data_loader(self): 70 | assert self._img_channel == 1 or self._img_channel == 3 71 | to_imgae = transforms.ToPILImage() 72 | resize = transforms.Resize(self._img_side_len, Image.LANCZOS) 73 | if self._img_channel == 1: 74 | img_transform = transforms.Compose( 75 | [to_imgae, resize, 76 | transforms.Grayscale(num_output_channels=1), 77 | transforms.ToTensor()]) 78 | else: 79 | img_transform = transforms.Compose( 80 | [to_imgae, resize, transforms.ToTensor()]) 81 | dset = Cifar100MAMLSplit(self._root, transform=img_transform, 82 | train=self._train, download=True, 83 | num_train_classes=self._num_train_classes) 84 | labels = dset._labels.numpy().tolist() 85 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch, 86 | self._total_samples_per_class, 87 | self._num_total_batches, self._train) 88 | 89 | batch_size = (self._num_classes_per_batch 90 | * self._total_samples_per_class 91 | * self._meta_batch_size) 92 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler, 93 | num_workers=self._num_workers, pin_memory=True) 94 | return loader 95 | 96 | def _make_single_batch(self, imgs, labels): 97 | """Split imgs and labels into train and validation set. 98 | TODO: check if this might become the bottleneck""" 99 | # relabel classes randomly 100 | new_labels = list(range(self._num_classes_per_batch)) 101 | random.shuffle(new_labels) 102 | labels = labels.tolist() 103 | label_set = set(labels) 104 | label_map = {label: new_labels[i] for i, label in enumerate(label_set)} 105 | labels = [label_map[l] for l in labels] 106 | 107 | label_indices = defaultdict(list) 108 | for i, label in enumerate(labels): 109 | label_indices[label].append(i) 110 | 111 | # assign samples to train and validation sets 112 | val_indices = [] 113 | train_indices = [] 114 | for label, indices in label_indices.items(): 115 | val_indices.extend(indices[:self._num_val_samples]) 116 | train_indices.extend(indices[self._num_val_samples:]) 117 | label_tensor = torch.tensor(labels, device=self._device) 118 | imgs = imgs.to(self._device) 119 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name) 120 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name) 121 | 122 | return train_task, val_task 123 | 124 | def _make_meta_batch(self, imgs, labels): 125 | batches = [] 126 | inner_batch_size = (self._total_samples_per_class 127 | * self._num_classes_per_batch) 128 | for i in range(0, len(imgs) - 1, inner_batch_size): 129 | batch_imgs = imgs[i:i+inner_batch_size] 130 | batch_labels = labels[i:i+inner_batch_size] 131 | batch = self._make_single_batch(batch_imgs, batch_labels) 132 | batches.append(batch) 133 | 134 | train_tasks, val_tasks = zip(*batches) 135 | 136 | return train_tasks, val_tasks 137 | 138 | def __iter__(self): 139 | for imgs, labels in iter(self._dataloader): 140 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels) 141 | yield train_tasks, val_tasks 142 | -------------------------------------------------------------------------------- /maml/datasets/metadataset.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple 2 | 3 | Task = namedtuple('Task', ['x', 'y', 'task_info']) -------------------------------------------------------------------------------- /maml/datasets/miniimagenet.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | from collections import defaultdict 4 | 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | 10 | from maml.sampler import ClassBalancedSampler 11 | from maml.datasets.metadataset import Task 12 | 13 | 14 | class MiniimagenetMAMLSplit(): 15 | def __init__(self, root, train=True, num_train_classes=1200, 16 | transform=None, target_transform=None, **kwargs): 17 | self.transform = transform 18 | self.target_transform = target_transform 19 | self.root = root + '/miniimagenet' 20 | 21 | self._train = train 22 | 23 | if self._train: 24 | all_character_dirs = glob.glob(self.root + '/train/**') 25 | self._characters = all_character_dirs 26 | else: 27 | all_character_dirs = glob.glob(self.root + '/test/**') 28 | self._characters = all_character_dirs 29 | 30 | self._character_images = [] 31 | for i, char_path in enumerate(self._characters): 32 | img_list = [(cp, i) for cp in glob.glob(char_path + '/*')] 33 | self._character_images.append(img_list) 34 | 35 | self._flat_character_images = sum(self._character_images, []) 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | Returns: 42 | tuple: (image, target) where target is index of the target 43 | character class. 44 | """ 45 | image_path, character_class = self._flat_character_images[index] 46 | image = Image.open(image_path, mode='r') 47 | 48 | if self.transform: 49 | image = self.transform(image) 50 | 51 | if self.target_transform: 52 | character_class = self.target_transform(character_class) 53 | 54 | return image, character_class 55 | 56 | 57 | class MiniimagenetMetaDataset(object): 58 | """ 59 | TODO: Check if the data loader is fast enough. 60 | Args: 61 | root: path to miniimagenet dataset 62 | img_side_len: images are scaled to this size 63 | num_classes_per_batch: number of classes to sample for each batch 64 | num_samples_per_class: number of samples to sample for each class 65 | for each batch. For K shot learning this should be K + number 66 | of validation samples 67 | num_total_batches: total number of tasks to generate 68 | train: whether to create data loader from the test or validation data 69 | """ 70 | def __init__(self, name='MiniImageNet', root='data', 71 | img_side_len=84, img_channel=3, 72 | num_classes_per_batch=5, num_samples_per_class=6, 73 | num_total_batches=200000, 74 | num_val_samples=1, meta_batch_size=40, train=True, 75 | num_train_classes=1100, num_workers=0, device='cpu'): 76 | self.name = name 77 | self._root = root 78 | self._img_side_len = img_side_len 79 | self._img_channel = img_channel 80 | self._num_classes_per_batch = num_classes_per_batch 81 | self._num_samples_per_class = num_samples_per_class 82 | self._num_total_batches = num_total_batches 83 | self._num_val_samples = num_val_samples 84 | self._meta_batch_size = meta_batch_size 85 | self._num_train_classes = num_train_classes 86 | self._train = train 87 | self._num_workers = num_workers 88 | self._device = device 89 | 90 | self._total_samples_per_class = ( 91 | num_samples_per_class + num_val_samples) 92 | self._dataloader = self._get_miniimagenet_data_loader() 93 | 94 | self.input_size = (img_channel, img_side_len, img_side_len) 95 | self.output_size = self._num_classes_per_batch 96 | 97 | def _get_miniimagenet_data_loader(self): 98 | assert self._img_channel == 1 or self._img_channel == 3 99 | resize = transforms.Resize( 100 | (self._img_side_len, self._img_side_len), Image.LANCZOS) 101 | if self._img_channel == 1: 102 | img_transform = transforms.Compose( 103 | [resize, transforms.Grayscale(num_output_channels=1), 104 | transforms.ToTensor()]) 105 | else: 106 | img_transform = transforms.Compose( 107 | [resize, transforms.ToTensor()]) 108 | dset = MiniimagenetMAMLSplit( 109 | self._root, transform=img_transform, train=self._train, 110 | download=True, num_train_classes=self._num_train_classes) 111 | _, labels = zip(*dset._flat_character_images) 112 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch, 113 | self._total_samples_per_class, 114 | self._num_total_batches, self._train) 115 | 116 | batch_size = (self._num_classes_per_batch * 117 | self._total_samples_per_class * 118 | self._meta_batch_size) 119 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler, 120 | num_workers=self._num_workers, pin_memory=True) 121 | return loader 122 | 123 | def _make_single_batch(self, imgs, labels): 124 | """Split imgs and labels into train and validation set. 125 | TODO: check if this might become the bottleneck""" 126 | # relabel classes randomly 127 | new_labels = list(range(self._num_classes_per_batch)) 128 | random.shuffle(new_labels) 129 | labels = labels.tolist() 130 | label_set = set(labels) 131 | label_map = {label: new_labels[i] for i, label in enumerate(label_set)} 132 | labels = [label_map[l] for l in labels] 133 | 134 | label_indices = defaultdict(list) 135 | for i, label in enumerate(labels): 136 | label_indices[label].append(i) 137 | 138 | # assign samples to train and validation sets 139 | val_indices = [] 140 | train_indices = [] 141 | for label, indices in label_indices.items(): 142 | val_indices.extend(indices[:self._num_val_samples]) 143 | train_indices.extend(indices[self._num_val_samples:]) 144 | label_tensor = torch.tensor(labels, device=self._device) 145 | imgs = imgs.to(self._device) 146 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name) 147 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name) 148 | 149 | return train_task, val_task 150 | 151 | def _make_meta_batch(self, imgs, labels): 152 | batches = [] 153 | inner_batch_size = ( 154 | self._total_samples_per_class * self._num_classes_per_batch) 155 | for i in range(0, len(imgs) - 1, inner_batch_size): 156 | batch_imgs = imgs[i:i+inner_batch_size] 157 | batch_labels = labels[i:i+inner_batch_size] 158 | batch = self._make_single_batch(batch_imgs, batch_labels) 159 | batches.append(batch) 160 | 161 | train_tasks, val_tasks = zip(*batches) 162 | 163 | return train_tasks, val_tasks 164 | 165 | def __iter__(self): 166 | for imgs, labels in iter(self._dataloader): 167 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels) 168 | yield train_tasks, val_tasks 169 | -------------------------------------------------------------------------------- /maml/datasets/multimodal_few_shot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import chain 3 | from maml.datasets.metadataset import Task 4 | 5 | 6 | class MultimodalFewShotDataset(object): 7 | 8 | def __init__(self, datasets, num_total_batches, 9 | name='MultimodalFewShot', 10 | mix_meta_batch=True, mix_mini_batch=False, 11 | train=True, verbose=False, txt_file=None): 12 | self._datasets = datasets 13 | self._num_total_batches = num_total_batches 14 | self.name = name 15 | self.num_dataset = len(datasets) 16 | self.dataset_names = [dataset.name for dataset in self._datasets] 17 | self._meta_batch_size = datasets[0]._meta_batch_size 18 | self._mix_meta_batch = mix_meta_batch 19 | self._mix_mini_batch = mix_mini_batch 20 | self._train = train 21 | self._verbose = verbose 22 | self._txt_file = open(txt_file, 'w') if not txt_file is None else None 23 | 24 | # make sure all input/output sizes match 25 | input_size_list = [dataset.input_size for dataset in self._datasets] 26 | assert input_size_list.count(input_size_list[0]) == len(input_size_list) 27 | output_size_list = [dataset.output_size for dataset in self._datasets] 28 | assert output_size_list.count(output_size_list[0]) == len(output_size_list) 29 | self.input_size = datasets[0].input_size 30 | self.output_size = datasets[0].output_size 31 | 32 | # build iterators 33 | self._datasets_iter = [iter(dataset) for dataset in self._datasets] 34 | self._iter_index = 0 35 | 36 | # print info 37 | print('Multimodal Few Shot Datasets: {}'.format(' '.join(self.dataset_names))) 38 | print('mix meta batch: {}'.format(mix_meta_batch)) 39 | print('mix mini batch: {}'.format(mix_mini_batch)) 40 | 41 | def __next__(self): 42 | if self.n < self._num_total_batches: 43 | if not self._mix_meta_batch and not self._mix_mini_batch: 44 | dataset_index = np.random.randint(len(self._datasets)) 45 | if self._verbose: 46 | print('Sample from: {}'.format(self._datasets[dataset_index].name)) 47 | train_tasks, val_tasks = next(self._datasets_iter[dataset_index]) 48 | return train_tasks, val_tasks 49 | else: 50 | # get all tasks 51 | tasks = [] 52 | all_train_tasks = [] 53 | all_val_tasks = [] 54 | for dataset_iter in self._datasets_iter: 55 | train_tasks, val_tasks = next(dataset_iter) 56 | all_train_tasks.extend(train_tasks) 57 | all_val_tasks.extend(val_tasks) 58 | 59 | if not self._mix_mini_batch: 60 | # mix them to obtain a meta batch 61 | """ 62 | # randomly sample task 63 | dataset_indexes = np.random.choice( 64 | len(all_train_tasks), size=self._meta_batch_size, replace=False) 65 | """ 66 | # balancedly sample from all datasets 67 | dataset_indexes = [] 68 | if self._train: 69 | dataset_start_idx = np.random.randint(0, self.num_dataset) 70 | else: 71 | dataset_start_idx = (self._iter_index + self._meta_batch_size) % self.num_dataset 72 | self._iter_index += self._meta_batch_size 73 | self._iter_index = self._iter_index % self.num_dataset 74 | 75 | for i in range(self._meta_batch_size): 76 | dataset_indexes.append( 77 | np.random.randint(0, self._meta_batch_size)+ 78 | ((i+dataset_start_idx)%self.num_dataset)*self._meta_batch_size) 79 | 80 | train_tasks = [] 81 | val_tasks = [] 82 | dataset_names = [] 83 | for dataset_index in dataset_indexes: 84 | train_tasks.append(all_train_tasks[dataset_index]) 85 | val_tasks.append(all_val_tasks[dataset_index]) 86 | dataset_names.append(self._datasets[dataset_index//self._meta_batch_size].name) 87 | if self._verbose: 88 | print('Sample from: {} (indexes: {})'.format( 89 | [name for name in dataset_names], dataset_indexes)) 90 | if self._txt_file is not None: 91 | for name in dataset_names: 92 | self._txt_file.write(name+'\n') 93 | return train_tasks, val_tasks 94 | else: 95 | # mix them to obtain a mini batch and make a meta batch 96 | raise NotImplementedError 97 | else: 98 | raise StopIteration 99 | 100 | def __iter__(self): 101 | self.n = 0 102 | return self 103 | -------------------------------------------------------------------------------- /maml/datasets/omniglot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | from collections import defaultdict 5 | 6 | import torch 7 | import numpy as np 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from torchvision.datasets import Omniglot 12 | from torchvision.datasets.utils import list_files 13 | 14 | from maml.sampler import ClassBalancedSampler 15 | from maml.datasets.metadataset import Task 16 | 17 | 18 | class OmniglotMAMLSplit(Omniglot): 19 | """Implements similar train / test split for Omniglot as 20 | https://github.com/cbfinn/maml/blob/master/data_generator.py 21 | 22 | Uses torchvision.datasets.Omniglot for downloading and checking 23 | dataset integrity. 24 | """ 25 | def __init__(self, root, train=True, num_train_classes=1100, **kwargs): 26 | super(OmniglotMAMLSplit, self).__init__(root, download=True, 27 | background=True, **kwargs) 28 | 29 | self._train = train 30 | self._num_train_classes = num_train_classes 31 | 32 | # download testing data and test integrity 33 | self.background = False 34 | self.download() 35 | if not self._check_integrity(): 36 | raise RuntimeError('Dataset not found or corrupted') 37 | 38 | all_character_dirs = glob.glob(self.root + '/**/**/**') 39 | if self._train: 40 | self._characters = all_character_dirs[:self._num_train_classes] 41 | else: 42 | self._characters = all_character_dirs[self._num_train_classes:] 43 | 44 | self._character_images = [] 45 | for i, char_path in enumerate(self._characters): 46 | img_list = [(cp, i) for cp in glob.glob(char_path + '/*')] 47 | self._character_images.append(img_list) 48 | 49 | self._flat_character_images = sum(self._character_images, []) 50 | 51 | def __getitem__(self, index): 52 | """ 53 | Args: 54 | index (int): Index 55 | Returns: 56 | tuple: (image, target) where target is index of the target 57 | character class. 58 | """ 59 | image_path, character_class = self._flat_character_images[index] 60 | image = Image.open(image_path, mode='r').convert('L') 61 | 62 | if self.transform: 63 | image = self.transform(image) 64 | 65 | if self.target_transform: 66 | character_class = self.target_transform(character_class) 67 | 68 | return image, character_class 69 | 70 | 71 | class OmniglotMetaDataset(object): 72 | """ 73 | TODO: Check if the data loader is fast enough. 74 | Args: 75 | root: path to omniglot dataset 76 | img_side_len: images are scaled to this size 77 | num_classes_per_batch: number of classes to sample for each batch 78 | num_samples_per_class: number of samples to sample for each class 79 | for each batch. For K shot learning this should be K + number 80 | of validation samples 81 | num_total_batches: total number of tasks to generate 82 | train: whether to create data loader from the test or validation data 83 | """ 84 | def __init__(self, name='Omniglot', root='data', 85 | img_side_len=28, img_channel=1, 86 | num_classes_per_batch=5, num_samples_per_class=6, 87 | num_total_batches=200000, 88 | num_val_samples=1, meta_batch_size=40, train=True, 89 | num_train_classes=1100, num_workers=0, device='cpu'): 90 | self.name = name 91 | self._root = root 92 | self._img_side_len = img_side_len 93 | self._img_channel = img_channel 94 | self._num_classes_per_batch = num_classes_per_batch 95 | self._num_samples_per_class = num_samples_per_class 96 | self._num_total_batches = num_total_batches 97 | self._num_val_samples = num_val_samples 98 | self._meta_batch_size = meta_batch_size 99 | self._num_train_classes = num_train_classes 100 | self._train = train 101 | self._num_workers = num_workers 102 | self._device = device 103 | 104 | self._total_samples_per_class = ( 105 | num_samples_per_class + num_val_samples) 106 | self._dataloader = self._get_omniglot_data_loader() 107 | 108 | self.input_size = (img_channel, img_side_len, img_side_len) 109 | self.output_size = self._num_classes_per_batch 110 | 111 | def _get_omniglot_data_loader(self): 112 | assert self._img_channel == 1 or self._img_channel == 3 113 | resize = transforms.Resize(self._img_side_len, Image.LANCZOS) 114 | invert = transforms.Lambda(lambda x: 1.0 - x) 115 | if self._img_channel > 1: 116 | # tile the image 117 | tile = transforms.Lambda(lambda x: x.repeat(self._img_channel, 1, 1)) 118 | img_transform = transforms.Compose( 119 | [resize, transforms.ToTensor(), invert, tile]) 120 | else: 121 | img_transform = transforms.Compose( 122 | [resize, transforms.ToTensor(), invert]) 123 | dset = OmniglotMAMLSplit(self._root, transform=img_transform, 124 | train=self._train, 125 | num_train_classes=self._num_train_classes) 126 | _, labels = zip(*dset._flat_character_images) 127 | sampler = ClassBalancedSampler(labels, self._num_classes_per_batch, 128 | self._total_samples_per_class, 129 | self._num_total_batches, self._train) 130 | 131 | batch_size = (self._num_classes_per_batch 132 | * self._total_samples_per_class 133 | * self._meta_batch_size) 134 | loader = DataLoader(dset, batch_size=batch_size, sampler=sampler, 135 | num_workers=self._num_workers, pin_memory=True) 136 | return loader 137 | 138 | def _make_single_batch(self, imgs, labels): 139 | """Split imgs and labels into train and validation set. 140 | TODO: check if this might become the bottleneck""" 141 | # relabel classes randomly 142 | new_labels = list(range(self._num_classes_per_batch)) 143 | random.shuffle(new_labels) 144 | labels = labels.tolist() 145 | label_set = set(labels) 146 | label_map = {label: i for i, label in zip(new_labels, label_set)} 147 | labels = [label_map[l] for l in labels] 148 | 149 | label_indices = defaultdict(list) 150 | for i, label in enumerate(labels): 151 | label_indices[label].append(i) 152 | 153 | # rotate randomly to create new classes 154 | # TODO: move this to torch once supported. 155 | for label, indices in label_indices.items(): 156 | rotation = np.random.randint(4) 157 | for i in range(len(indices)): 158 | img = imgs[indices[i]].numpy() 159 | # copy here for contiguity 160 | img = np.copy(np.rot90(img, k=rotation, axes=(1,2))) 161 | imgs[indices[i]] = torch.from_numpy(img) 162 | 163 | # assign samples to train and validation sets 164 | val_indices = [] 165 | train_indices = [] 166 | for label, indices in label_indices.items(): 167 | val_indices.extend(indices[:self._num_val_samples]) 168 | train_indices.extend(indices[self._num_val_samples:]) 169 | label_tensor = torch.tensor(labels, device=self._device) 170 | imgs = imgs.to(self._device) 171 | train_task = Task(imgs[train_indices], label_tensor[train_indices], self.name) 172 | val_task = Task(imgs[val_indices], label_tensor[val_indices], self.name) 173 | 174 | return train_task, val_task 175 | 176 | def _make_meta_batch(self, imgs, labels): 177 | batches = [] 178 | inner_batch_size = (self._total_samples_per_class 179 | * self._num_classes_per_batch) 180 | for i in range(0, len(imgs) - 1, inner_batch_size): 181 | batch_imgs = imgs[i:i+inner_batch_size] 182 | batch_labels = labels[i:i+inner_batch_size] 183 | batch = self._make_single_batch(batch_imgs, batch_labels) 184 | batches.append(batch) 185 | 186 | train_tasks, val_tasks = zip(*batches) 187 | 188 | return train_tasks, val_tasks 189 | 190 | def __iter__(self): 191 | for imgs, labels in iter(self._dataloader): 192 | train_tasks, val_tasks = self._make_meta_batch(imgs, labels) 193 | yield train_tasks, val_tasks 194 | -------------------------------------------------------------------------------- /maml/metalearner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from torch.nn.utils.clip_grad import clip_grad_norm_ 5 | from maml.utils import accuracy 6 | 7 | def get_grad_norm(parameters, norm_type=2): 8 | if isinstance(parameters, torch.Tensor): 9 | parameters = [parameters] 10 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 11 | norm_type = float(norm_type) 12 | total_norm = 0 13 | for p in parameters: 14 | param_norm = p.grad.data.norm(norm_type) 15 | total_norm += param_norm.item() ** norm_type 16 | total_norm = total_norm ** (1. / norm_type) 17 | 18 | return total_norm 19 | 20 | class MetaLearner(object): 21 | def __init__(self, model, embedding_model, optimizers, fast_lr, loss_func, 22 | first_order, num_updates, inner_loop_grad_clip, 23 | collect_accuracies, device, alternating=False, 24 | embedding_schedule=10, classifier_schedule=10, 25 | embedding_grad_clip=0): 26 | self._model = model 27 | self._embedding_model = embedding_model 28 | self._fast_lr = fast_lr 29 | self._optimizers = optimizers 30 | self._loss_func = loss_func 31 | self._first_order = first_order 32 | self._num_updates = num_updates 33 | self._inner_loop_grad_clip = inner_loop_grad_clip 34 | self._collect_accuracies = collect_accuracies 35 | self._device = device 36 | self._alternating = alternating 37 | self._alternating_schedules = (classifier_schedule, embedding_schedule) 38 | self._alternating_count = 0 39 | self._alternating_index = 1 40 | self._embedding_grad_clip = embedding_grad_clip 41 | self._grads_mean = [] 42 | 43 | self.to(device) 44 | 45 | self._reset_measurements() 46 | 47 | def _reset_measurements(self): 48 | self._count_iters = 0.0 49 | self._cum_loss = 0.0 50 | self._cum_accuracy = 0.0 51 | 52 | def _update_measurements(self, task, loss, preds): 53 | self._count_iters += 1.0 54 | self._cum_loss += loss.data.cpu().numpy() 55 | if self._collect_accuracies: 56 | self._cum_accuracy += accuracy( 57 | preds, task.y).data.cpu().numpy() 58 | 59 | def _pop_measurements(self): 60 | measurements = {} 61 | loss = self._cum_loss / self._count_iters 62 | measurements['loss'] = loss 63 | if self._collect_accuracies: 64 | accuracy = self._cum_accuracy / self._count_iters 65 | measurements['accuracy'] = accuracy 66 | self._reset_measurements() 67 | return measurements 68 | 69 | def measure(self, tasks, train_tasks=None, adapted_params_list=None, 70 | embeddings_list=None): 71 | """Measures performance on tasks. Either train_tasks has to be a list 72 | of training task for computing embeddings, or adapted_params_list and 73 | embeddings_list have to contain adapted_params and embeddings""" 74 | if adapted_params_list is None: 75 | adapted_params_list = [None] * len(tasks) 76 | if embeddings_list is None: 77 | embeddings_list = [None] * len(tasks) 78 | for i in range(len(tasks)): 79 | params = adapted_params_list[i] 80 | if params is None: 81 | params = self._model.param_dict 82 | embeddings = embeddings_list[i] 83 | task = tasks[i] 84 | preds = self._model(task, params=params, embeddings=embeddings) 85 | loss = self._loss_func(preds, task.y) 86 | self._update_measurements(task, loss, preds) 87 | 88 | measurements = self._pop_measurements() 89 | return measurements 90 | 91 | def measure_each(self, tasks, train_tasks=None, adapted_params_list=None, 92 | embeddings_list=None): 93 | """Measures performance on tasks. Either train_tasks has to be a list 94 | of training task for computing embeddings, or adapted_params_list and 95 | embeddings_list have to contain adapted_params and embeddings""" 96 | """Return a list of losses and accuracies""" 97 | if adapted_params_list is None: 98 | adapted_params_list = [None] * len(tasks) 99 | if embeddings_list is None: 100 | embeddings_list = [None] * len(tasks) 101 | accuracies = [] 102 | for i in range(len(tasks)): 103 | params = adapted_params_list[i] 104 | if params is None: 105 | params = self._model.param_dict 106 | embeddings = embeddings_list[i] 107 | task = tasks[i] 108 | preds = self._model(task, params=params, embeddings=embeddings) 109 | pred_y = np.argmax(preds.data.cpu().numpy(), axis=-1) 110 | accuracy = np.mean( 111 | task.y.data.cpu().numpy() == 112 | np.argmax(preds.data.cpu().numpy(), axis=-1)) 113 | accuracies.append(accuracy) 114 | 115 | return accuracies 116 | 117 | def update_params(self, loss, params): 118 | """Apply one step of gradient descent on the loss function `loss`, 119 | with step-size `self._fast_lr`, and returns the updated parameters. 120 | """ 121 | create_graph = not self._first_order 122 | grads = torch.autograd.grad(loss, params.values(), 123 | create_graph=create_graph, allow_unused=True) 124 | for (name, param), grad in zip(params.items(), grads): 125 | if self._inner_loop_grad_clip > 0 and grad is not None: 126 | grad = grad.clamp(min=-self._inner_loop_grad_clip, 127 | max=self._inner_loop_grad_clip) 128 | if grad is not None: 129 | params[name] = param - self._fast_lr * grad 130 | 131 | return params 132 | 133 | def adapt(self, train_tasks): 134 | adapted_params = [] 135 | embeddings_list = [] 136 | 137 | for task in train_tasks: 138 | params = self._model.param_dict 139 | embeddings = None 140 | if self._embedding_model: 141 | embeddings = self._embedding_model(task) 142 | for i in range(self._num_updates): 143 | preds = self._model(task, params=params, embeddings=embeddings) 144 | loss = self._loss_func(preds, task.y) 145 | params = self.update_params(loss, params=params) 146 | if i == 0: 147 | self._update_measurements(task, loss, preds) 148 | adapted_params.append(params) 149 | embeddings_list.append(embeddings) 150 | 151 | measurements = self._pop_measurements() 152 | return measurements, adapted_params, embeddings_list 153 | 154 | def step(self, adapted_params_list, embeddings_list, val_tasks, 155 | is_training): 156 | for optimizer in self._optimizers: 157 | optimizer.zero_grad() 158 | post_update_losses = [] 159 | 160 | for adapted_params, embeddings, task in zip( 161 | adapted_params_list, embeddings_list, val_tasks): 162 | preds = self._model(task, params=adapted_params, 163 | embeddings=embeddings) 164 | loss = self._loss_func(preds, task.y) 165 | post_update_losses.append(loss) 166 | self._update_measurements(task, loss, preds) 167 | 168 | mean_loss = torch.mean(torch.stack(post_update_losses)) 169 | if is_training: 170 | mean_loss.backward() 171 | if self._alternating: 172 | self._optimizers[self._alternating_index].step() 173 | self._alternating_count += 1 174 | if self._alternating_count % self._alternating_schedules[self._alternating_index] == 0: 175 | self._alternating_index = (1 - self._alternating_index) 176 | self._alternating_count = 0 177 | else: 178 | self._optimizers[0].step() 179 | if len(self._optimizers) > 1: 180 | if self._embedding_grad_clip > 0: 181 | _grad_norm = clip_grad_norm_(self._embedding_model.parameters(), self._embedding_grad_clip) 182 | else: 183 | _grad_norm = get_grad_norm(self._embedding_model.parameters()) 184 | # grad_norm 185 | self._grads_mean.append(_grad_norm) 186 | self._optimizers[1].step() 187 | 188 | measurements = self._pop_measurements() 189 | return measurements 190 | 191 | def to(self, device, **kwargs): 192 | self._device = device 193 | self._model.to(device, **kwargs) 194 | if self._embedding_model: 195 | self._embedding_model.to(device, **kwargs) 196 | 197 | def state_dict(self): 198 | state = { 199 | 'model_state_dict': self._model.state_dict(), 200 | 'optimizers': [ optimizer.state_dict() for optimizer in self._optimizers ] 201 | } 202 | if self._embedding_model: 203 | state.update( 204 | {'embedding_model_state_dict': 205 | self._embedding_model.state_dict()}) 206 | return state 207 | -------------------------------------------------------------------------------- /maml/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaohua0116/MMAML-Classification/bdf1a93e798ab81619563038b95a3c5aa18717e0/maml/models/__init__.py -------------------------------------------------------------------------------- /maml/models/conv_embedding_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class ConvEmbeddingModel(torch.nn.Module): 8 | def __init__(self, input_size, output_size, embedding_dims, 9 | hidden_size=128, num_layers=1, 10 | convolutional=False, num_conv=4, num_channels=32, num_channels_max=256, 11 | rnn_aggregation=False, linear_before_rnn=False, 12 | embedding_pooling='max', batch_norm=True, avgpool_after_conv=True, 13 | num_sample_embedding=0, sample_embedding_file='embedding.hdf5', 14 | img_size=(1, 28, 28), verbose=False): 15 | 16 | super(ConvEmbeddingModel, self).__init__() 17 | self._input_size = input_size 18 | self._output_size = output_size 19 | self._hidden_size = hidden_size 20 | self._num_layers = num_layers 21 | self._embedding_dims = embedding_dims 22 | self._bidirectional = True 23 | self._device = 'cpu' 24 | self._convolutional = convolutional 25 | self._num_conv = num_conv 26 | self._num_channels = num_channels 27 | self._num_channels_max = num_channels_max 28 | self._batch_norm = batch_norm 29 | self._img_size = img_size 30 | self._rnn_aggregation = rnn_aggregation 31 | self._embedding_pooling = embedding_pooling 32 | self._linear_before_rnn = linear_before_rnn 33 | self._embeddings_array = [] 34 | self._num_sample_embedding = num_sample_embedding 35 | self._sample_embedding_file = sample_embedding_file 36 | self._avgpool_after_conv = avgpool_after_conv 37 | self._reuse = False 38 | self._verbose = verbose 39 | 40 | if self._convolutional: 41 | conv_list = OrderedDict([]) 42 | num_ch = [self._img_size[0]] + [self._num_channels*2**i for i in range(self._num_conv)] 43 | num_ch = [min(num_channels_max, ch) for ch in num_ch] 44 | for i in range(self._num_conv): 45 | conv_list.update({ 46 | 'conv{}'.format(i+1): 47 | torch.nn.Conv2d(num_ch[i], num_ch[i+1], 48 | (3, 3), stride=2, padding=1)}) 49 | if self._batch_norm: 50 | conv_list.update({ 51 | 'bn{}'.format(i+1): 52 | torch.nn.BatchNorm2d(num_ch[i+1], momentum=0.001)}) 53 | conv_list.update({'relu{}'.format(i+1): torch.nn.ReLU(inplace=True)}) 54 | self.conv = torch.nn.Sequential(conv_list) 55 | self._num_layer_per_conv = len(conv_list) // self._num_conv 56 | 57 | if self._linear_before_rnn: 58 | linear_input_size = self.compute_input_size( 59 | 1, 3, 2, self.conv[self._num_layer_per_conv*(self._num_conv-1)].out_channels) 60 | rnn_input_size = 128 61 | else: 62 | if self._avgpool_after_conv: 63 | rnn_input_size = self.conv[self._num_layer_per_conv*(self._num_conv-1)].out_channels 64 | else: 65 | rnn_input_size = self.compute_input_size( 66 | 1, 3, 2, self.conv[self._num_layer_per_conv*(self._num_conv-1)].out_channels) 67 | else: 68 | rnn_input_size = int(input_size) 69 | 70 | if self._rnn_aggregation: 71 | if self._linear_before_rnn: 72 | self.linear = torch.nn.Linear(linear_input_size, rnn_input_size) 73 | self.relu_after_linear = torch.nn.ReLU(inplace=True) 74 | self.rnn = torch.nn.GRU(rnn_input_size, hidden_size, 75 | num_layers, bidirectional=self._bidirectional) 76 | embedding_input_size = hidden_size*(2 if self._bidirectional else 1) 77 | else: 78 | self.rnn = None 79 | embedding_input_size = hidden_size 80 | self.linear = torch.nn.Linear(rnn_input_size, embedding_input_size) 81 | self.relu_after_linear = torch.nn.ReLU(inplace=True) 82 | 83 | self._embeddings = torch.nn.ModuleList() 84 | for dim in embedding_dims: 85 | self._embeddings.append(torch.nn.Linear(embedding_input_size, dim)) 86 | 87 | def compute_input_size(self, p, k, s, ch): 88 | current_img_size = self._img_size[1] 89 | for _ in range(self._num_conv): 90 | current_img_size = (current_img_size+2*p-k)//s+1 91 | return ch * int(current_img_size) ** 2 92 | 93 | def forward(self, task, params=None): 94 | if not self._reuse and self._verbose: print('='*8 + ' Emb Model ' + '='*8) 95 | if params is None: 96 | params = OrderedDict(self.named_parameters()) 97 | 98 | if self._convolutional: 99 | x = task.x 100 | if not self._reuse and self._verbose: print('input size: {}'.format(x.size())) 101 | for layer_name, layer in self.conv.named_children(): 102 | weight = params.get('conv.' + layer_name + '.weight', None) 103 | bias = params.get('conv.' + layer_name + '.bias', None) 104 | if 'conv' in layer_name: 105 | x = F.conv2d(x, weight=weight, bias=bias, stride=2, padding=1) 106 | elif 'relu' in layer_name: 107 | x = F.relu(x) 108 | elif 'bn' in layer_name: 109 | x = F.batch_norm(x, weight=weight, bias=bias, 110 | running_mean=layer.running_mean, 111 | running_var=layer.running_var, 112 | training=True) 113 | if not self._reuse and self._verbose: print('{}: {}'.format(layer_name, x.size())) 114 | if self._avgpool_after_conv: 115 | x = x.view(x.size(0), x.size(1), -1) 116 | if not self._reuse and self._verbose: print('reshape to: {}'.format(x.size())) 117 | x = torch.mean(x, dim=2) 118 | if not self._reuse and self._verbose: print('reduce mean: {}'.format(x.size())) 119 | 120 | else: 121 | x = task.x.view(task.x.size(0), -1) 122 | if not self._reuse and self._verbose: print('flatten: {}'.format(x.size())) 123 | else: 124 | x = task.x.view(task.x.size(0), -1) 125 | if not self._reuse and self._verbose: print('flatten: {}'.format(x.size())) 126 | 127 | if self._rnn_aggregation: 128 | # LSTM input dimensions are seq_len, batch, input_size 129 | batch_size = 1 130 | h0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1), 131 | batch_size, self._hidden_size, device=self._device) 132 | if self._linear_before_rnn: 133 | x = F.relu(self.linear(x)) 134 | inputs = x.view(x.size(0), 1, -1) 135 | output, hn = self.rnn(inputs, h0) 136 | if self._bidirectional: 137 | N, B, H = output.shape 138 | output = output.view(N, B, 2, H // 2) 139 | embedding_input = torch.cat([output[-1, :, 0], output[0, :, 1]], dim=1) 140 | 141 | else: 142 | inputs = F.relu(self.linear(x).view(1, x.size(0), -1).transpose(1, 2)) 143 | if not self._reuse and self._verbose: print('fc: {}'.format(inputs.size())) 144 | if self._embedding_pooling == 'max': 145 | embedding_input = F.max_pool1d(inputs, x.size(0)).view(1, -1) 146 | elif self._embedding_pooling == 'avg': 147 | embedding_input = F.avg_pool1d(inputs, x.size(0)).view(1, -1) 148 | else: 149 | raise NotImplementedError 150 | if not self._reuse and self._verbose: print('reshape after {}pool: {}'.format( 151 | self._embedding_pooling, embedding_input.size())) 152 | 153 | # randomly sample embedding vectors 154 | if not self._num_sample_embedding == 0: 155 | self._embeddings_array.append(embedding_input.cpu().clone().detach().numpy()) 156 | if len(self._embeddings_array) >= self._num_sample_embedding: 157 | if self._sample_embedding_file.split('.')[-1] == 'hdf5': 158 | import h5py 159 | f = h5py.File(self._sample_embedding_file, 'w') 160 | f['embedding'] = np.squeeze(np.stack(self._embeddings_array)) 161 | f.close() 162 | elif self._sample_embedding_file.split('.')[-1] == 'pt': 163 | torch.save(np.squeeze(np.stack(self._embeddings_array)), 164 | self._sample_embedding_file) 165 | else: 166 | raise NotImplementedError 167 | 168 | out_embeddings = [] 169 | for i, embedding in enumerate(self._embeddings): 170 | embedding_vec = embedding(embedding_input) 171 | out_embeddings.append(embedding_vec) 172 | if not self._reuse and self._verbose: print('emb vec {} size: {}'.format( 173 | i+1, embedding_vec.size())) 174 | if not self._reuse and self._verbose: print('='*27) 175 | self._reuse = True 176 | return out_embeddings 177 | 178 | def to(self, device, **kwargs): 179 | self._device = device 180 | super(ConvEmbeddingModel, self).to(device, **kwargs) 181 | -------------------------------------------------------------------------------- /maml/models/conv_net.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from maml.models.model import Model 7 | 8 | 9 | def weight_init(module): 10 | if (isinstance(module, torch.nn.Linear) 11 | or isinstance(module, torch.nn.Conv2d)): 12 | torch.nn.init.xavier_uniform_(module.weight) 13 | module.bias.data.zero_() 14 | 15 | 16 | class ConvModel(Model): 17 | """ 18 | NOTE: difference to tf implementation: batch norm scaling is enabled here 19 | TODO: enable 'non-transductive' setting as per 20 | https://arxiv.org/abs/1803.02999 21 | """ 22 | def __init__(self, input_channels, output_size, num_channels=64, 23 | kernel_size=3, padding=1, nonlinearity=F.relu, 24 | use_max_pool=False, img_side_len=28, verbose=False): 25 | super(ConvModel, self).__init__() 26 | self._input_channels = input_channels 27 | self._output_size = output_size 28 | self._num_channels = num_channels 29 | self._kernel_size = kernel_size 30 | self._nonlinearity = nonlinearity 31 | self._use_max_pool = use_max_pool 32 | self._padding = padding 33 | self._bn_affine = False 34 | self._reuse = False 35 | self._verbose = verbose 36 | 37 | if self._use_max_pool: 38 | self._conv_stride = 1 39 | self._features_size = 1 40 | self.features = torch.nn.Sequential(OrderedDict([ 41 | ('layer1_conv', torch.nn.Conv2d(self._input_channels, 42 | self._num_channels, 43 | self._kernel_size, 44 | stride=self._conv_stride, 45 | padding=self._padding)), 46 | ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels, 47 | affine=self._bn_affine, 48 | momentum=0.001)), 49 | ('layer1_max_pool', torch.nn.MaxPool2d(kernel_size=2, 50 | stride=2)), 51 | ('layer1_relu', torch.nn.ReLU(inplace=True)), 52 | ('layer2_conv', torch.nn.Conv2d(self._num_channels, 53 | self._num_channels*2, 54 | self._kernel_size, 55 | stride=self._conv_stride, 56 | padding=self._padding)), 57 | ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2, 58 | affine=self._bn_affine, 59 | momentum=0.001)), 60 | ('layer2_max_pool', torch.nn.MaxPool2d(kernel_size=2, 61 | stride=2)), 62 | ('layer2_relu', torch.nn.ReLU(inplace=True)), 63 | ('layer3_conv', torch.nn.Conv2d(self._num_channels*2, 64 | self._num_channels*4, 65 | self._kernel_size, 66 | stride=self._conv_stride, 67 | padding=self._padding)), 68 | ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4, 69 | affine=self._bn_affine, 70 | momentum=0.001)), 71 | ('layer3_max_pool', torch.nn.MaxPool2d(kernel_size=2, 72 | stride=2)), 73 | ('layer3_relu', torch.nn.ReLU(inplace=True)), 74 | ('layer4_conv', torch.nn.Conv2d(self._num_channels*4, 75 | self._num_channels*8, 76 | self._kernel_size, 77 | stride=self._conv_stride, 78 | padding=self._padding)), 79 | ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8, 80 | affine=self._bn_affine, 81 | momentum=0.001)), 82 | ('layer4_max_pool', torch.nn.MaxPool2d(kernel_size=2, 83 | stride=2)), 84 | ('layer4_relu', torch.nn.ReLU(inplace=True)), 85 | ])) 86 | else: 87 | self._conv_stride = 2 88 | self._features_size = (img_side_len // 14)**2 89 | self.features = torch.nn.Sequential(OrderedDict([ 90 | ('layer1_conv', torch.nn.Conv2d(self._input_channels, 91 | self._num_channels, 92 | self._kernel_size, 93 | stride=self._conv_stride, 94 | padding=self._padding)), 95 | ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels, 96 | affine=self._bn_affine, 97 | momentum=0.001)), 98 | ('layer1_relu', torch.nn.ReLU(inplace=True)), 99 | ('layer2_conv', torch.nn.Conv2d(self._num_channels, 100 | self._num_channels*2, 101 | self._kernel_size, 102 | stride=self._conv_stride, 103 | padding=self._padding)), 104 | ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2, 105 | affine=self._bn_affine, 106 | momentum=0.001)), 107 | ('layer2_relu', torch.nn.ReLU(inplace=True)), 108 | ('layer3_conv', torch.nn.Conv2d(self._num_channels*2, 109 | self._num_channels*4, 110 | self._kernel_size, 111 | stride=self._conv_stride, 112 | padding=self._padding)), 113 | ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4, 114 | affine=self._bn_affine, 115 | momentum=0.001)), 116 | ('layer3_relu', torch.nn.ReLU(inplace=True)), 117 | ('layer4_conv', torch.nn.Conv2d(self._num_channels*4, 118 | self._num_channels*8, 119 | self._kernel_size, 120 | stride=self._conv_stride, 121 | padding=self._padding)), 122 | ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8, 123 | affine=self._bn_affine, 124 | momentum=0.001)), 125 | ('layer4_relu', torch.nn.ReLU(inplace=True)), 126 | ])) 127 | 128 | self.classifier = torch.nn.Sequential(OrderedDict([ 129 | ('fully_connected', torch.nn.Linear(self._num_channels*8, 130 | self._output_size)) 131 | ])) 132 | self.apply(weight_init) 133 | 134 | def forward(self, task, params=None, embeddings=None): 135 | if not self._reuse and self._verbose: print('='*10 + ' Model ' + '='*10) 136 | if params is None: 137 | params = OrderedDict(self.named_parameters()) 138 | 139 | x = task.x 140 | if not self._reuse and self._verbose: print('input size: {}'.format(x.size())) 141 | for layer_name, layer in self.features.named_children(): 142 | weight = params.get('features.' + layer_name + '.weight', None) 143 | bias = params.get('features.' + layer_name + '.bias', None) 144 | if 'conv' in layer_name: 145 | x = F.conv2d(x, weight=weight, bias=bias, 146 | stride=self._conv_stride, padding=self._padding) 147 | elif 'bn' in layer_name: 148 | x = F.batch_norm(x, weight=weight, bias=bias, 149 | running_mean=layer.running_mean, 150 | running_var=layer.running_var, 151 | training=True) 152 | elif 'max_pool' in layer_name: 153 | x = F.max_pool2d(x, kernel_size=2, stride=2) 154 | elif 'relu' in layer_name: 155 | x = F.relu(x) 156 | elif 'fully_connected' in layer_name: 157 | break 158 | else: 159 | raise ValueError('Unrecognized layer {}'.format(layer_name)) 160 | if not self._reuse and self._verbose: print('{}: {}'.format(layer_name, x.size())) 161 | 162 | # in maml network the conv maps are average pooled 163 | x = x.view(x.size(0), self._num_channels*8, self._features_size) 164 | if not self._reuse and self._verbose: print('reshape to: {}'.format(x.size())) 165 | x = torch.mean(x, dim=2) 166 | if not self._reuse and self._verbose: print('reduce mean: {}'.format(x.size())) 167 | logits = F.linear( 168 | x, weight=params['classifier.fully_connected.weight'], 169 | bias=params['classifier.fully_connected.bias']) 170 | if not self._reuse and self._verbose: print('logits size: {}'.format(logits.size())) 171 | if not self._reuse and self._verbose: print('='*27) 172 | self._reuse = True 173 | return logits 174 | -------------------------------------------------------------------------------- /maml/models/fully_connected.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from maml.models.model import Model 7 | 8 | def weight_init(module): 9 | if isinstance(module, torch.nn.Linear): 10 | torch.nn.init.normal_(module.weight, mean=0, std=0.01) 11 | module.bias.data.zero_() 12 | 13 | 14 | class FullyConnectedModel(Model): 15 | def __init__(self, input_size, output_size, hidden_sizes=(), 16 | nonlinearity=F.relu, disable_norm=False, 17 | bias_transformation_size=0): 18 | super(FullyConnectedModel, self).__init__() 19 | self.hidden_sizes = hidden_sizes 20 | self.nonlinearity = nonlinearity 21 | self.num_layers = len(hidden_sizes) + 1 22 | self.disable_norm = disable_norm 23 | self.bias_transformation_size = bias_transformation_size 24 | 25 | if bias_transformation_size > 0: 26 | input_size = input_size + bias_transformation_size 27 | self.bias_transformation = torch.nn.Parameter( 28 | torch.zeros(bias_transformation_size)) 29 | 30 | layer_sizes = [input_size] + hidden_sizes + [output_size] 31 | for i in range(1, self.num_layers): 32 | self.add_module( 33 | 'layer{0}_linear'.format(i), 34 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i])) 35 | if not self.disable_norm: 36 | self.add_module( 37 | 'layer{0}_bn'.format(i), 38 | torch.nn.BatchNorm1d(layer_sizes[i], momentum=0.001)) 39 | self.add_module( 40 | 'output_linear', 41 | torch.nn.Linear(layer_sizes[self.num_layers - 1], 42 | layer_sizes[self.num_layers])) 43 | self.apply(weight_init) 44 | 45 | def forward(self, task, params=None, training=True, embeddings=None): 46 | if params is None: 47 | params = OrderedDict(self.named_parameters()) 48 | x = task.x.view(task.x.size(0), -1) 49 | 50 | if self.bias_transformation_size > 0: 51 | x = torch.cat((x, params['bias_transformation'].expand( 52 | x.size(0), params['bias_transformation'].size(0))), dim=1) 53 | 54 | for key, module in self.named_modules(): 55 | if 'linear' in key: 56 | x = F.linear(x, weight=params[key + '.weight'], 57 | bias=params[key + '.bias']) 58 | if self.disable_norm and 'output' not in key: 59 | x = self.nonlinearity(x) 60 | if 'bn' in key: 61 | x = F.batch_norm(x, weight=params[key + '.weight'], 62 | bias=params[key + '.bias'], 63 | running_mean=module.running_mean, 64 | running_var=module.running_var, 65 | training=training) 66 | x = self.nonlinearity(x) 67 | return x 68 | 69 | class MultiFullyConnectedModel(Model): 70 | def __init__(self, input_size, output_size, hidden_sizes=(), 71 | nonlinearity=F.relu, disable_norm=False, num_tasks=1, 72 | bias_transformation_size=0): 73 | super(MultiFullyConnectedModel, self).__init__() 74 | self.hidden_sizes = hidden_sizes 75 | self.nonlinearity = nonlinearity 76 | self.num_layers = len(hidden_sizes) + 1 77 | self.disable_norm = disable_norm 78 | self.bias_transformation_size = bias_transformation_size 79 | self.num_tasks = num_tasks 80 | 81 | if bias_transformation_size > 0: 82 | input_size = input_size + bias_transformation_size 83 | self.bias_transformation = torch.nn.Parameter( 84 | torch.zeros(bias_transformation_size)) 85 | 86 | layer_sizes = [input_size] + hidden_sizes + [output_size] 87 | for j in range(0, self.num_tasks): 88 | for i in range(1, self.num_layers): 89 | self.add_module( 90 | 'task{0}_layer{1}_linear'.format(j, i), 91 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i])) 92 | if not self.disable_norm: 93 | self.add_module( 94 | 'task{0}_layer{1}_bn'.format(j, i), 95 | torch.nn.BatchNorm1d(layer_sizes[i], momentum=0.001)) 96 | self.add_module( 97 | 'task{0}_output_linear'.format(j), 98 | torch.nn.Linear(layer_sizes[self.num_layers - 1], 99 | layer_sizes[self.num_layers])) 100 | self.apply(weight_init) 101 | 102 | def forward(self, task, params=None, training=True, embeddings=None): 103 | if params is None: 104 | params = OrderedDict(self.named_parameters()) 105 | x = task.x.view(task.x.size(0), -1) 106 | task_id = task.task_info['task_id'] 107 | 108 | if self.bias_transformation_size > 0: 109 | x = torch.cat((x, params['bias_transformation'].expand( 110 | x.size(0), params['bias_transformation'].size(0))), dim=1) 111 | 112 | for key, module in self.named_modules(): 113 | if 'task{0}'.format(task_id) in key: 114 | if 'linear' in key: 115 | x = F.linear(x, weight=params[key + '.weight'], 116 | bias=params[key + '.bias']) 117 | if self.disable_norm and 'output' not in key: 118 | x = self.nonlinearity(x) 119 | if 'bn' in key: 120 | x = F.batch_norm(x, weight=params[key + '.weight'], 121 | bias=params[key + '.bias'], 122 | running_mean=module.running_mean, 123 | running_var=module.running_var, 124 | training=training) 125 | x = self.nonlinearity(x) 126 | return x 127 | -------------------------------------------------------------------------------- /maml/models/gated_conv_net.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from maml.models.model import Model 7 | 8 | 9 | def weight_init(module): 10 | if (isinstance(module, torch.nn.Linear) 11 | or isinstance(module, torch.nn.Conv2d)): 12 | torch.nn.init.xavier_uniform_(module.weight) 13 | module.bias.data.zero_() 14 | 15 | 16 | class GatedConvModel(Model): 17 | """ 18 | NOTE: difference to tf implementation: batch norm scaling is enabled here 19 | TODO: enable 'non-transductive' setting as per 20 | https://arxiv.org/abs/1803.02999 21 | """ 22 | def __init__(self, input_channels, output_size, num_channels=64, 23 | kernel_size=3, padding=1, nonlinearity=F.relu, 24 | use_max_pool=False, img_side_len=28, 25 | condition_type='affine', condition_order='low2high', verbose=False): 26 | super(GatedConvModel, self).__init__() 27 | self._input_channels = input_channels 28 | self._output_size = output_size 29 | self._num_channels = num_channels 30 | self._kernel_size = kernel_size 31 | self._nonlinearity = nonlinearity 32 | self._use_max_pool = use_max_pool 33 | self._padding = padding 34 | self._condition_type = condition_type 35 | self._condition_order = condition_order 36 | self._bn_affine = False 37 | self._reuse = False 38 | self._verbose = verbose 39 | 40 | if self._use_max_pool: 41 | self._conv_stride = 1 42 | self._features_size = 1 43 | self.features = torch.nn.Sequential(OrderedDict([ 44 | ('layer1_conv', torch.nn.Conv2d(self._input_channels, 45 | self._num_channels, 46 | self._kernel_size, 47 | stride=self._conv_stride, 48 | padding=self._padding)), 49 | ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels, 50 | affine=self._bn_affine, 51 | momentum=0.001)), 52 | ('layer1_condition', None), 53 | ('layer1_max_pool', torch.nn.MaxPool2d(kernel_size=2, 54 | stride=2)), 55 | ('layer1_relu', torch.nn.ReLU(inplace=True)), 56 | ('layer2_conv', torch.nn.Conv2d(self._num_channels, 57 | self._num_channels*2, 58 | self._kernel_size, 59 | stride=self._conv_stride, 60 | padding=self._padding)), 61 | ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2, 62 | affine=self._bn_affine, 63 | momentum=0.001)), 64 | ('layer2_condition', None), 65 | ('layer2_max_pool', torch.nn.MaxPool2d(kernel_size=2, 66 | stride=2)), 67 | ('layer2_relu', torch.nn.ReLU(inplace=True)), 68 | ('layer3_conv', torch.nn.Conv2d(self._num_channels*2, 69 | self._num_channels*4, 70 | self._kernel_size, 71 | stride=self._conv_stride, 72 | padding=self._padding)), 73 | ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4, 74 | affine=self._bn_affine, 75 | momentum=0.001)), 76 | ('layer3_condition', None), 77 | ('layer3_max_pool', torch.nn.MaxPool2d(kernel_size=2, 78 | stride=2)), 79 | ('layer3_relu', torch.nn.ReLU(inplace=True)), 80 | ('layer4_conv', torch.nn.Conv2d(self._num_channels*4, 81 | self._num_channels*8, 82 | self._kernel_size, 83 | stride=self._conv_stride, 84 | padding=self._padding)), 85 | ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8, 86 | affine=self._bn_affine, 87 | momentum=0.001)), 88 | ('layer4_condition', None), 89 | ('layer4_max_pool', torch.nn.MaxPool2d(kernel_size=2, 90 | stride=2)), 91 | ('layer4_relu', torch.nn.ReLU(inplace=True)), 92 | ])) 93 | else: 94 | self._conv_stride = 2 95 | self._features_size = (img_side_len // 14)**2 96 | self.features = torch.nn.Sequential(OrderedDict([ 97 | ('layer1_conv', torch.nn.Conv2d(self._input_channels, 98 | self._num_channels, 99 | self._kernel_size, 100 | stride=self._conv_stride, 101 | padding=self._padding)), 102 | ('layer1_bn', torch.nn.BatchNorm2d(self._num_channels, 103 | affine=self._bn_affine, 104 | momentum=0.001)), 105 | ('layer1_condition', torch.nn.ReLU(inplace=True)), 106 | ('layer1_relu', torch.nn.ReLU(inplace=True)), 107 | ('layer2_conv', torch.nn.Conv2d(self._num_channels, 108 | self._num_channels*2, 109 | self._kernel_size, 110 | stride=self._conv_stride, 111 | padding=self._padding)), 112 | ('layer2_bn', torch.nn.BatchNorm2d(self._num_channels*2, 113 | affine=self._bn_affine, 114 | momentum=0.001)), 115 | ('layer2_condition', torch.nn.ReLU(inplace=True)), 116 | ('layer2_relu', torch.nn.ReLU(inplace=True)), 117 | ('layer3_conv', torch.nn.Conv2d(self._num_channels*2, 118 | self._num_channels*4, 119 | self._kernel_size, 120 | stride=self._conv_stride, 121 | padding=self._padding)), 122 | ('layer3_bn', torch.nn.BatchNorm2d(self._num_channels*4, 123 | affine=self._bn_affine, 124 | momentum=0.001)), 125 | ('layer3_condition', torch.nn.ReLU(inplace=True)), 126 | ('layer3_relu', torch.nn.ReLU(inplace=True)), 127 | ('layer4_conv', torch.nn.Conv2d(self._num_channels*4, 128 | self._num_channels*8, 129 | self._kernel_size, 130 | stride=self._conv_stride, 131 | padding=self._padding)), 132 | ('layer4_bn', torch.nn.BatchNorm2d(self._num_channels*8, 133 | affine=self._bn_affine, 134 | momentum=0.001)), 135 | ('layer4_condition', torch.nn.ReLU(inplace=True)), 136 | ('layer4_relu', torch.nn.ReLU(inplace=True)), 137 | ])) 138 | 139 | self.classifier = torch.nn.Sequential(OrderedDict([ 140 | ('fully_connected', torch.nn.Linear(self._num_channels*8, 141 | self._output_size)) 142 | ])) 143 | self.apply(weight_init) 144 | 145 | def conditional_layer(self, x, embedding): 146 | if self._condition_type == 'sigmoid_gate': 147 | x = x * F.sigmoid(embedding).expand_as(x) 148 | elif self._condition_type == 'affine': 149 | gammas, betas = torch.split(embedding, x.size(1), dim=-1) 150 | gammas = gammas.view(1, -1, 1, 1).expand_as(x) 151 | betas = betas.view(1, -1, 1, 1).expand_as(x) 152 | gammas = gammas + torch.ones_like(gammas) 153 | x = x * gammas + betas 154 | elif self._condition_type == 'softmax': 155 | x = x * F.softmax(embedding).view(1, -1, 1, 1).expand_as(x) 156 | else: 157 | raise ValueError('Unrecognized conditional layer type {}'.format( 158 | self._condition_type)) 159 | return x 160 | 161 | def forward(self, task, params=None, embeddings=None): 162 | if not self._reuse and self._verbose: print('='*10 + ' Model ' + '='*10) 163 | if params is None: 164 | params = OrderedDict(self.named_parameters()) 165 | 166 | if embeddings is not None: 167 | embeddings = {'layer{}_condition'.format(i): embedding 168 | for i, embedding in enumerate(embeddings, start=1)} 169 | 170 | x = task.x 171 | if not self._reuse and self._verbose: print('input size: {}'.format(x.size())) 172 | for layer_name, layer in self.features.named_children(): 173 | weight = params.get('features.' + layer_name + '.weight', None) 174 | bias = params.get('features.' + layer_name + '.bias', None) 175 | if 'conv' in layer_name: 176 | x = F.conv2d(x, weight=weight, bias=bias, 177 | stride=self._conv_stride, padding=self._padding) 178 | elif 'condition' in layer_name: 179 | x = self.conditional_layer(x, embeddings[layer_name]) if embeddings is not None else x 180 | elif 'bn' in layer_name: 181 | x = F.batch_norm(x, weight=weight, bias=bias, 182 | running_mean=layer.running_mean, 183 | running_var=layer.running_var, 184 | training=True) 185 | elif 'max_pool' in layer_name: 186 | x = F.max_pool2d(x, kernel_size=2, stride=2) 187 | elif 'relu' in layer_name: 188 | x = F.relu(x) 189 | elif 'fully_connected' in layer_name: 190 | break 191 | else: 192 | raise ValueError('Unrecognized layer {}'.format(layer_name)) 193 | if not self._reuse and self._verbose: print('{}: {}'.format(layer_name, x.size())) 194 | 195 | # in maml network the conv maps are average pooled 196 | x = x.view(x.size(0), self._num_channels*8, self._features_size) 197 | if not self._reuse and self._verbose: print('reshape to: {}'.format(x.size())) 198 | x = torch.mean(x, dim=2) 199 | if not self._reuse and self._verbose: print('reduce mean: {}'.format(x.size())) 200 | logits = F.linear( 201 | x, weight=params['classifier.fully_connected.weight'], 202 | bias=params['classifier.fully_connected.bias']) 203 | if not self._reuse and self._verbose: print('logits size: {}'.format(logits.size())) 204 | if not self._reuse and self._verbose: print('='*27) 205 | self._reuse = True 206 | return logits 207 | -------------------------------------------------------------------------------- /maml/models/gated_net.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from maml.models.model import Model 7 | from IPython import embed 8 | 9 | def weight_init(module): 10 | if isinstance(module, torch.nn.Linear): 11 | torch.nn.init.normal_(module.weight, mean=0, std=0.01) 12 | module.bias.data.zero_() 13 | 14 | 15 | class GatedNet(Model): 16 | def __init__(self, input_size, output_size, hidden_sizes=[40, 40], 17 | nonlinearity=F.relu, condition_type='sigmoid_gate', condition_order='low2high'): 18 | super(GatedNet, self).__init__() 19 | self._nonlinearity = nonlinearity 20 | self._condition_type = condition_type 21 | self._condition_order = condition_order 22 | 23 | self.num_layers = len(hidden_sizes) + 1 24 | 25 | layer_sizes = [input_size] + hidden_sizes + [output_size] 26 | for i in range(1, self.num_layers): 27 | self.add_module( 28 | 'layer{0}_linear'.format(i), 29 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i])) 30 | self.add_module( 31 | 'output_linear', 32 | torch.nn.Linear(layer_sizes[self.num_layers - 1], 33 | layer_sizes[self.num_layers])) 34 | self.apply(weight_init) 35 | 36 | def conditional_layer(self, x, embedding): 37 | if self._condition_type == 'sigmoid_gate': 38 | x = x * F.sigmoid(embedding).expand_as(x) 39 | elif self._condition_type == 'affine': 40 | gammas, betas = torch.split(embedding, x.size(1), dim=-1) 41 | gammas = gammas + torch.ones_like(gammas) 42 | x = x * gammas + betas 43 | elif self._condition_type == 'softmax': 44 | x = x * F.softmax(embedding).expand_as(x) 45 | else: 46 | raise ValueError('Unrecognized conditional layer type {}'.format( 47 | self._condition_type)) 48 | return x 49 | 50 | def forward(self, task, params=None, embeddings=None, training=True): 51 | if params is None: 52 | params = OrderedDict(self.named_parameters()) 53 | 54 | if embeddings is not None: 55 | if self._condition_order == 'high2low': ## High2Low 56 | embeddings = {'layer{}_linear'.format(len(params)-i): embedding 57 | for i, embedding in enumerate(embeddings[::-1])} 58 | elif self._condition_order == 'low2high': ## Low2High 59 | embeddings = {'layer{}_linear'.format(i): embedding 60 | for i, embedding in enumerate(embeddings[::-1], start=1)} 61 | else: 62 | raise NotImplementedError('Unsuppported order for using conditional layers') 63 | x = task.x.view(task.x.size(0), -1) 64 | 65 | for key, module in self.named_modules(): 66 | if 'linear' in key: 67 | x = F.linear(x, weight=params[key + '.weight'], 68 | bias=params[key + '.bias']) 69 | if 'output' not in key and embeddings is not None: # conditioning and nonlinearity 70 | if type(embeddings.get(key, -1)) != type(-1): 71 | x = self.conditional_layer(x, embeddings[key]) 72 | 73 | x = self._nonlinearity(x) 74 | 75 | return x 76 | -------------------------------------------------------------------------------- /maml/models/gru_embedding_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class GRUEmbeddingModel(torch.nn.Module): 4 | def __init__(self, input_size, output_size, embedding_dims, 5 | hidden_size=40, num_layers=2): 6 | super(GRUEmbeddingModel, self).__init__() 7 | self._input_size = input_size 8 | self._output_size = output_size 9 | self._hidden_size = hidden_size 10 | self._num_layers = num_layers 11 | self._embedding_dims = embedding_dims 12 | self._bidirectional = True 13 | self._device = 'cpu' 14 | 15 | rnn_input_size = int(input_size + output_size) 16 | self.rnn = torch.nn.GRU(rnn_input_size, hidden_size, num_layers, bidirectional=self._bidirectional) 17 | 18 | self._embeddings = torch.nn.ModuleList() 19 | for dim in embedding_dims: 20 | self._embeddings.append(torch.nn.Linear(hidden_size*(2 if self._bidirectional else 1), dim)) 21 | 22 | def forward(self, task): 23 | batch_size = 1 24 | h0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1), 25 | batch_size, self._hidden_size, device=self._device) 26 | 27 | x = task.x.view(task.x.size(0), -1) 28 | y = task.y.view(task.y.size(0), -1) 29 | 30 | # LSTM input dimensions are seq_len, batch, input_size 31 | inputs = torch.cat((x, y), dim=1).view(x.size(0), 1, -1) 32 | output, _ = self.rnn(inputs, h0) 33 | if self._bidirectional: 34 | N, B, H = output.shape 35 | output = output.view(N, B, 2, H // 2) 36 | embedding_input = torch.cat([output[-1, :, 0], output[0, :, 1]], dim=1) 37 | 38 | out_embeddings = [] 39 | for embedding in self._embeddings: 40 | out_embeddings.append(embedding(embedding_input)) 41 | return out_embeddings 42 | 43 | def to(self, device, **kwargs): 44 | self._device = device 45 | super(GRUEmbeddingModel, self).to(device, **kwargs) 46 | -------------------------------------------------------------------------------- /maml/models/lstm_embedding_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LSTMEmbeddingModel(torch.nn.Module): 4 | def __init__(self, input_size, output_size, embedding_dims, 5 | hidden_size=40, num_layers=2): 6 | super(LSTMEmbeddingModel, self).__init__() 7 | self._input_size = input_size 8 | self._output_size = output_size 9 | self._hidden_size = hidden_size 10 | self._num_layers = num_layers 11 | self._embedding_dims = embedding_dims 12 | self._bidirectional = True 13 | self._device = 'cpu' 14 | 15 | rnn_input_size = int(input_size + output_size) 16 | self.rnn = torch.nn.LSTM(rnn_input_size, hidden_size, num_layers, bidirectional=self._bidirectional) 17 | 18 | self._embeddings = torch.nn.ModuleList() 19 | for dim in embedding_dims: 20 | self._embeddings.append(torch.nn.Linear(hidden_size*(2 if self._bidirectional else 1), dim)) 21 | 22 | def forward(self, task): 23 | batch_size = 1 24 | h0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1), 25 | batch_size, self._hidden_size, device=self._device) 26 | c0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1), 27 | batch_size, self._hidden_size, device=self._device) 28 | 29 | x = task.x.view(task.x.size(0), -1) 30 | y = task.y.view(task.y.size(0), -1) 31 | 32 | # LSTM input dimensions are seq_len, batch, input_size 33 | inputs = torch.cat((x, y), dim=1).view(x.size(0), 1, -1) 34 | output, (hn, cn) = self.rnn(inputs, (h0, c0)) 35 | if self._bidirectional: 36 | N, B, H = output.shape 37 | output = output.view(N, B, 2, H // 2) 38 | embedding_input = torch.cat([output[-1, :, 0], output[0, :, 1]], dim=1) 39 | 40 | out_embeddings = [] 41 | for embedding in self._embeddings: 42 | out_embeddings.append(embedding(embedding_input)) 43 | return out_embeddings 44 | 45 | def to(self, device, **kwargs): 46 | self._device = device 47 | super(LSTMEmbeddingModel, self).to(device, **kwargs) 48 | -------------------------------------------------------------------------------- /maml/models/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | class Model(torch.nn.Module): 7 | def __init__(self): 8 | super(Model, self).__init__() 9 | 10 | @property 11 | def param_dict(self): 12 | return OrderedDict(self.named_parameters()) -------------------------------------------------------------------------------- /maml/models/simple_embedding_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SimpleEmbeddingModel(torch.nn.Module): 4 | def __init__(self, num_embeddings, embedding_dims): 5 | super(SimpleEmbeddingModel, self).__init__() 6 | self._embeddings = torch.nn.ModuleList() 7 | for dim in embedding_dims: 8 | self._embeddings.append(torch.nn.Embedding(num_embeddings, dim)) 9 | self._device = 'cpu' 10 | 11 | def forward(self, task): 12 | task_id = torch.tensor(task.task_id, dtype=torch.long, 13 | device=self._device) 14 | out_embeddings = [] 15 | for embedding in self._embeddings: 16 | out_embeddings.append(embedding(task_id)) 17 | return out_embeddings 18 | 19 | def to(self, device, **kwargs): 20 | self._device = device 21 | super(SimpleEmbeddingModel, self).to(device, **kwargs) -------------------------------------------------------------------------------- /maml/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict, namedtuple 3 | 4 | from torch.utils.data.sampler import Sampler 5 | 6 | 7 | class ClassBalancedSampler(Sampler): 8 | """Generates indices for class balanced batch by sampling with replacement. 9 | """ 10 | def __init__(self, dataset_labels, num_classes_per_batch, 11 | num_samples_per_class, num_total_batches, train): 12 | """ 13 | Args: 14 | dataset_labels: list of dataset labels 15 | num_classes_per_batch: number of classes to sample for each batch 16 | num_samples_per_class: number of samples to sample for each class 17 | for each batch. For K shot learning this should be K + number 18 | of validation samples 19 | num_total_batches: total number of batches to generate 20 | """ 21 | self._dataset_labels = dataset_labels 22 | self._classes = set(self._dataset_labels) 23 | self._class_to_samples = defaultdict(set) 24 | for i, c in enumerate(self._dataset_labels): 25 | self._class_to_samples[c].add(i) 26 | 27 | self._num_classes_per_batch = num_classes_per_batch 28 | self._num_samples_per_class = num_samples_per_class 29 | self._num_total_batches = num_total_batches 30 | self._train = train 31 | 32 | def __iter__(self): 33 | for i in range(self._num_total_batches): 34 | if len(self._class_to_samples.keys()) >= self._num_classes_per_batch: 35 | batch_classes = random.sample( 36 | self._class_to_samples.keys(), self._num_classes_per_batch) 37 | else: 38 | batch_classes = [random.choice(list(self._class_to_samples.keys())) 39 | for _ in range(self._num_classes_per_batch)] 40 | batch_samples = [] 41 | for c in batch_classes: 42 | if len(self._class_to_samples[c]) >= self._num_samples_per_class: 43 | class_samples = random.sample( 44 | self._class_to_samples[c], self._num_samples_per_class) 45 | else: 46 | class_samples = [random.choice(list(self._class_to_samples[c])) 47 | for _ in range(self._num_samples_per_class)] 48 | 49 | for sample in class_samples: 50 | batch_samples.append(sample) 51 | random.shuffle(batch_samples) 52 | for sample in batch_samples: 53 | yield sample 54 | 55 | def __len__(self): 56 | return self._num_total_batches 57 | -------------------------------------------------------------------------------- /maml/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | 8 | class Trainer(object): 9 | def __init__(self, meta_learner, meta_dataset, writer, log_interval, 10 | save_interval, model_type, save_folder, total_iter): 11 | self._meta_learner = meta_learner 12 | self._meta_dataset = meta_dataset 13 | self._writer = writer 14 | self._log_interval = log_interval 15 | self._save_interval = save_interval 16 | self._model_type = model_type 17 | self._save_folder = save_folder 18 | self._total_iter = total_iter 19 | 20 | 21 | def run(self, is_training): 22 | if not is_training: 23 | all_pre_val_measurements = defaultdict(list) 24 | all_pre_train_measurements = defaultdict(list) 25 | all_post_val_measurements = defaultdict(list) 26 | all_post_train_measurements = defaultdict(list) 27 | 28 | # compute running accuracies for all datasets 29 | if self._meta_dataset.name == 'MultimodalFewShot': 30 | accuracies = [[] for i in range(self._meta_dataset.num_dataset)] 31 | 32 | for i, (train_tasks, val_tasks) in enumerate( 33 | iter(self._meta_dataset), start=1): 34 | 35 | # Save model 36 | if (i % self._save_interval == 0 or i == 1) and is_training: 37 | save_name = 'maml_{0}_{1}.pt'.format(self._model_type, i) 38 | save_path = os.path.join(self._save_folder, save_name) 39 | with open(save_path, 'wb') as f: 40 | torch.save(self._meta_learner.state_dict(), f) 41 | 42 | (pre_train_measurements, adapted_params, embeddings 43 | ) = self._meta_learner.adapt(train_tasks) 44 | post_val_measurements = self._meta_learner.step( 45 | adapted_params, embeddings, val_tasks, is_training) 46 | 47 | # Tensorboard 48 | if (i % self._log_interval == 0 or i == 1): 49 | pre_val_measurements = self._meta_learner.measure( 50 | tasks=val_tasks, embeddings_list=embeddings) 51 | post_train_measurements = self._meta_learner.measure( 52 | tasks=train_tasks, adapted_params_list=adapted_params, 53 | embeddings_list=embeddings) 54 | 55 | _grads_mean = np.mean(self._meta_learner._grads_mean) 56 | self._meta_learner._grads_mean = [] 57 | 58 | self.log_output( 59 | pre_val_measurements, pre_train_measurements, 60 | post_val_measurements, post_train_measurements, 61 | i, _grads_mean) 62 | 63 | if is_training: 64 | self.write_tensorboard( 65 | pre_val_measurements, pre_train_measurements, 66 | post_val_measurements, post_train_measurements, 67 | i, _grads_mean) 68 | 69 | if self._meta_dataset.name == 'MultimodalFewShot': 70 | 71 | post_val_accuracies = self._meta_learner.measure_each( 72 | tasks=val_tasks, adapted_params_list=adapted_params, 73 | embeddings_list=embeddings) 74 | 75 | if is_training: 76 | accuracies = [[] for i in range(self._meta_dataset.num_dataset)] 77 | for i, accuracy in enumerate(post_val_accuracies): 78 | accuracies[self._meta_dataset.dataset_names.index( 79 | val_tasks[i].task_info)].append(accuracy) 80 | 81 | accuracy_str = [] 82 | for i, accuracy in enumerate(accuracies): 83 | accuracy_str.append('{}: {}'.format( 84 | self._meta_dataset.dataset_names[i], 85 | 'NaN' if len(accuracy) == 0 \ 86 | else '{:.3f}%'.format(100*np.mean(accuracy)))) 87 | 88 | print('Individual accuracies: {}'.format(' '.join(accuracy_str))) 89 | print('All accuracy: {:.3f}%'.format(100*np.mean( 90 | [item for accuracy in accuracies for item in accuracy]))) 91 | 92 | # Collect evaluation statistics over full dataset 93 | if not is_training: 94 | for key, value in sorted(pre_val_measurements.items()): 95 | all_pre_val_measurements[key].append(value) 96 | for key, value in sorted(pre_train_measurements.items()): 97 | all_pre_train_measurements[key].append(value) 98 | for key, value in sorted(post_val_measurements.items()): 99 | all_post_val_measurements[key].append(value) 100 | for key, value in sorted(post_train_measurements.items()): 101 | all_post_train_measurements[key].append(value) 102 | 103 | # Compute evaluation statistics assuming all batches were the same size 104 | if not is_training: 105 | results = {'num_batches': i} 106 | for key, value in sorted(all_pre_val_measurements.items()): 107 | results['pre_val_' + key] = value 108 | for key, value in sorted(all_pre_train_measurements.items()): 109 | results['pre_train_' + key] = value 110 | for key, value in sorted(all_post_val_measurements.items()): 111 | results['post_val_' + key] = value 112 | for key, value in sorted(all_post_train_measurements.items()): 113 | results['post_train_' + key] = value 114 | 115 | print('Evaluation results:') 116 | for key, value in sorted(results.items()): 117 | if not isinstance(value, int): 118 | print('{}: {} +- {}'.format( 119 | key, np.mean(value), self.compute_confidence_interval(value))) 120 | else: 121 | print('{}: {}'.format(key, value)) 122 | 123 | results_path = os.path.join(self._save_folder, 'results.json') 124 | with open(results_path, 'w') as f: 125 | json.dump(results, f) 126 | 127 | def compute_confidence_interval(self, value): 128 | """ 129 | Compute 95% +- confidence intervals over tasks 130 | change 1.960 to 2.576 for 99% +- confidence intervals 131 | """ 132 | return np.std(value) * 1.960 / np.sqrt(len(value)) 133 | 134 | def train(self): 135 | self.run(is_training=True) 136 | 137 | def eval(self): 138 | self.run(is_training=False) 139 | 140 | def write_tensorboard(self, pre_val_measurements, pre_train_measurements, 141 | post_val_measurements, post_train_measurements, 142 | iteration, embedding_grads_mean=None): 143 | for key, value in pre_val_measurements.items(): 144 | self._writer.add_scalar( 145 | '{}/before_update/meta_val'.format(key), value, iteration) 146 | for key, value in pre_train_measurements.items(): 147 | self._writer.add_scalar( 148 | '{}/before_update/meta_train'.format(key), value, iteration) 149 | for key, value in post_train_measurements.items(): 150 | self._writer.add_scalar( 151 | '{}/after_update/meta_train'.format(key), value, iteration) 152 | for key, value in post_val_measurements.items(): 153 | self._writer.add_scalar( 154 | '{}/after_update/meta_val'.format(key), value, iteration) 155 | if embedding_grads_mean is not None: 156 | self._writer.add_scalar( 157 | 'embedding_grads_mean', embedding_grads_mean, iteration) 158 | 159 | def log_output(self, pre_val_measurements, pre_train_measurements, 160 | post_val_measurements, post_train_measurements, 161 | iteration, embedding_grads_mean=None): 162 | log_str = 'Iteration: {}/{} '.format(iteration, self._total_iter) 163 | for key, value in sorted(pre_val_measurements.items()): 164 | log_str = (log_str + '{} meta_val before: {:.3f} ' 165 | ''.format(key, value)) 166 | for key, value in sorted(pre_train_measurements.items()): 167 | log_str = (log_str + '{} meta_train before: {:.3f} ' 168 | ''.format(key, value)) 169 | for key, value in sorted(post_train_measurements.items()): 170 | log_str = (log_str + '{} meta_train after: {:.3f} ' 171 | ''.format(key, value)) 172 | for key, value in sorted(post_val_measurements.items()): 173 | log_str = (log_str + '{} meta_val after: {:.3f} ' 174 | ''.format(key, value)) 175 | if embedding_grads_mean is not None: 176 | log_str = (log_str + 'embedding_grad_norm after: {:.3f} ' 177 | ''.format(embedding_grads_mean)) 178 | print(log_str) 179 | -------------------------------------------------------------------------------- /maml/utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import torch 4 | 5 | 6 | def accuracy(preds, y): 7 | _, preds = torch.max(preds.data, 1) 8 | total = y.size(0) 9 | correct = (preds == y).sum().float() 10 | return correct / total 11 | 12 | 13 | def optimizer_to_device(optimizer, device): 14 | for state in optimizer.state.values(): 15 | for k, v in state.items(): 16 | if isinstance(v, torch.Tensor): 17 | state[k] = v.to(device) 18 | 19 | 20 | def get_git_revision_hash(): 21 | return str(subprocess.check_output(['git', 'rev-parse', 'HEAD'])) -------------------------------------------------------------------------------- /miniimagenet-data/download_mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def download_file_from_google_drive(id, destination): 4 | URL = "https://drive.google.com/uc?export=download" 5 | 6 | session = requests.Session() 7 | 8 | response = session.get(URL, params = { 'id' : id }, stream = True) 9 | token = get_confirm_token(response) 10 | 11 | if token: 12 | params = { 'id' : id, 'confirm' : token } 13 | response = session.get(URL, params = params, stream = True) 14 | 15 | save_response_content(response, destination) 16 | 17 | def get_confirm_token(response): 18 | for key, value in response.cookies.items(): 19 | if key.startswith('download_warning'): 20 | return value 21 | 22 | return None 23 | 24 | def save_response_content(response, destination): 25 | CHUNK_SIZE = 32768 26 | 27 | with open(destination, "wb") as f: 28 | for chunk in response.iter_content(CHUNK_SIZE): 29 | if chunk: # filter out keep-alive new chunks 30 | f.write(chunk) 31 | 32 | if __name__ == "__main__": 33 | file_id = "0B3Irx3uQNoBMQ1FlNXJsZUdYWEE" 34 | destination = './images.zip' 35 | download_file_from_google_drive(file_id, destination) 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipython 2 | torch=='1.2.0' 3 | torchvision=='0.4.0' 4 | imageio 5 | scikit-image 6 | tqdm 7 | Pillow 8 | numpy 9 | scipy 10 | requests 11 | subprocess 12 | tensorboardX 13 | --------------------------------------------------------------------------------