├── .gitignore ├── LICENSE ├── README.md ├── assets └── model.png ├── environment.yml ├── main.py └── maml ├── __init__.py ├── datasets ├── __init__.py ├── metadataset.py └── simple_functions.py ├── metalearner.py ├── models ├── __init__.py ├── fully_connected.py ├── gated_net.py ├── lstm_embedding_model.py ├── model.py └── simple_embedding_model.py ├── sampler.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # project specific 2 | data/ 3 | saves/ 4 | logs/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | *.txt 112 | *.sw[opn] 113 | *.pt 114 | 115 | 116 | train_dir -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tristan Deleu 4 | Copyright (c) 2018 Risto Vuorio 5 | Copyright (c) 2019 Hexiang Hu 6 | Copyright (c) 2019 Shao-Hua Sun 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Model-Agnostic Meta-Learning for Few-shot Regression 2 | 3 | This project is an implementation of [**Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation**](https://arxiv.org/abs/1910.13616), which is published in [**NeurIPS 2019**](https://neurips.cc/Conferences/2019/). Please visit our [project page](https://vuoristo.github.io/MMAML/) for more information and contact [Hexiang Hu](http://hexianghu.com/) for any questions. 4 | 5 | Model-agnostic meta-learners aim to acquire meta-prior parameters from a distribution of tasks and adapt to novel tasks with few gradient updates. Yet, seeking a common initialization shared across the entire task distribution substantially limits the diversity of the task distributions that they are able to learn from. We propose a multimodal MAML (MMAML) framework, which is able to modulate its meta-learned prior according to the identified mode, allowing more efficient fast adaptation. An illustration of the proposed framework is as follows. 6 | 7 |

8 | 9 |

10 | 11 | 12 | ## Getting started 13 | 14 | Use of [Conda Environment](https://docs.conda.io/en/latest/) is suggested to for straightforward handling of the dependencies. 15 | 16 | ```bash 17 | conda env create -f environment.yml 18 | conda activate mmaml_regression 19 | ``` 20 | 21 | # Usage 22 | 23 | After installation, we can start to train models with the following commands. 24 | 25 | ## Linear + Sinusoid Functions 26 | 27 | ### MAML 28 | ``` 29 | python main.py --dataset mixed --num-batches 70000 --model-type fc --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-maml-5steps --bias-transformation-size 20 --disable-norm 30 | ``` 31 | 32 | 33 | ### Multi-MAML 34 | ``` 35 | python main.py --dataset mixed --num-batches 70000 --model-type multi --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-multi-maml-5steps --bias-transformation-size 20 --disable-norm 36 | ``` 37 | 38 | ### MMAML-postupdate 39 | 40 | #### FiLM 41 | ``` 42 | python main.py --dataset mixed --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-mmaml-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 --inner-loop-grad-clip 10 43 | ``` 44 | 45 | #### Sigmoid 46 | ``` 47 | python main.py --dataset mixed --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-mmaml-sigmoid-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 --inner-loop-grad-clip 10 --condition-type sigmoid_gate 48 | ``` 49 | 50 | #### Softmax 51 | ``` 52 | python main.py --dataset mixed --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-mmaml-softmax-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 --inner-loop-grad-clip 10 --condition-type softmax 53 | ``` 54 | 55 | ### MMAML-preupdate 56 | ``` 57 | python main.py --dataset mixed --num-batches 70000 --model-type gated --fast-lr 0.0 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 1 --output-folder 2mods-mmaml-pre-1steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 --inner-loop-grad-clip 10 58 | ``` 59 | 60 | ## Linear + Quadratic + Sinusoid Functions 61 | 62 | ### MAML 63 | ``` 64 | python main.py --dataset many --num-batches 70000 --model-type fc --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-maml-5steps --bias-transformation-size 20 --disable-norm 65 | ``` 66 | 67 | 68 | ### Multi-MAML 69 | ``` 70 | python main.py --dataset many --num-batches 70000 --model-type multi --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-multi-maml-5steps --bias-transformation-size 20 --disable-norm 71 | ``` 72 | 73 | ### MMAML-postupdate 74 | 75 | #### FiLM 76 | ``` 77 | python main.py --dataset many --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 200 200 --inner-loop-grad-clip 10 78 | ``` 79 | 80 | #### Sigmoid 81 | ``` 82 | python main.py --dataset many --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-sigmoid-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 100 100 --inner-loop-grad-clip 10 --condition-type sigmoid_gate 83 | ``` 84 | 85 | #### Softmax 86 | ``` 87 | python main.py --dataset many --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-softmax-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 100 100 --inner-loop-grad-clip 10 --condition-type softmax 88 | ``` 89 | 90 | ### MMAML-preupdate 91 | ``` 92 | python main.py --dataset many --num-batches 70000 --model-type gated --fast-lr 0.00 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 1 --output-folder 3mods-mmaml-pre-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 200 200 --inner-loop-grad-clip 10 93 | ``` 94 | 95 | 96 | ## Linear + Quadratic + Sinusoid + Tanh + Absolute Functions 97 | 98 | ### MAML 99 | ``` 100 | python main.py --dataset five --num-batches 70000 --model-type fc --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 5mods-maml-5steps --bias-transformation-size 20 --disable-norm 101 | ``` 102 | 103 | 104 | ### Multi-MAML 105 | ``` 106 | python main.py --dataset five --num-batches 70000 --model-type multi --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 5mods-multi-maml-5steps --bias-transformation-size 20 --disable-norm 107 | ``` 108 | 109 | ### MMAML-postupdate 110 | 111 | #### FiLM 112 | ``` 113 | python main.py --dataset five --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 5mods-mmaml-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 200 200 --inner-loop-grad-clip 10 114 | ``` 115 | 116 | #### Sigmoid 117 | ``` 118 | python main.py --dataset five --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-sigmoid-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 100 100 --inner-loop-grad-clip 10 --condition-type sigmoid_gate 119 | ``` 120 | 121 | #### Softmax 122 | ``` 123 | python main.py --dataset five --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-softmax-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 100 100 --inner-loop-grad-clip 10 --condition-type softmax 124 | ``` 125 | 126 | ### MMAML-preupdate 127 | ``` 128 | python main.py --dataset five --num-batches 70000 --model-type gated --fast-lr 0.00 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 1 --output-folder 5mods-mmaml-pre-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 200 200 --inner-loop-grad-clip 10 129 | ``` 130 | 131 | 132 | 133 | ## Authors 134 | 135 | [Hexiang Hu](http://hexianghu.com/), [Shao-Hua Sun](http://shaohua0116.github.io/), [Risto Vuorio](https://vuoristo.github.io/) 136 | -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vuoristo/MMAML-Regression/1a8bea4d60461d8814e3c9f91427ed56378716ce/assets/model.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mmaml_regression 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - python=3.7 6 | - pytorch=0.4.1 7 | - numpy 8 | - pip 9 | - pip: 10 | - tensorboardX -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | 9 | from maml.datasets.simple_functions import ( 10 | SinusoidMetaDataset, 11 | LinearMetaDataset, 12 | MixedFunctionsMetaDataset, 13 | ManyFunctionsMetaDataset, 14 | FiveFunctionsMetaDataset, 15 | MultiSinusoidsMetaDataset, 16 | ) 17 | from maml.models.fully_connected import FullyConnectedModel, MultiFullyConnectedModel 18 | from maml.models.gated_net import GatedNet 19 | from maml.models.lstm_embedding_model import LSTMEmbeddingModel 20 | from maml.metalearner import MetaLearner 21 | from maml.trainer import Trainer 22 | from maml.utils import optimizer_to_device, get_git_revision_hash 23 | 24 | 25 | def main(args): 26 | is_training = not args.eval 27 | run_name = 'train' if is_training else 'eval' 28 | 29 | if is_training: 30 | writer = SummaryWriter('./train_dir/{0}/{1}'.format( 31 | args.output_folder, run_name)) 32 | else: 33 | writer = None 34 | 35 | save_folder = './train_dir/{0}'.format(args.output_folder) 36 | if not os.path.exists(save_folder): 37 | os.makedirs(save_folder) 38 | 39 | config_name = '{0}_config.json'.format(run_name) 40 | with open(os.path.join(save_folder, config_name), 'w') as f: 41 | config = {k: v for (k, v) in vars(args).items() if k != 'device'} 42 | config.update(device=args.device.type) 43 | config.update({'git_hash': get_git_revision_hash()}) 44 | json.dump(config, f, indent=2) 45 | 46 | _num_tasks = 1 47 | if args.dataset == 'sinusoid': 48 | dataset = SinusoidMetaDataset( 49 | num_total_batches=args.num_batches, 50 | num_samples_per_function=args.num_samples_per_class, 51 | num_val_samples=args.num_val_samples, 52 | meta_batch_size=args.meta_batch_size, 53 | amp_range=args.amp_range, 54 | phase_range=args.phase_range, 55 | input_range=args.input_range, 56 | oracle=args.oracle, 57 | train=is_training, 58 | device=args.device) 59 | loss_func = torch.nn.MSELoss() 60 | collect_accuracies = False 61 | elif args.dataset == 'linear': 62 | dataset = LinearMetaDataset( 63 | num_total_batches=args.num_batches, 64 | num_samples_per_function=args.num_samples_per_class, 65 | num_val_samples=args.num_val_samples, 66 | meta_batch_size=args.meta_batch_size, 67 | slope_range=args.slope_range, 68 | intersect_range=args.intersect_range, 69 | input_range=args.input_range, 70 | oracle=args.oracle, 71 | train=is_training, 72 | device=args.device) 73 | loss_func = torch.nn.MSELoss() 74 | collect_accuracies = False 75 | elif args.dataset == 'mixed': 76 | dataset = MixedFunctionsMetaDataset( 77 | num_total_batches=args.num_batches, 78 | num_samples_per_function=args.num_samples_per_class, 79 | num_val_samples=args.num_val_samples, 80 | meta_batch_size=args.meta_batch_size, 81 | amp_range=args.amp_range, 82 | phase_range=args.phase_range, 83 | slope_range=args.slope_range, 84 | intersect_range=args.intersect_range, 85 | input_range=args.input_range, 86 | noise_std=args.noise_std, 87 | oracle=args.oracle, 88 | task_oracle=args.task_oracle, 89 | train=is_training, 90 | device=args.device) 91 | loss_func = torch.nn.MSELoss() 92 | collect_accuracies = False 93 | _num_tasks = 2 94 | elif args.dataset == 'five': 95 | dataset = FiveFunctionsMetaDataset( 96 | num_total_batches=args.num_batches, 97 | num_samples_per_function=args.num_samples_per_class, 98 | num_val_samples=args.num_val_samples, 99 | meta_batch_size=args.meta_batch_size, 100 | amp_range=args.amp_range, 101 | phase_range=args.phase_range, 102 | slope_range=args.slope_range, 103 | intersect_range=args.intersect_range, 104 | input_range=args.input_range, 105 | noise_std=args.noise_std, 106 | oracle=args.oracle, 107 | task_oracle=args.task_oracle, 108 | train=is_training, 109 | device=args.device) 110 | loss_func = torch.nn.MSELoss() 111 | collect_accuracies = False 112 | _num_tasks = 5 113 | elif args.dataset == 'many': 114 | dataset = ManyFunctionsMetaDataset( 115 | num_total_batches=args.num_batches, 116 | num_samples_per_function=args.num_samples_per_class, 117 | num_val_samples=args.num_val_samples, 118 | meta_batch_size=args.meta_batch_size, 119 | amp_range=args.amp_range, 120 | phase_range=args.phase_range, 121 | slope_range=args.slope_range, 122 | intersect_range=args.intersect_range, 123 | input_range=args.input_range, 124 | noise_std=args.noise_std, 125 | oracle=args.oracle, 126 | task_oracle=args.task_oracle, 127 | train=is_training, 128 | device=args.device) 129 | loss_func = torch.nn.MSELoss() 130 | collect_accuracies = False 131 | _num_tasks = 3 132 | elif args.dataset == 'multisinusoids': 133 | dataset = MultiSinusoidsMetaDataset( 134 | num_total_batches=args.num_batches, 135 | num_samples_per_function=args.num_samples_per_class, 136 | num_val_samples=args.num_val_samples, 137 | meta_batch_size=args.meta_batch_size, 138 | amp_range=args.amp_range, 139 | phase_range=args.phase_range, 140 | slope_range=args.slope_range, 141 | intersect_range=args.intersect_range, 142 | input_range=args.input_range, 143 | noise_std=args.noise_std, 144 | oracle=args.oracle, 145 | task_oracle=args.task_oracle, 146 | train=is_training, 147 | device=args.device) 148 | loss_func = torch.nn.MSELoss() 149 | collect_accuracies = False 150 | else: 151 | raise ValueError('Unrecognized dataset {}'.format(args.dataset)) 152 | 153 | embedding_model = None 154 | 155 | if args.model_type == 'fc': 156 | model = FullyConnectedModel( 157 | input_size=np.prod(dataset.input_size), 158 | output_size=dataset.output_size, 159 | hidden_sizes=args.hidden_sizes, 160 | disable_norm=args.disable_norm, 161 | bias_transformation_size=args.bias_transformation_size) 162 | elif args.model_type == 'multi': 163 | model = MultiFullyConnectedModel( 164 | input_size=np.prod(dataset.input_size), 165 | output_size=dataset.output_size, 166 | hidden_sizes=args.hidden_sizes, 167 | disable_norm=args.disable_norm, 168 | num_tasks=_num_tasks, 169 | bias_transformation_size=args.bias_transformation_size) 170 | elif args.model_type == 'gated': 171 | model = GatedNet( 172 | input_size=np.prod(dataset.input_size), 173 | output_size=dataset.output_size, 174 | hidden_sizes=args.hidden_sizes, 175 | condition_type=args.condition_type, 176 | condition_order=args.condition_order) 177 | else: 178 | raise ValueError('Unrecognized model type {}'.format(args.model_type)) 179 | model_parameters = list(model.parameters()) 180 | 181 | if args.embedding_type == '': 182 | embedding_model = None 183 | elif args.embedding_type == 'LSTM': 184 | embedding_model = LSTMEmbeddingModel( 185 | input_size=np.prod(dataset.input_size), 186 | output_size=dataset.output_size, 187 | embedding_dims=args.embedding_dims, 188 | hidden_size=args.embedding_hidden_size, 189 | num_layers=args.embedding_num_layers) 190 | embedding_parameters = list(embedding_model.parameters()) 191 | else: 192 | raise ValueError('Unrecognized embedding type {}'.format( 193 | args.embedding_type)) 194 | 195 | optimizers = None 196 | if embedding_model: 197 | optimizers = (torch.optim.Adam(model_parameters, lr=args.slow_lr), 198 | torch.optim.Adam(embedding_parameters, lr=args.slow_lr)) 199 | else: 200 | optimizers = (torch.optim.Adam(model_parameters, lr=args.slow_lr), ) 201 | 202 | if args.checkpoint != '': 203 | checkpoint = torch.load(args.checkpoint) 204 | model.load_state_dict(checkpoint['model_state_dict']) 205 | model.to(args.device) 206 | if 'optimizer' in checkpoint: 207 | pass 208 | else: 209 | optimizers[0].load_state_dict(checkpoint['optimizers'][0]) 210 | optimizer_to_device(optimizers[0], args.device) 211 | 212 | if embedding_model: 213 | embedding_model.load_state_dict( 214 | checkpoint['embedding_model_state_dict']) 215 | optimizers[1].load_state_dict(checkpoint['optimizers'][1]) 216 | optimizer_to_device(optimizers[1], args.device) 217 | 218 | meta_learner = MetaLearner( 219 | model, embedding_model, optimizers, fast_lr=args.fast_lr, 220 | loss_func=loss_func, first_order=args.first_order, 221 | num_updates=args.num_updates, 222 | inner_loop_grad_clip=args.inner_loop_grad_clip, 223 | collect_accuracies=collect_accuracies, device=args.device, 224 | embedding_grad_clip=args.embedding_grad_clip, 225 | model_grad_clip=args.model_grad_clip) 226 | 227 | trainer = Trainer( 228 | meta_learner=meta_learner, meta_dataset=dataset, writer=writer, 229 | log_interval=args.log_interval, save_interval=args.save_interval, 230 | model_type=args.model_type, save_folder=save_folder) 231 | 232 | if is_training: 233 | trainer.train() 234 | else: 235 | trainer.eval() 236 | 237 | 238 | if __name__ == '__main__': 239 | 240 | def str2bool(arg): 241 | return arg.lower() == 'true' 242 | 243 | parser = argparse.ArgumentParser( 244 | description='Multimodal Model-Agnostic Meta-Learning (MAML)') 245 | 246 | # Model 247 | parser.add_argument('--hidden-sizes', type=int, 248 | default=[256, 128, 64, 64], nargs='+', 249 | help='number of hidden units per layer') 250 | parser.add_argument('--model-type', type=str, default='fc', 251 | help='type of the model') 252 | parser.add_argument('--condition-type', type=str, default='affine', 253 | help='type of the conditional layers') 254 | parser.add_argument('--use-max-pool', action='store_true', 255 | help='choose whether to use max pooling with convolutional model') 256 | parser.add_argument('--num-channels', type=int, default=64, 257 | help='number of channels in convolutional layers') 258 | parser.add_argument('--disable-norm', action='store_true', 259 | help='disable batchnorm after linear layers in a fully connected model') 260 | parser.add_argument('--bias-transformation-size', type=int, default=0, 261 | help='size of bias transformation vector that is concatenated with ' 262 | 'input') 263 | parser.add_argument('--condition-order', type=str, default='low2high', 264 | help='order of the conditional layers to be used') 265 | 266 | # Embedding 267 | parser.add_argument('--embedding-type', type=str, default='', 268 | help='type of the embedding') 269 | parser.add_argument('--embedding-hidden-size', type=int, default=40, 270 | help='number of hidden units per layer in recurrent embedding model') 271 | parser.add_argument('--embedding-num-layers', type=int, default=2, 272 | help='number of layers in recurrent embedding model') 273 | parser.add_argument('--embedding-dims', type=int, nargs='+', default=0, 274 | help='dimensions of the embeddings') 275 | 276 | # Randomly sampled embedding vectors 277 | parser.add_argument('--num-sample-embedding', type=int, default=0, 278 | help='number of randomly sampled embedding vectors') 279 | parser.add_argument( 280 | '--sample-embedding-file', type=str, default='embeddings', 281 | help='the file name of randomly sampled embedding vectors') 282 | parser.add_argument( 283 | '--sample-embedding-file-type', type=str, default='hdf5') 284 | 285 | # Inner loop 286 | parser.add_argument('--first-order', action='store_true', 287 | help='use the first-order approximation of MAML') 288 | parser.add_argument('--fast-lr', type=float, default=0.4, 289 | help='learning rate for the 1-step gradient update of MAML') 290 | parser.add_argument('--inner-loop-grad-clip', type=float, default=0.0, 291 | help='enable gradient clipping in the inner loop') 292 | parser.add_argument('--num-updates', type=int, default=1, 293 | help='how many update steps in the inner loop') 294 | 295 | # Optimization 296 | parser.add_argument('--num-batches', type=int, default=1920000, 297 | help='number of batches') 298 | parser.add_argument('--meta-batch-size', type=int, default=32, 299 | help='number of tasks per batch') 300 | parser.add_argument('--slow-lr', type=float, default=0.001, 301 | help='learning rate for the global update of MAML') 302 | 303 | # Miscellaneous 304 | parser.add_argument('--output-folder', type=str, default='maml', 305 | help='name of the output folder') 306 | parser.add_argument('--device', type=str, default='cpu', 307 | help='set the device (cpu or cuda)') 308 | parser.add_argument('--num-workers', type=int, default=4, 309 | help='how many DataLoader workers to use') 310 | parser.add_argument('--log-interval', type=int, default=100, 311 | help='number of batches between tensorboard writes') 312 | parser.add_argument('--save-interval', type=int, default=1000, 313 | help='number of batches between model saves') 314 | parser.add_argument('--eval', action='store_true', default=False, 315 | help='evaluate model') 316 | parser.add_argument('--checkpoint', type=str, default='', 317 | help='path to saved parameters.') 318 | 319 | # Dataset 320 | parser.add_argument('--dataset', type=str, default='omniglot', 321 | help='which dataset to use') 322 | parser.add_argument('--data-root', type=str, default='data', 323 | help='path to store datasets') 324 | parser.add_argument('--num-samples-per-class', type=int, default=1, 325 | help='how many samples per class for training') 326 | parser.add_argument('--num-val-samples', type=int, default=1, 327 | help='how many samples per class for validation') 328 | parser.add_argument('--input-range', type=float, default=[-5.0, 5.0], 329 | nargs='+', help='input range of simple functions') 330 | parser.add_argument('--phase-range', type=float, default=[0, np.pi], 331 | nargs='+', help='phase range of sinusoids') 332 | parser.add_argument('--amp-range', type=float, default=[0.1, 5.0], 333 | nargs='+', help='amp range of sinusoids') 334 | parser.add_argument('--slope-range', type=float, default=[-3.0, 3.0], 335 | nargs='+', help='slope range of linear functions') 336 | parser.add_argument('--intersect-range', type=float, default=[-3.0, 3.0], 337 | nargs='+', help='intersect range of linear functions') 338 | parser.add_argument('--noise-std', type=float, default=0.0, 339 | help='add gaussian noise to mixed functions') 340 | parser.add_argument('--oracle', action='store_true', 341 | help='concatenate phase and amp to sinusoid inputs') 342 | parser.add_argument('--task-oracle', action='store_true', 343 | help='uses task id for prediction in some models') 344 | 345 | parser.add_argument('--embedding-grad-clip', type=float, default=2.0, 346 | help='') 347 | parser.add_argument('--model-grad-clip', type=float, default=2.0, 348 | help='') 349 | 350 | args = parser.parse_args() 351 | 352 | if args.embedding_dims == 0: 353 | args.embedding_dims = args.hidden_sizes 354 | 355 | # Create logs and saves folder if they don't exist 356 | if not os.path.exists('./train_dir'): 357 | os.makedirs('./train_dir') 358 | 359 | # Make sure num sample embedding < num sample tasks 360 | args.num_sample_embedding = min( 361 | args.num_sample_embedding, args.num_batches) 362 | 363 | # Device 364 | args.device = torch.device( 365 | args.device if torch.cuda.is_available() else 'cpu') 366 | 367 | main(args) 368 | -------------------------------------------------------------------------------- /maml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vuoristo/MMAML-Regression/1a8bea4d60461d8814e3c9f91427ed56378716ce/maml/__init__.py -------------------------------------------------------------------------------- /maml/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vuoristo/MMAML-Regression/1a8bea4d60461d8814e3c9f91427ed56378716ce/maml/datasets/__init__.py -------------------------------------------------------------------------------- /maml/datasets/metadataset.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple 2 | 3 | Task = namedtuple('Task', ['x', 'y', 'task_info']) 4 | -------------------------------------------------------------------------------- /maml/datasets/simple_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from maml.datasets.metadataset import Task 5 | 6 | 7 | def generate_sinusoid_batch(amp_range, phase_range, input_range, num_samples, 8 | batch_size, oracle, bias=0): 9 | amp = np.random.uniform(amp_range[0], amp_range[1], [batch_size]) 10 | phase = np.random.uniform(phase_range[0], phase_range[1], [batch_size]) 11 | outputs = np.zeros([batch_size, num_samples, 1]) 12 | inputs = np.zeros([batch_size, num_samples, 1]) 13 | for i in range(batch_size): 14 | inputs[i] = np.random.uniform(input_range[0], input_range[1], 15 | [num_samples, 1]) 16 | outputs[i] = amp[i] * np.sin(inputs[i] - phase[i]) + bias 17 | 18 | if oracle: 19 | amps = np.ones_like(inputs) * amp.reshape(-1, 1, 1) 20 | phases = np.ones_like(inputs) * phase.reshape(-1, 1, 1) 21 | inputs = np.concatenate((inputs, amps, phases), axis=2) 22 | 23 | return inputs, outputs, amp, phase 24 | 25 | 26 | def generate_linear_batch(slope_range, intersect_range, input_range, 27 | num_samples, batch_size, oracle): 28 | slope = np.random.uniform(slope_range[0], slope_range[1], [batch_size]) 29 | intersect = np.random.uniform(intersect_range[0], intersect_range[1], 30 | [batch_size]) 31 | outputs = np.zeros([batch_size, num_samples, 1]) 32 | inputs = np.zeros([batch_size, num_samples, 1]) 33 | for i in range(batch_size): 34 | inputs[i] = np.random.uniform(input_range[0], input_range[1], 35 | [num_samples, 1]) 36 | outputs[i] = inputs[i] * slope[i] + intersect[i] 37 | 38 | if oracle: 39 | slopes = np.ones_like(inputs) * slope.reshape(-1, 1, 1) 40 | intersects = np.ones_like(inputs) * intersect.reshape(-1, 1, 1) 41 | inputs = np.concatenate((inputs, slopes, intersects), axis=2) 42 | 43 | return inputs, outputs, slope, intersect 44 | 45 | 46 | class SimpleFunctionDataset(object): 47 | def __init__(self, num_total_batches=200000, num_samples_per_function=5, 48 | num_val_samples=5, meta_batch_size=75, oracle=False, 49 | train=True, device='cpu', dtype=torch.float, **kwargs): 50 | self._num_total_batches = num_total_batches 51 | self._num_samples_per_function = num_samples_per_function 52 | self._num_val_samples = num_val_samples 53 | self._num_total_samples = num_samples_per_function 54 | self._meta_batch_size = meta_batch_size 55 | self._oracle = oracle 56 | self._train = train 57 | self._device = device 58 | self._dtype = dtype 59 | 60 | def _generate_batch(self): 61 | raise NotImplementedError('Subclass should implement _generate_batch') 62 | 63 | def __iter__(self): 64 | for batch in range(self._num_total_batches): 65 | inputs, outputs, infos = self._generate_batch() 66 | 67 | train_tasks = [] 68 | val_tasks = [] 69 | for task in range(self._meta_batch_size): 70 | task_inputs = torch.tensor( 71 | inputs[task], device=self._device, dtype=self._dtype) 72 | task_outputs = torch.tensor( 73 | outputs[task], device=self._device, dtype=self._dtype) 74 | task_infos = infos[task] 75 | train_task = Task(task_inputs[self._num_val_samples:], 76 | task_outputs[self._num_val_samples:], 77 | task_infos) 78 | train_tasks.append(train_task) 79 | val_task = Task(task_inputs[:self._num_val_samples], 80 | task_outputs[:self._num_val_samples], 81 | task_infos) 82 | val_tasks.append(val_task) 83 | yield train_tasks, val_tasks 84 | 85 | class BiasedSinusoidMetaDataset(SimpleFunctionDataset): 86 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi], 87 | input_range=[-5.0, 5.0], bias=0, **kwargs): 88 | super(BiasedSinusoidMetaDataset, self).__init__(**kwargs) 89 | self._amp_range = amp_range 90 | self._phase_range = phase_range 91 | self._input_range = input_range 92 | self._bias = bias 93 | 94 | if self._oracle: 95 | self.input_size = 3 96 | else: 97 | self.input_size = 1 98 | self.output_size = 1 99 | 100 | def _generate_batch(self): 101 | inputs, outputs, amp, phase = generate_sinusoid_batch( 102 | amp_range=self._amp_range, phase_range=self._phase_range, 103 | input_range=self._input_range, 104 | num_samples=self._num_total_samples, 105 | batch_size=self._meta_batch_size, oracle=self._oracle, bias=self._bias) 106 | task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]} 107 | for i in range(len(amp))] 108 | return inputs, outputs, task_infos 109 | 110 | class SinusoidMetaDataset(SimpleFunctionDataset): 111 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi], 112 | input_range=[-5.0, 5.0], **kwargs): 113 | super(SinusoidMetaDataset, self).__init__(**kwargs) 114 | self._amp_range = amp_range 115 | self._phase_range = phase_range 116 | self._input_range = input_range 117 | 118 | if self._oracle: 119 | self.input_size = 3 120 | else: 121 | self.input_size = 1 122 | self.output_size = 1 123 | 124 | def _generate_batch(self): 125 | inputs, outputs, amp, phase = generate_sinusoid_batch( 126 | amp_range=self._amp_range, phase_range=self._phase_range, 127 | input_range=self._input_range, 128 | num_samples=self._num_total_samples, 129 | batch_size=self._meta_batch_size, oracle=self._oracle) 130 | task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]} 131 | for i in range(len(amp))] 132 | return inputs, outputs, task_infos 133 | 134 | 135 | class LinearMetaDataset(SimpleFunctionDataset): 136 | def __init__(self, slope_range=[-3.0, 3.0], intersect_range=[-3, 3], 137 | input_range=[-5.0, 5.0], **kwargs): 138 | super(LinearMetaDataset, self).__init__(**kwargs) 139 | self._slope_range = slope_range 140 | self._intersect_range = intersect_range 141 | self._input_range = input_range 142 | 143 | if self._oracle: 144 | self.input_size = 3 145 | else: 146 | self.input_size = 1 147 | self.output_size = 1 148 | 149 | def _generate_batch(self): 150 | inputs, outputs, slope, intersect = generate_linear_batch( 151 | slope_range=self._slope_range, 152 | intersect_range=self._intersect_range, 153 | input_range=self._input_range, 154 | num_samples=self._num_total_samples, 155 | batch_size=self._meta_batch_size, oracle=self._oracle) 156 | task_infos = [{'task_id': 0, 'slope': slope[i], 'intersect': intersect[i]} 157 | for i in range(len(slope))] 158 | return inputs, outputs, task_infos 159 | 160 | def generate_quadratic_batch(center_range, bias_range, sign_range, input_range, 161 | num_samples, batch_size, oracle): 162 | center = np.random.uniform(center_range[0], center_range[1], [batch_size]) 163 | bias = np.random.uniform(bias_range[0], bias_range[1], [batch_size]) 164 | 165 | # alpha range 166 | alpha = np.random.uniform(sign_range[0], sign_range[1], [batch_size]) 167 | sign = np.random.randint(2, size=[batch_size]) 168 | sign[sign == 0] = -1 169 | sign = alpha * sign 170 | 171 | outputs = np.zeros([batch_size, num_samples, 1]) 172 | inputs = np.zeros([batch_size, num_samples, 1]) 173 | for i in range(batch_size): 174 | inputs[i] = np.random.uniform(input_range[0], input_range[1], 175 | [num_samples, 1]) 176 | outputs[i] = sign[i] * (inputs[i] - center[i])**2 + bias[i] 177 | 178 | if oracle: 179 | centers = np.ones_like(inputs) * center.reshape(-1, 1, 1) 180 | biases = np.ones_like(inputs) * bias.reshape(-1, 1, 1) 181 | inputs = np.concatenate((inputs, centers, biases), axis=2) 182 | 183 | return inputs, outputs, sign, center, bias 184 | 185 | 186 | class QuadraticMetaDataset(SimpleFunctionDataset): 187 | """ Quadratic function like: sign * (x - center)^2 + bias 188 | """ 189 | def __init__(self, center_range=[-3.0, 3.0], bias_range=[-3, 3], sign_range=[0.02, 0.15], 190 | input_range=[-5.0, 5.0], **kwargs): 191 | super(QuadraticMetaDataset, self).__init__(**kwargs) 192 | self._center_range = center_range 193 | self._bias_range = bias_range 194 | self._input_range = input_range 195 | self._sign_range = sign_range 196 | 197 | if self._oracle: 198 | self.input_size = 3 199 | else: 200 | self.input_size = 1 201 | self.output_size = 1 202 | 203 | def _generate_batch(self): 204 | inputs, outputs, sign, center, bias = generate_quadratic_batch( 205 | center_range=self._center_range, 206 | bias_range=self._bias_range, 207 | sign_range=self._sign_range, 208 | input_range=self._input_range, 209 | num_samples=self._num_total_samples, 210 | batch_size=self._meta_batch_size, oracle=self._oracle) 211 | task_infos = [{'task_id': 2, 'sign': sign[i], 'center': center[i], 'bias': bias[i]} 212 | for i in range(len(sign))] 213 | return inputs, outputs, task_infos 214 | 215 | 216 | class MixedFunctionsMetaDataset(SimpleFunctionDataset): 217 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi], 218 | input_range=[-5.0, 5.0], slope_range=[-3.0, 3.0], 219 | intersect_range=[-3.0, 3.0], task_oracle=False, 220 | noise_std=0, **kwargs): 221 | super(MixedFunctionsMetaDataset, self).__init__(**kwargs) 222 | self._amp_range = amp_range 223 | self._phase_range = phase_range 224 | self._slope_range = slope_range 225 | self._intersect_range = intersect_range 226 | self._input_range = input_range 227 | self._task_oracle = task_oracle 228 | self._noise_std = noise_std 229 | 230 | if not self._oracle: 231 | if not self._task_oracle: 232 | self.input_size = 1 233 | else: 234 | self.input_size = 2 235 | else: 236 | if not self._task_oracle: 237 | self.input_size = 3 238 | else: 239 | self.input_size = 4 240 | 241 | self.output_size = 1 242 | self.num_tasks = 2 243 | 244 | def _generate_batch(self): 245 | half_batch_size = self._meta_batch_size // 2 246 | sin_inputs, sin_outputs, amp, phase = generate_sinusoid_batch( 247 | amp_range=self._amp_range, phase_range=self._phase_range, 248 | input_range=self._input_range, 249 | num_samples=self._num_total_samples, 250 | batch_size=half_batch_size, oracle=self._oracle) 251 | sin_task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]} 252 | for i in range(len(amp))] 253 | if self._task_oracle: 254 | sin_inputs = np.concatenate( 255 | (sin_inputs, np.zeros(sin_inputs.shape[:2] + (1,))), axis=2) 256 | 257 | lin_inputs, lin_outputs, slope, intersect = generate_linear_batch( 258 | slope_range=self._slope_range, 259 | intersect_range=self._intersect_range, 260 | input_range=self._input_range, 261 | num_samples=self._num_total_samples, 262 | batch_size=half_batch_size, oracle=self._oracle) 263 | lin_task_infos = [{'task_id': 1, 'slope': slope[i], 'intersect': intersect[i]} 264 | for i in range(len(slope))] 265 | if self._task_oracle: 266 | lin_inputs = np.concatenate( 267 | (lin_inputs, np.ones(lin_inputs.shape[:2] + (1,))), axis=2) 268 | inputs = np.concatenate((sin_inputs, lin_inputs)) 269 | outputs = np.concatenate((sin_outputs, lin_outputs)) 270 | 271 | if self._noise_std > 0: 272 | outputs = outputs + np.random.normal(scale=self._noise_std, size=outputs.shape) 273 | task_infos = sin_task_infos + lin_task_infos 274 | return inputs, outputs, task_infos 275 | 276 | class ManyFunctionsMetaDataset(SimpleFunctionDataset): 277 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi], 278 | input_range=[-5.0, 5.0], slope_range=[-3.0, 3.0], 279 | intersect_range=[-3.0, 3.0], center_range=[-3.0, 3.0], 280 | bias_range=[-3.0, 3.0], sign_range=[0.02, 0.15], task_oracle=False, 281 | noise_std=0, **kwargs): 282 | super(ManyFunctionsMetaDataset, self).__init__(**kwargs) 283 | self._amp_range = amp_range 284 | self._phase_range = phase_range 285 | self._slope_range = slope_range 286 | self._intersect_range = intersect_range 287 | self._input_range = input_range 288 | self._center_range = center_range 289 | self._bias_range = bias_range 290 | self._sign_range = sign_range 291 | self._task_oracle = task_oracle 292 | self._noise_std = noise_std 293 | 294 | if not self._oracle: 295 | if not self._task_oracle: 296 | self.input_size = 1 297 | else: 298 | self.input_size = 2 299 | else: 300 | if not self._task_oracle: 301 | self.input_size = 3 302 | else: 303 | self.input_size = 4 304 | 305 | self.output_size = 1 306 | self.num_tasks = 2 307 | 308 | def _generate_batch(self): 309 | half_batch_size = self._meta_batch_size // 3 310 | sin_inputs, sin_outputs, amp, phase = generate_sinusoid_batch( 311 | amp_range=self._amp_range, phase_range=self._phase_range, 312 | input_range=self._input_range, 313 | num_samples=self._num_total_samples, 314 | batch_size=half_batch_size, oracle=self._oracle) 315 | sin_task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]} 316 | for i in range(len(amp))] 317 | if self._task_oracle: 318 | sin_inputs = np.concatenate( 319 | (sin_inputs, np.zeros(sin_inputs.shape[:2] + (1,))), axis=2) 320 | 321 | lin_inputs, lin_outputs, slope, intersect = generate_linear_batch( 322 | slope_range=self._slope_range, 323 | intersect_range=self._intersect_range, 324 | input_range=self._input_range, 325 | num_samples=self._num_total_samples, 326 | batch_size=half_batch_size, oracle=self._oracle) 327 | lin_task_infos = [{'task_id': 1, 'slope': slope[i], 'intersect': intersect[i]} 328 | for i in range(len(slope))] 329 | if self._task_oracle: 330 | lin_inputs = np.concatenate( 331 | (lin_inputs, np.ones(lin_inputs.shape[:2] + (1,))), axis=2) 332 | 333 | qua_inputs, qua_outputs, sign, center, bias = generate_quadratic_batch( 334 | center_range=self._center_range, 335 | bias_range=self._bias_range, 336 | sign_range=self._sign_range, 337 | input_range=self._input_range, 338 | num_samples=self._num_total_samples, 339 | batch_size=half_batch_size, oracle=self._oracle) 340 | qua_task_infos = [{'task_id': 2, 'sign': sign[i], 'center': center[i], 'bias': bias[i]} 341 | for i in range(len(sign))] 342 | 343 | if self._task_oracle: 344 | qua_inputs = np.concatenate( 345 | (qua_inputs, np.ones(qua_inputs.shape[:2] + (1,))), axis=2) 346 | 347 | inputs = np.concatenate((sin_inputs, lin_inputs, qua_inputs)) 348 | outputs = np.concatenate((sin_outputs, lin_outputs, qua_outputs)) 349 | 350 | if self._noise_std > 0: 351 | outputs = outputs + np.random.normal(scale=self._noise_std, size=outputs.shape) 352 | task_infos = sin_task_infos + lin_task_infos + qua_task_infos 353 | return inputs, outputs, task_infos 354 | 355 | 356 | class MultiSinusoidsMetaDataset(SimpleFunctionDataset): 357 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi], biases=(-5, 5), 358 | input_range=[-5.0, 5.0], task_oracle=False, 359 | noise_std=0, **kwargs): 360 | super(MultiSinusoidsMetaDataset, self).__init__(**kwargs) 361 | self._amp_range = amp_range 362 | self._phase_range = phase_range 363 | self._input_range = input_range 364 | self._task_oracle = task_oracle 365 | self._noise_std = noise_std 366 | self._biases = biases 367 | 368 | if not self._oracle: 369 | if not self._task_oracle: 370 | self.input_size = 1 371 | else: 372 | self.input_size = 2 373 | else: 374 | if not self._task_oracle: 375 | self.input_size = 3 376 | else: 377 | self.input_size = 4 378 | 379 | self.output_size = 1 380 | self.num_tasks = 2 381 | 382 | def _generate_batch(self): 383 | half_batch_size = self._meta_batch_size // 2 384 | sin1_inputs, sin1_outputs, amp1, phase1 = generate_sinusoid_batch( 385 | amp_range=self._amp_range, phase_range=self._phase_range, 386 | input_range=self._input_range, 387 | num_samples=self._num_total_samples, 388 | batch_size=half_batch_size, oracle=self._oracle, bias=self._biases[0]) 389 | sin1_task_infos = [{'task_id': 0, 'amp': amp1[i], 'phase': phase1[i]} 390 | for i in range(len(amp1))] 391 | if self._task_oracle: 392 | sin1_inputs = np.concatenate( 393 | (sin1_inputs, np.zeros(sin1_inputs.shape[:2] + (1,))), axis=2) 394 | 395 | sin2_inputs, sin2_outputs, amp2, phase2 = generate_sinusoid_batch( 396 | amp_range=self._amp_range, phase_range=self._phase_range, 397 | input_range=self._input_range, 398 | num_samples=self._num_total_samples, 399 | batch_size=half_batch_size, oracle=self._oracle, bias=self._biases[0]) 400 | sin2_task_infos = [{'task_id': 1, 'amp': amp2[i], 'phase': phase2[i]} 401 | for i in range(len(amp2))] 402 | if self._task_oracle: 403 | sin2_inputs = np.concatenate( 404 | (sin2_inputs, np.zeros(sin2_inputs.shape[:2] + (1,))), axis=2) 405 | 406 | inputs = np.concatenate((sin1_inputs, sin2_inputs)) 407 | outputs = np.concatenate((sin2_outputs, sin2_outputs)) 408 | 409 | if self._noise_std > 0: 410 | outputs = outputs + np.random.normal(scale=self._noise_std, 411 | size=outputs.shape) 412 | task_infos = sin1_task_infos + sin2_task_infos 413 | return inputs, outputs, task_infos 414 | 415 | def generate_tanh_batch(center_range, bias_range, slope_range, input_range, 416 | num_samples, batch_size, oracle): 417 | center = np.random.uniform(center_range[0], center_range[1], [batch_size]) 418 | bias = np.random.uniform(bias_range[0], bias_range[1], [batch_size]) 419 | 420 | # alpha range 421 | slope = np.random.uniform(slope_range[0], slope_range[1], [batch_size]) 422 | 423 | outputs = np.zeros([batch_size, num_samples, 1]) 424 | inputs = np.zeros([batch_size, num_samples, 1]) 425 | for i in range(batch_size): 426 | inputs[i] = np.random.uniform(input_range[0], input_range[1], 427 | [num_samples, 1]) 428 | outputs[i] = slope[i] * np.tanh(inputs[i] - center[i]) + bias[i] 429 | 430 | if oracle: 431 | centers = np.ones_like(inputs) * center.reshape(-1, 1, 1) 432 | biases = np.ones_like(inputs) * bias.reshape(-1, 1, 1) 433 | inputs = np.concatenate((inputs, centers, biases), axis=2) 434 | 435 | return inputs, outputs, slope, center, bias 436 | 437 | def generate_abs_batch(slope_range, center_range, bias_range, input_range, 438 | num_samples, batch_size, oracle): 439 | slope = np.random.uniform(slope_range[0], slope_range[1], [batch_size]) 440 | bias = np.random.uniform(bias_range[0], bias_range[1], [batch_size]) 441 | center = np.random.uniform(center_range[0], center_range[1], [batch_size]) 442 | 443 | outputs = np.zeros([batch_size, num_samples, 1]) 444 | inputs = np.zeros([batch_size, num_samples, 1]) 445 | for i in range(batch_size): 446 | inputs[i] = np.random.uniform(input_range[0], input_range[1], 447 | [num_samples, 1]) 448 | outputs[i] = np.abs(inputs[i] - center[i]) * slope[i] + bias[i] 449 | 450 | if oracle: 451 | slopes = np.ones_like(inputs) * slope.reshape(-1, 1, 1) 452 | intersects = np.ones_like(inputs) * intersect.reshape(-1, 1, 1) 453 | inputs = np.concatenate((inputs, slopes, intersects), axis=2) 454 | 455 | return inputs, outputs, slope, center, bias 456 | 457 | 458 | class FiveFunctionsMetaDataset(SimpleFunctionDataset): 459 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi], 460 | input_range=[-5.0, 5.0], slope_range=[-3.0, 3.0], 461 | intersect_range=[-3.0, 3.0], center_range=[-3.0, 3.0], 462 | bias_range=[-3.0, 3.0], sign_range=[0.02, 0.15], task_oracle=False, 463 | noise_std=0, **kwargs): 464 | super(FiveFunctionsMetaDataset, self).__init__(**kwargs) 465 | self._amp_range = amp_range 466 | self._phase_range = phase_range 467 | self._slope_range = slope_range 468 | self._intersect_range = intersect_range 469 | self._input_range = input_range 470 | self._center_range = center_range 471 | self._bias_range = bias_range 472 | self._sign_range = sign_range 473 | self._task_oracle = task_oracle 474 | self._noise_std = noise_std 475 | 476 | assert self._task_oracle == False 477 | if not self._oracle: 478 | if not self._task_oracle: 479 | self.input_size = 1 480 | else: 481 | self.input_size = 2 482 | else: 483 | if not self._task_oracle: 484 | self.input_size = 3 485 | else: 486 | self.input_size = 4 487 | 488 | self.output_size = 1 489 | self.num_tasks = 2 490 | 491 | def _generate_batch(self): 492 | assert self._meta_batch_size % 5 == 0, 'Error size of meta batch.' 493 | half_batch_size = self._meta_batch_size // 5 494 | sin_inputs, sin_outputs, amp, phase = generate_sinusoid_batch( 495 | amp_range=self._amp_range, phase_range=self._phase_range, 496 | input_range=self._input_range, 497 | num_samples=self._num_total_samples, 498 | batch_size=half_batch_size, oracle=self._oracle) 499 | sin_task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]} 500 | for i in range(len(amp))] 501 | 502 | lin_inputs, lin_outputs, slope, intersect = generate_linear_batch( 503 | slope_range=self._slope_range, 504 | intersect_range=self._intersect_range, 505 | input_range=self._input_range, 506 | num_samples=self._num_total_samples, 507 | batch_size=half_batch_size, oracle=self._oracle) 508 | lin_task_infos = [{'task_id': 1, 'slope': slope[i], 'intersect': intersect[i]} 509 | for i in range(len(slope))] 510 | 511 | qua_inputs, qua_outputs, sign, center, bias = generate_quadratic_batch( 512 | center_range=self._center_range, 513 | bias_range=self._bias_range, 514 | sign_range=self._sign_range, 515 | input_range=self._input_range, 516 | num_samples=self._num_total_samples, 517 | batch_size=half_batch_size, oracle=self._oracle) 518 | qua_task_infos = [{'task_id': 2, 'sign': sign[i], 'center': center[i], 'bias': bias[i]} 519 | for i in range(len(sign))] 520 | 521 | tanh_inputs, tanh_outputs, slope, center, bias = generate_tanh_batch( 522 | center_range=self._center_range, 523 | bias_range=self._bias_range, 524 | slope_range=self._slope_range, 525 | input_range=self._input_range, 526 | num_samples=self._num_total_samples, 527 | batch_size=half_batch_size, oracle=self._oracle) 528 | tanh_task_infos = [{'task_id': 3, 'slope': slope[i], 'center': center[i], 'bias': bias[i]} 529 | for i in range(len(sign))] 530 | 531 | abs_inputs, abs_outputs, slope, center, bias = generate_abs_batch( 532 | slope_range=self._slope_range, 533 | center_range=self._center_range, 534 | bias_range=self._bias_range, 535 | input_range=self._input_range, 536 | num_samples=self._num_total_samples, 537 | batch_size=half_batch_size, oracle=self._oracle) 538 | abs_task_infos = [{'task_id': 4, 'slope': slope[i], 'center': center[i], 'bias': bias[i]} 539 | for i in range(len(sign))] 540 | 541 | inputs = np.concatenate((sin_inputs, lin_inputs, qua_inputs, tanh_inputs, abs_inputs)) 542 | outputs = np.concatenate((sin_outputs, lin_outputs, qua_outputs, tanh_outputs, abs_outputs)) 543 | 544 | if self._noise_std > 0: 545 | outputs = outputs + np.random.normal(scale=self._noise_std, size=outputs.shape) 546 | task_infos = sin_task_infos + lin_task_infos + qua_task_infos + tanh_task_infos + abs_task_infos 547 | return inputs, outputs, task_infos 548 | -------------------------------------------------------------------------------- /maml/metalearner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.nn.utils.clip_grad import clip_grad_norm_, clip_grad_value_ 4 | from maml.utils import accuracy 5 | 6 | 7 | def get_grad_norm(parameters, norm_type=2): 8 | if isinstance(parameters, torch.Tensor): 9 | parameters = [parameters] 10 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 11 | norm_type = float(norm_type) 12 | total_norm = 0 13 | for p in parameters: 14 | param_norm = p.grad.data.norm(norm_type) 15 | total_norm += param_norm.item() ** norm_type 16 | total_norm = total_norm ** (1. / norm_type) 17 | 18 | return total_norm 19 | 20 | 21 | class MetaLearner(object): 22 | def __init__(self, model, embedding_model, optimizers, fast_lr, loss_func, 23 | first_order, num_updates, inner_loop_grad_clip, 24 | collect_accuracies, device, embedding_grad_clip=0, 25 | model_grad_clip=0): 26 | self._model = model 27 | self._embedding_model = embedding_model 28 | self._fast_lr = fast_lr 29 | self._optimizers = optimizers 30 | self._loss_func = loss_func 31 | self._first_order = first_order 32 | self._num_updates = num_updates 33 | self._inner_loop_grad_clip = inner_loop_grad_clip 34 | self._collect_accuracies = collect_accuracies 35 | self._device = device 36 | self._embedding_grad_clip = embedding_grad_clip 37 | self._model_grad_clip = model_grad_clip 38 | self._grads_mean = [] 39 | 40 | self.to(device) 41 | 42 | self._reset_measurements() 43 | 44 | def _reset_measurements(self): 45 | self._count_iters = 0.0 46 | self._cum_loss = 0.0 47 | self._cum_accuracy = 0.0 48 | 49 | def _update_measurements(self, task, loss, preds): 50 | self._count_iters += 1.0 51 | self._cum_loss += loss.data.cpu().numpy() 52 | if self._collect_accuracies: 53 | self._cum_accuracy += accuracy( 54 | preds, task.y).data.cpu().numpy() 55 | 56 | def _pop_measurements(self): 57 | measurements = {} 58 | loss = self._cum_loss / (self._count_iters + 1e-32) 59 | measurements['loss'] = loss 60 | if self._collect_accuracies: 61 | accuracy = self._cum_accuracy / (self._count_iters + 1e-32) 62 | measurements['accuracy'] = accuracy 63 | self._reset_measurements() 64 | return measurements 65 | 66 | def measure(self, tasks, train_tasks=None, adapted_params_list=None, 67 | embeddings_list=None): 68 | """Measures performance on tasks. Either train_tasks has to be a list 69 | of training task for computing embeddings, or adapted_params_list and 70 | embeddings_list have to contain adapted_params and embeddings""" 71 | if adapted_params_list is None: 72 | adapted_params_list = [None] * len(tasks) 73 | if embeddings_list is None: 74 | embeddings_list = [None] * len(tasks) 75 | for i in range(len(tasks)): 76 | params = adapted_params_list[i] 77 | if params is None: 78 | params = self._model.param_dict 79 | embeddings = embeddings_list[i] 80 | task = tasks[i] 81 | preds = self._model(task, params=params, embeddings=embeddings) 82 | loss = self._loss_func(preds, task.y) 83 | self._update_measurements(task, loss, preds) 84 | 85 | measurements = self._pop_measurements() 86 | return measurements 87 | 88 | def update_params(self, loss, params): 89 | """Apply one step of gradient descent on the loss function `loss`, 90 | with step-size `self._fast_lr`, and returns the updated parameters. 91 | """ 92 | create_graph = not self._first_order 93 | grads = torch.autograd.grad(loss, params.values(), 94 | create_graph=create_graph, allow_unused=True) 95 | for (name, param), grad in zip(params.items(), grads): 96 | if self._inner_loop_grad_clip > 0 and grad is not None: 97 | grad = grad.clamp(min=-self._inner_loop_grad_clip, 98 | max=self._inner_loop_grad_clip) 99 | if grad is not None: 100 | params[name] = param - self._fast_lr * grad 101 | 102 | return params 103 | 104 | def adapt(self, train_tasks, return_task_embedding=False): 105 | adapted_params = [] 106 | embeddings_list = [] 107 | task_embeddings_list = [] 108 | 109 | for task in train_tasks: 110 | params = self._model.param_dict 111 | embeddings = None 112 | if self._embedding_model: 113 | if return_task_embedding: 114 | embeddings, task_embedding = self._embedding_model( 115 | task, return_task_embedding=True) 116 | task_embeddings_list.append(task_embedding) 117 | else: 118 | embeddings = self._embedding_model( 119 | task, return_task_embedding=False) 120 | for i in range(self._num_updates): 121 | preds = self._model(task, params=params, embeddings=embeddings) 122 | loss = self._loss_func(preds, task.y) 123 | params = self.update_params(loss, params=params) 124 | if i == 0: 125 | self._update_measurements(task, loss, preds) 126 | adapted_params.append(params) 127 | embeddings_list.append(embeddings) 128 | 129 | measurements = self._pop_measurements() 130 | if return_task_embedding: 131 | return measurements, adapted_params, embeddings_list, task_embeddings_list 132 | else: 133 | return measurements, adapted_params, embeddings_list 134 | 135 | def step(self, adapted_params_list, embeddings_list, val_tasks, 136 | is_training): 137 | for optimizer in self._optimizers: 138 | optimizer.zero_grad() 139 | post_update_losses = [] 140 | 141 | for adapted_params, embeddings, task in zip( 142 | adapted_params_list, embeddings_list, val_tasks): 143 | preds = self._model(task, params=adapted_params, 144 | embeddings=embeddings) 145 | loss = self._loss_func(preds, task.y) 146 | post_update_losses.append(loss) 147 | self._update_measurements(task, loss, preds) 148 | 149 | mean_loss = torch.mean(torch.stack(post_update_losses)) 150 | if is_training: 151 | self._optimizers[0].zero_grad() 152 | if len(self._optimizers) > 1: 153 | self._optimizers[1].zero_grad() 154 | 155 | mean_loss.backward() 156 | if len(self._optimizers) > 1: 157 | if self._embedding_grad_clip > 0: 158 | _grad_norm = clip_grad_norm_( 159 | self._embedding_model.parameters(), self._embedding_grad_clip) 160 | else: 161 | _grad_norm = get_grad_norm( 162 | self._embedding_model.parameters()) 163 | # grad_norm 164 | self._grads_mean.append(_grad_norm) 165 | self._optimizers[1].step() 166 | 167 | if self._model_grad_clip > 0: 168 | _grad_norm = clip_grad_norm_( 169 | self._model.parameters(), self._model_grad_clip) 170 | self._optimizers[0].step() 171 | 172 | measurements = self._pop_measurements() 173 | return measurements 174 | 175 | def to(self, device, **kwargs): 176 | self._device = device 177 | self._model.to(device, **kwargs) 178 | if self._embedding_model: 179 | self._embedding_model.to(device, **kwargs) 180 | 181 | def state_dict(self): 182 | state = { 183 | 'model_state_dict': self._model.state_dict(), 184 | 'optimizers': [optimizer.state_dict() for optimizer in self._optimizers] 185 | } 186 | if self._embedding_model: 187 | state.update( 188 | {'embedding_model_state_dict': 189 | self._embedding_model.state_dict()}) 190 | return state 191 | -------------------------------------------------------------------------------- /maml/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vuoristo/MMAML-Regression/1a8bea4d60461d8814e3c9f91427ed56378716ce/maml/models/__init__.py -------------------------------------------------------------------------------- /maml/models/fully_connected.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from maml.models.model import Model 7 | 8 | 9 | def weight_init(module): 10 | if isinstance(module, torch.nn.Linear): 11 | torch.nn.init.normal_(module.weight, mean=0, std=0.01) 12 | module.bias.data.zero_() 13 | 14 | 15 | class FullyConnectedModel(Model): 16 | def __init__(self, input_size, output_size, hidden_sizes=(), 17 | nonlinearity=F.relu, disable_norm=False, 18 | bias_transformation_size=0): 19 | super(FullyConnectedModel, self).__init__() 20 | self.hidden_sizes = hidden_sizes 21 | self.nonlinearity = nonlinearity 22 | self.num_layers = len(hidden_sizes) + 1 23 | self.disable_norm = disable_norm 24 | self.bias_transformation_size = bias_transformation_size 25 | 26 | if bias_transformation_size > 0: 27 | input_size = input_size + bias_transformation_size 28 | self.bias_transformation = torch.nn.Parameter( 29 | torch.zeros(bias_transformation_size)) 30 | 31 | layer_sizes = [input_size] + hidden_sizes + [output_size] 32 | for i in range(1, self.num_layers): 33 | self.add_module( 34 | 'layer{0}_linear'.format(i), 35 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i])) 36 | if not self.disable_norm: 37 | self.add_module( 38 | 'layer{0}_bn'.format(i), 39 | torch.nn.BatchNorm1d(layer_sizes[i], momentum=0.001)) 40 | self.add_module( 41 | 'output_linear', 42 | torch.nn.Linear(layer_sizes[self.num_layers - 1], 43 | layer_sizes[self.num_layers])) 44 | self.apply(weight_init) 45 | 46 | def forward(self, task, params=None, training=True, embeddings=None): 47 | if params is None: 48 | params = OrderedDict(self.named_parameters()) 49 | x = task.x.view(task.x.size(0), -1) 50 | 51 | if self.bias_transformation_size > 0: 52 | x = torch.cat((x, params['bias_transformation'].expand( 53 | x.size(0), params['bias_transformation'].size(0))), dim=1) 54 | 55 | for key, module in self.named_modules(): 56 | if 'linear' in key: 57 | x = F.linear(x, weight=params[key + '.weight'], 58 | bias=params[key + '.bias']) 59 | if self.disable_norm and 'output' not in key: 60 | x = self.nonlinearity(x) 61 | if 'bn' in key: 62 | x = F.batch_norm(x, weight=params[key + '.weight'], 63 | bias=params[key + '.bias'], 64 | running_mean=module.running_mean, 65 | running_var=module.running_var, 66 | training=training) 67 | x = self.nonlinearity(x) 68 | return x 69 | 70 | 71 | class MultiFullyConnectedModel(Model): 72 | def __init__(self, input_size, output_size, hidden_sizes=(), 73 | nonlinearity=F.relu, disable_norm=False, num_tasks=1, 74 | bias_transformation_size=0): 75 | super(MultiFullyConnectedModel, self).__init__() 76 | self.hidden_sizes = hidden_sizes 77 | self.nonlinearity = nonlinearity 78 | self.num_layers = len(hidden_sizes) + 1 79 | self.disable_norm = disable_norm 80 | self.bias_transformation_size = bias_transformation_size 81 | self.num_tasks = num_tasks 82 | 83 | if bias_transformation_size > 0: 84 | input_size = input_size + bias_transformation_size 85 | self.bias_transformation = torch.nn.Embedding( 86 | self.num_tasks, bias_transformation_size 87 | ) 88 | 89 | layer_sizes = [input_size] + hidden_sizes + [output_size] 90 | for j in range(0, self.num_tasks): 91 | for i in range(1, self.num_layers): 92 | self.add_module( 93 | 'task{0}_layer{1}_linear'.format(j, i), 94 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i])) 95 | if not self.disable_norm: 96 | self.add_module( 97 | 'task{0}_layer{1}_bn'.format(j, i), 98 | torch.nn.BatchNorm1d(layer_sizes[i], momentum=0.001)) 99 | self.add_module( 100 | 'task{0}_output_linear'.format(j), 101 | torch.nn.Linear(layer_sizes[self.num_layers - 1], 102 | layer_sizes[self.num_layers])) 103 | self.apply(weight_init) 104 | 105 | def forward(self, task, params=None, training=True, embeddings=None): 106 | if params is None: 107 | params = OrderedDict(self.named_parameters()) 108 | x = task.x.view(task.x.size(0), -1) 109 | task_id = task.task_info['task_id'] 110 | 111 | if self.bias_transformation_size > 0: 112 | bias_trans = self.bias_transformation( 113 | torch.LongTensor([task_id]).to(x.device)) 114 | x = torch.cat(( 115 | x, 116 | bias_trans.expand(x.size(0), bias_trans.size(1)) 117 | ), dim=1) 118 | 119 | for key, module in self.named_modules(): 120 | if 'task{0}'.format(task_id) in key: 121 | if 'linear' in key: 122 | x = F.linear(x, weight=params[key + '.weight'], 123 | bias=params[key + '.bias']) 124 | if self.disable_norm and 'output' not in key: 125 | x = self.nonlinearity(x) 126 | if 'bn' in key: 127 | x = F.batch_norm(x, weight=params[key + '.weight'], 128 | bias=params[key + '.bias'], 129 | running_mean=module.running_mean, 130 | running_var=module.running_var, 131 | training=training) 132 | x = self.nonlinearity(x) 133 | return x 134 | -------------------------------------------------------------------------------- /maml/models/gated_net.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from maml.models.model import Model 7 | 8 | 9 | def weight_init(module): 10 | if isinstance(module, torch.nn.Linear): 11 | torch.nn.init.normal_(module.weight, mean=0, std=0.01) 12 | module.bias.data.zero_() 13 | 14 | 15 | class GatedNet(Model): 16 | def __init__(self, input_size, output_size, hidden_sizes=[40, 40], 17 | nonlinearity=F.relu, condition_type='sigmoid_gate', condition_order='low2high'): 18 | super(GatedNet, self).__init__() 19 | self._nonlinearity = nonlinearity 20 | self._condition_type = condition_type 21 | self._condition_order = condition_order 22 | 23 | self.num_layers = len(hidden_sizes) + 1 24 | 25 | layer_sizes = [input_size] + hidden_sizes + [output_size] 26 | for i in range(1, self.num_layers): 27 | self.add_module( 28 | 'layer{0}_linear'.format(i), 29 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i])) 30 | self.add_module( 31 | 'output_linear', 32 | torch.nn.Linear(layer_sizes[self.num_layers - 1], 33 | layer_sizes[self.num_layers])) 34 | self.apply(weight_init) 35 | 36 | def conditional_layer(self, x, embedding): 37 | if self._condition_type == 'sigmoid_gate': 38 | x = x * F.sigmoid(embedding).expand_as(x) 39 | elif self._condition_type == 'affine': 40 | gammas, betas = torch.split(embedding, x.size(1), dim=-1) 41 | gammas = gammas + torch.ones_like(gammas) 42 | x = x * gammas + betas 43 | elif self._condition_type == 'softmax': 44 | x = x * F.softmax(embedding).expand_as(x) 45 | else: 46 | raise ValueError('Unrecognized conditional layer type {}'.format( 47 | self._condition_type)) 48 | return x 49 | 50 | def forward(self, task, params=None, embeddings=None, training=True): 51 | if params is None: 52 | params = OrderedDict(self.named_parameters()) 53 | 54 | if embeddings is not None: 55 | if self._condition_order == 'high2low': # High2Low 56 | embeddings = {'layer{}_linear'.format(len(params)-i): embedding 57 | for i, embedding in enumerate(embeddings[::-1])} 58 | elif self._condition_order == 'low2high': # Low2High 59 | embeddings = {'layer{}_linear'.format(i): embedding 60 | for i, embedding in enumerate(embeddings[::-1], start=1)} 61 | else: 62 | raise NotImplementedError( 63 | 'Unsuppported order for using conditional layers') 64 | x = task.x.view(task.x.size(0), -1) 65 | 66 | for key, module in self.named_modules(): 67 | if 'linear' in key: 68 | x = F.linear(x, weight=params[key + '.weight'], 69 | bias=params[key + '.bias']) 70 | if 'output' not in key and embeddings is not None: # conditioning and nonlinearity 71 | if type(embeddings.get(key, -1)) != type(-1): 72 | x = self.conditional_layer(x, embeddings[key]) 73 | 74 | x = self._nonlinearity(x) 75 | 76 | return x 77 | -------------------------------------------------------------------------------- /maml/models/lstm_embedding_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LSTMEmbeddingModel(torch.nn.Module): 5 | def __init__(self, input_size, output_size, embedding_dims, 6 | hidden_size=40, num_layers=2): 7 | super(LSTMEmbeddingModel, self).__init__() 8 | self._input_size = input_size 9 | self._output_size = output_size 10 | self._hidden_size = hidden_size 11 | self._num_layers = num_layers 12 | self._embedding_dims = embedding_dims 13 | self._bidirectional = True 14 | self._device = 'cpu' 15 | 16 | rnn_input_size = int(input_size + output_size) 17 | self.rnn = torch.nn.LSTM( 18 | rnn_input_size, hidden_size, num_layers, bidirectional=self._bidirectional) 19 | 20 | self._embeddings = torch.nn.ModuleList() 21 | for dim in embedding_dims: 22 | self._embeddings.append(torch.nn.Linear( 23 | hidden_size*(2 if self._bidirectional else 1), dim)) 24 | 25 | def forward(self, task, return_task_embedding=True): 26 | batch_size = 1 27 | h0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1), 28 | batch_size, self._hidden_size, device=self._device) 29 | c0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1), 30 | batch_size, self._hidden_size, device=self._device) 31 | 32 | x = task.x.view(task.x.size(0), -1) 33 | y = task.y.view(task.y.size(0), -1) 34 | 35 | # LSTM input dimensions are seq_len, batch, input_size 36 | inputs = torch.cat((x, y), dim=1).view(x.size(0), 1, -1) 37 | output, (hn, cn) = self.rnn(inputs, (h0, c0)) 38 | if self._bidirectional: 39 | N, B, H = output.shape 40 | output = output.view(N, B, 2, H // 2) 41 | embedding_input = torch.cat( 42 | [output[-1, :, 0], output[0, :, 1]], dim=1) 43 | 44 | out_embeddings = [] 45 | for embedding in self._embeddings: 46 | out_embeddings.append(embedding(embedding_input)) 47 | if return_task_embedding: 48 | return out_embeddings, embedding_input 49 | else: 50 | return out_embeddings 51 | 52 | def to(self, device, **kwargs): 53 | self._device = device 54 | super(LSTMEmbeddingModel, self).to(device, **kwargs) 55 | -------------------------------------------------------------------------------- /maml/models/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | class Model(torch.nn.Module): 7 | def __init__(self): 8 | super(Model, self).__init__() 9 | 10 | @property 11 | def param_dict(self): 12 | return OrderedDict(self.named_parameters()) 13 | -------------------------------------------------------------------------------- /maml/models/simple_embedding_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SimpleEmbeddingModel(torch.nn.Module): 5 | def __init__(self, num_embeddings, embedding_dims): 6 | super(SimpleEmbeddingModel, self).__init__() 7 | self._embeddings = torch.nn.ModuleList() 8 | for dim in embedding_dims: 9 | self._embeddings.append(torch.nn.Embedding(num_embeddings, dim)) 10 | self._device = 'cpu' 11 | 12 | def forward(self, task): 13 | task_id = torch.tensor(task.task_id, dtype=torch.long, 14 | device=self._device) 15 | out_embeddings = [] 16 | for embedding in self._embeddings: 17 | out_embeddings.append(embedding(task_id)) 18 | return out_embeddings 19 | 20 | def to(self, device, **kwargs): 21 | self._device = device 22 | super(SimpleEmbeddingModel, self).to(device, **kwargs) 23 | -------------------------------------------------------------------------------- /maml/sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict, namedtuple 3 | 4 | from torch.utils.data.sampler import Sampler 5 | 6 | 7 | class ClassBalancedSampler(Sampler): 8 | """Generates indices for class balanced batch by sampling with replacement. 9 | """ 10 | 11 | def __init__(self, dataset_labels, num_classes_per_batch, 12 | num_samples_per_class, num_total_batches, train): 13 | """ 14 | Args: 15 | dataset_labels: list of dataset labels 16 | num_classes_per_batch: number of classes to sample for each batch 17 | num_samples_per_class: number of samples to sample for each class 18 | for each batch. For K shot learning this should be K + number 19 | of validation samples 20 | num_total_batches: total number of batches to generate 21 | """ 22 | self._dataset_labels = dataset_labels 23 | self._classes = set(self._dataset_labels) 24 | self._class_to_samples = defaultdict(set) 25 | for i, c in enumerate(self._dataset_labels): 26 | self._class_to_samples[c].add(i) 27 | 28 | self._num_classes_per_batch = num_classes_per_batch 29 | self._num_samples_per_class = num_samples_per_class 30 | self._num_total_batches = num_total_batches 31 | self._train = train 32 | 33 | def __iter__(self): 34 | for i in range(self._num_total_batches): 35 | batch_classes = random.sample( 36 | self._class_to_samples.keys(), self._num_classes_per_batch) 37 | batch_samples = [] 38 | for c in batch_classes: 39 | class_samples = random.sample( 40 | self._class_to_samples[c], self._num_samples_per_class) 41 | for sample in class_samples: 42 | batch_samples.append(sample) 43 | random.shuffle(batch_samples) 44 | for sample in batch_samples: 45 | yield sample 46 | 47 | def __len__(self): 48 | return self._num_total_batches 49 | -------------------------------------------------------------------------------- /maml/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, meta_learner, meta_dataset, writer, log_interval, 11 | save_interval, model_type, save_folder): 12 | self._meta_learner = meta_learner 13 | self._meta_dataset = meta_dataset 14 | self._writer = writer 15 | self._log_interval = log_interval 16 | self._save_interval = save_interval 17 | self._model_type = model_type 18 | self._save_folder = save_folder 19 | 20 | def run(self, is_training): 21 | if not is_training: 22 | all_pre_val_measurements = defaultdict(list) 23 | all_pre_train_measurements = defaultdict(list) 24 | all_post_val_measurements = defaultdict(list) 25 | all_post_train_measurements = defaultdict(list) 26 | 27 | for i, (train_tasks, val_tasks) in enumerate( 28 | iter(self._meta_dataset), start=1): 29 | (pre_train_measurements, adapted_params, embeddings 30 | ) = self._meta_learner.adapt(train_tasks) 31 | post_val_measurements = self._meta_learner.step( 32 | adapted_params, embeddings, val_tasks, is_training) 33 | 34 | # Tensorboard 35 | if (i % self._log_interval == 0 or i == 1): 36 | pre_val_measurements = self._meta_learner.measure( 37 | tasks=val_tasks, embeddings_list=embeddings) 38 | post_train_measurements = self._meta_learner.measure( 39 | tasks=train_tasks, adapted_params_list=adapted_params, 40 | embeddings_list=embeddings) 41 | 42 | _grads_mean = np.mean(self._meta_learner._grads_mean) 43 | self._meta_learner._grads_mean = [] 44 | 45 | self.log_output( 46 | pre_val_measurements, pre_train_measurements, 47 | post_val_measurements, post_train_measurements, 48 | i, _grads_mean) 49 | 50 | if is_training: 51 | self.write_tensorboard( 52 | pre_val_measurements, pre_train_measurements, 53 | post_val_measurements, post_train_measurements, 54 | i, _grads_mean) 55 | 56 | # Save model 57 | if i % self._save_interval == 0 and is_training: 58 | save_name = 'maml_{0}_{1}.pt'.format(self._model_type, i) 59 | save_path = os.path.join(self._save_folder, save_name) 60 | with open(save_path, 'wb') as f: 61 | torch.save(self._meta_learner.state_dict(), f) 62 | 63 | # Collect evaluation statistics over full dataset 64 | if not is_training: 65 | for key, value in sorted(pre_val_measurements.items()): 66 | all_pre_val_measurements[key].append(value) 67 | for key, value in sorted(pre_train_measurements.items()): 68 | all_pre_train_measurements[key].append(value) 69 | for key, value in sorted(post_val_measurements.items()): 70 | all_post_val_measurements[key].append(value) 71 | for key, value in sorted(post_train_measurements.items()): 72 | all_post_train_measurements[key].append(value) 73 | 74 | # Compute evaluation statistics assuming all batches were the same size 75 | if not is_training: 76 | results = {'num_batches': i} 77 | for key, value in sorted(all_pre_val_measurements.items()): 78 | results['pre_val_' + key] = value 79 | for key, value in sorted(all_pre_train_measurements.items()): 80 | results['pre_train_' + key] = value 81 | for key, value in sorted(all_post_val_measurements.items()): 82 | results['post_val_' + key] = value 83 | for key, value in sorted(all_post_train_measurements.items()): 84 | results['post_train_' + key] = value 85 | 86 | print('Evaluation results:') 87 | for key, value in sorted(results.items()): 88 | if not isinstance(value, int): 89 | print('{}: {:.6f} +- {:.6e}, std={:.6f}'.format( 90 | key, 91 | float(np.mean(value)), 92 | float(self.compute_confidence_interval(value)), 93 | float(np.std(value)), 94 | )) 95 | else: 96 | print('{}: {}'.format(key, value)) 97 | 98 | results_path = os.path.join(self._save_folder, 'results.json') 99 | with open(results_path, 'w') as f: 100 | json.dump(results, f) 101 | 102 | def compute_confidence_interval(self, value): 103 | """ 104 | Compute 95% +- confidence intervals over tasks 105 | change 1.960 to 2.576 for 99% +- confidence intervals 106 | """ 107 | return np.std(value) * 1.960 / np.sqrt(len(value)) 108 | 109 | def train(self): 110 | self.run(is_training=True) 111 | 112 | def eval(self): 113 | self.run(is_training=False) 114 | 115 | def write_tensorboard(self, pre_val_measurements, pre_train_measurements, 116 | post_val_measurements, post_train_measurements, 117 | iteration, embedding_grads_mean=None): 118 | for key, value in pre_val_measurements.items(): 119 | self._writer.add_scalar( 120 | '{}/before_update/meta_val'.format(key), value, iteration) 121 | for key, value in pre_train_measurements.items(): 122 | self._writer.add_scalar( 123 | '{}/before_update/meta_train'.format(key), value, iteration) 124 | for key, value in post_train_measurements.items(): 125 | self._writer.add_scalar( 126 | '{}/after_update/meta_train'.format(key), value, iteration) 127 | for key, value in post_val_measurements.items(): 128 | self._writer.add_scalar( 129 | '{}/after_update/meta_val'.format(key), value, iteration) 130 | if embedding_grads_mean is not None: 131 | self._writer.add_scalar( 132 | 'embedding_grads_mean', embedding_grads_mean, iteration) 133 | 134 | def log_output(self, pre_val_measurements, pre_train_measurements, 135 | post_val_measurements, post_train_measurements, 136 | iteration, embedding_grads_mean=None): 137 | log_str = 'Iteration: {} '.format(iteration) 138 | for key, value in sorted(pre_val_measurements.items()): 139 | log_str = (log_str + '{} meta_val before: {:.3f} ' 140 | ''.format(key, value)) 141 | for key, value in sorted(pre_train_measurements.items()): 142 | log_str = (log_str + '{} meta_train before: {:.3f} ' 143 | ''.format(key, value)) 144 | for key, value in sorted(post_train_measurements.items()): 145 | log_str = (log_str + '{} meta_train after: {:.3f} ' 146 | ''.format(key, value)) 147 | for key, value in sorted(post_val_measurements.items()): 148 | log_str = (log_str + '{} meta_val after: {:.3f} ' 149 | ''.format(key, value)) 150 | if embedding_grads_mean is not None: 151 | log_str = (log_str + 'embedding_grad_norm after: {:.3f} ' 152 | ''.format(embedding_grads_mean)) 153 | print(log_str) 154 | -------------------------------------------------------------------------------- /maml/utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import torch 4 | 5 | 6 | def accuracy(preds, y): 7 | _, preds = torch.max(preds.data, 1) 8 | total = y.size(0) 9 | correct = (preds == y).sum().float() 10 | return correct / total 11 | 12 | 13 | def optimizer_to_device(optimizer, device): 14 | for state in optimizer.state.values(): 15 | for k, v in state.items(): 16 | if isinstance(v, torch.Tensor): 17 | state[k] = v.to(device) 18 | 19 | 20 | def get_git_revision_hash(): 21 | return str(subprocess.check_output(['git', 'rev-parse', 'HEAD'])) 22 | --------------------------------------------------------------------------------