├── .gitignore
├── LICENSE
├── README.md
├── assets
└── model.png
├── environment.yml
├── main.py
└── maml
├── __init__.py
├── datasets
├── __init__.py
├── metadataset.py
└── simple_functions.py
├── metalearner.py
├── models
├── __init__.py
├── fully_connected.py
├── gated_net.py
├── lstm_embedding_model.py
├── model.py
└── simple_embedding_model.py
├── sampler.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # project specific
2 | data/
3 | saves/
4 | logs/
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 |
111 | *.txt
112 | *.sw[opn]
113 | *.pt
114 |
115 |
116 | train_dir
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Tristan Deleu
4 | Copyright (c) 2018 Risto Vuorio
5 | Copyright (c) 2019 Hexiang Hu
6 | Copyright (c) 2019 Shao-Hua Sun
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multimodal Model-Agnostic Meta-Learning for Few-shot Regression
2 |
3 | This project is an implementation of [**Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation**](https://arxiv.org/abs/1910.13616), which is published in [**NeurIPS 2019**](https://neurips.cc/Conferences/2019/). Please visit our [project page](https://vuoristo.github.io/MMAML/) for more information and contact [Hexiang Hu](http://hexianghu.com/) for any questions.
4 |
5 | Model-agnostic meta-learners aim to acquire meta-prior parameters from a distribution of tasks and adapt to novel tasks with few gradient updates. Yet, seeking a common initialization shared across the entire task distribution substantially limits the diversity of the task distributions that they are able to learn from. We propose a multimodal MAML (MMAML) framework, which is able to modulate its meta-learned prior according to the identified mode, allowing more efficient fast adaptation. An illustration of the proposed framework is as follows.
6 |
7 |
8 |
9 |
10 |
11 |
12 | ## Getting started
13 |
14 | Use of [Conda Environment](https://docs.conda.io/en/latest/) is suggested to for straightforward handling of the dependencies.
15 |
16 | ```bash
17 | conda env create -f environment.yml
18 | conda activate mmaml_regression
19 | ```
20 |
21 | # Usage
22 |
23 | After installation, we can start to train models with the following commands.
24 |
25 | ## Linear + Sinusoid Functions
26 |
27 | ### MAML
28 | ```
29 | python main.py --dataset mixed --num-batches 70000 --model-type fc --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-maml-5steps --bias-transformation-size 20 --disable-norm
30 | ```
31 |
32 |
33 | ### Multi-MAML
34 | ```
35 | python main.py --dataset mixed --num-batches 70000 --model-type multi --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-multi-maml-5steps --bias-transformation-size 20 --disable-norm
36 | ```
37 |
38 | ### MMAML-postupdate
39 |
40 | #### FiLM
41 | ```
42 | python main.py --dataset mixed --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-mmaml-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 --inner-loop-grad-clip 10
43 | ```
44 |
45 | #### Sigmoid
46 | ```
47 | python main.py --dataset mixed --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-mmaml-sigmoid-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 --inner-loop-grad-clip 10 --condition-type sigmoid_gate
48 | ```
49 |
50 | #### Softmax
51 | ```
52 | python main.py --dataset mixed --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 2mods-mmaml-softmax-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 --inner-loop-grad-clip 10 --condition-type softmax
53 | ```
54 |
55 | ### MMAML-preupdate
56 | ```
57 | python main.py --dataset mixed --num-batches 70000 --model-type gated --fast-lr 0.0 --meta-batch-size 50 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 1 --output-folder 2mods-mmaml-pre-1steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 --inner-loop-grad-clip 10
58 | ```
59 |
60 | ## Linear + Quadratic + Sinusoid Functions
61 |
62 | ### MAML
63 | ```
64 | python main.py --dataset many --num-batches 70000 --model-type fc --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-maml-5steps --bias-transformation-size 20 --disable-norm
65 | ```
66 |
67 |
68 | ### Multi-MAML
69 | ```
70 | python main.py --dataset many --num-batches 70000 --model-type multi --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-multi-maml-5steps --bias-transformation-size 20 --disable-norm
71 | ```
72 |
73 | ### MMAML-postupdate
74 |
75 | #### FiLM
76 | ```
77 | python main.py --dataset many --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 200 200 --inner-loop-grad-clip 10
78 | ```
79 |
80 | #### Sigmoid
81 | ```
82 | python main.py --dataset many --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-sigmoid-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 100 100 --inner-loop-grad-clip 10 --condition-type sigmoid_gate
83 | ```
84 |
85 | #### Softmax
86 | ```
87 | python main.py --dataset many --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-softmax-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 100 100 --inner-loop-grad-clip 10 --condition-type softmax
88 | ```
89 |
90 | ### MMAML-preupdate
91 | ```
92 | python main.py --dataset many --num-batches 70000 --model-type gated --fast-lr 0.00 --meta-batch-size 75 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 1 --output-folder 3mods-mmaml-pre-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 200 200 --inner-loop-grad-clip 10
93 | ```
94 |
95 |
96 | ## Linear + Quadratic + Sinusoid + Tanh + Absolute Functions
97 |
98 | ### MAML
99 | ```
100 | python main.py --dataset five --num-batches 70000 --model-type fc --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 5mods-maml-5steps --bias-transformation-size 20 --disable-norm
101 | ```
102 |
103 |
104 | ### Multi-MAML
105 | ```
106 | python main.py --dataset five --num-batches 70000 --model-type multi --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 5mods-multi-maml-5steps --bias-transformation-size 20 --disable-norm
107 | ```
108 |
109 | ### MMAML-postupdate
110 |
111 | #### FiLM
112 | ```
113 | python main.py --dataset five --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 5mods-mmaml-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 200 200 --inner-loop-grad-clip 10
114 | ```
115 |
116 | #### Sigmoid
117 | ```
118 | python main.py --dataset five --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-sigmoid-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 100 100 --inner-loop-grad-clip 10 --condition-type sigmoid_gate
119 | ```
120 |
121 | #### Softmax
122 | ```
123 | python main.py --dataset five --num-batches 70000 --model-type gated --fast-lr 0.001 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 5 --output-folder 3mods-mmaml-softmax-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 100 100 100 --inner-loop-grad-clip 10 --condition-type softmax
124 | ```
125 |
126 | ### MMAML-preupdate
127 | ```
128 | python main.py --dataset five --num-batches 70000 --model-type gated --fast-lr 0.00 --meta-batch-size 125 --num-samples-per-class 10 --num-val-samples 5 --noise-std 0.3 --hidden-sizes 100 100 100 --device cuda --num-updates 1 --output-folder 5mods-mmaml-pre-5steps --bias-transformation-size 20 --disable-norm --embedding-type LSTM --embedding-dims 200 200 200 --inner-loop-grad-clip 10
129 | ```
130 |
131 |
132 |
133 | ## Authors
134 |
135 | [Hexiang Hu](http://hexianghu.com/), [Shao-Hua Sun](http://shaohua0116.github.io/), [Risto Vuorio](https://vuoristo.github.io/)
136 |
--------------------------------------------------------------------------------
/assets/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vuoristo/MMAML-Regression/1a8bea4d60461d8814e3c9f91427ed56378716ce/assets/model.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: mmaml_regression
2 | channels:
3 | - pytorch
4 | dependencies:
5 | - python=3.7
6 | - pytorch=0.4.1
7 | - numpy
8 | - pip
9 | - pip:
10 | - tensorboardX
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | import torch
6 | import numpy as np
7 | from tensorboardX import SummaryWriter
8 |
9 | from maml.datasets.simple_functions import (
10 | SinusoidMetaDataset,
11 | LinearMetaDataset,
12 | MixedFunctionsMetaDataset,
13 | ManyFunctionsMetaDataset,
14 | FiveFunctionsMetaDataset,
15 | MultiSinusoidsMetaDataset,
16 | )
17 | from maml.models.fully_connected import FullyConnectedModel, MultiFullyConnectedModel
18 | from maml.models.gated_net import GatedNet
19 | from maml.models.lstm_embedding_model import LSTMEmbeddingModel
20 | from maml.metalearner import MetaLearner
21 | from maml.trainer import Trainer
22 | from maml.utils import optimizer_to_device, get_git_revision_hash
23 |
24 |
25 | def main(args):
26 | is_training = not args.eval
27 | run_name = 'train' if is_training else 'eval'
28 |
29 | if is_training:
30 | writer = SummaryWriter('./train_dir/{0}/{1}'.format(
31 | args.output_folder, run_name))
32 | else:
33 | writer = None
34 |
35 | save_folder = './train_dir/{0}'.format(args.output_folder)
36 | if not os.path.exists(save_folder):
37 | os.makedirs(save_folder)
38 |
39 | config_name = '{0}_config.json'.format(run_name)
40 | with open(os.path.join(save_folder, config_name), 'w') as f:
41 | config = {k: v for (k, v) in vars(args).items() if k != 'device'}
42 | config.update(device=args.device.type)
43 | config.update({'git_hash': get_git_revision_hash()})
44 | json.dump(config, f, indent=2)
45 |
46 | _num_tasks = 1
47 | if args.dataset == 'sinusoid':
48 | dataset = SinusoidMetaDataset(
49 | num_total_batches=args.num_batches,
50 | num_samples_per_function=args.num_samples_per_class,
51 | num_val_samples=args.num_val_samples,
52 | meta_batch_size=args.meta_batch_size,
53 | amp_range=args.amp_range,
54 | phase_range=args.phase_range,
55 | input_range=args.input_range,
56 | oracle=args.oracle,
57 | train=is_training,
58 | device=args.device)
59 | loss_func = torch.nn.MSELoss()
60 | collect_accuracies = False
61 | elif args.dataset == 'linear':
62 | dataset = LinearMetaDataset(
63 | num_total_batches=args.num_batches,
64 | num_samples_per_function=args.num_samples_per_class,
65 | num_val_samples=args.num_val_samples,
66 | meta_batch_size=args.meta_batch_size,
67 | slope_range=args.slope_range,
68 | intersect_range=args.intersect_range,
69 | input_range=args.input_range,
70 | oracle=args.oracle,
71 | train=is_training,
72 | device=args.device)
73 | loss_func = torch.nn.MSELoss()
74 | collect_accuracies = False
75 | elif args.dataset == 'mixed':
76 | dataset = MixedFunctionsMetaDataset(
77 | num_total_batches=args.num_batches,
78 | num_samples_per_function=args.num_samples_per_class,
79 | num_val_samples=args.num_val_samples,
80 | meta_batch_size=args.meta_batch_size,
81 | amp_range=args.amp_range,
82 | phase_range=args.phase_range,
83 | slope_range=args.slope_range,
84 | intersect_range=args.intersect_range,
85 | input_range=args.input_range,
86 | noise_std=args.noise_std,
87 | oracle=args.oracle,
88 | task_oracle=args.task_oracle,
89 | train=is_training,
90 | device=args.device)
91 | loss_func = torch.nn.MSELoss()
92 | collect_accuracies = False
93 | _num_tasks = 2
94 | elif args.dataset == 'five':
95 | dataset = FiveFunctionsMetaDataset(
96 | num_total_batches=args.num_batches,
97 | num_samples_per_function=args.num_samples_per_class,
98 | num_val_samples=args.num_val_samples,
99 | meta_batch_size=args.meta_batch_size,
100 | amp_range=args.amp_range,
101 | phase_range=args.phase_range,
102 | slope_range=args.slope_range,
103 | intersect_range=args.intersect_range,
104 | input_range=args.input_range,
105 | noise_std=args.noise_std,
106 | oracle=args.oracle,
107 | task_oracle=args.task_oracle,
108 | train=is_training,
109 | device=args.device)
110 | loss_func = torch.nn.MSELoss()
111 | collect_accuracies = False
112 | _num_tasks = 5
113 | elif args.dataset == 'many':
114 | dataset = ManyFunctionsMetaDataset(
115 | num_total_batches=args.num_batches,
116 | num_samples_per_function=args.num_samples_per_class,
117 | num_val_samples=args.num_val_samples,
118 | meta_batch_size=args.meta_batch_size,
119 | amp_range=args.amp_range,
120 | phase_range=args.phase_range,
121 | slope_range=args.slope_range,
122 | intersect_range=args.intersect_range,
123 | input_range=args.input_range,
124 | noise_std=args.noise_std,
125 | oracle=args.oracle,
126 | task_oracle=args.task_oracle,
127 | train=is_training,
128 | device=args.device)
129 | loss_func = torch.nn.MSELoss()
130 | collect_accuracies = False
131 | _num_tasks = 3
132 | elif args.dataset == 'multisinusoids':
133 | dataset = MultiSinusoidsMetaDataset(
134 | num_total_batches=args.num_batches,
135 | num_samples_per_function=args.num_samples_per_class,
136 | num_val_samples=args.num_val_samples,
137 | meta_batch_size=args.meta_batch_size,
138 | amp_range=args.amp_range,
139 | phase_range=args.phase_range,
140 | slope_range=args.slope_range,
141 | intersect_range=args.intersect_range,
142 | input_range=args.input_range,
143 | noise_std=args.noise_std,
144 | oracle=args.oracle,
145 | task_oracle=args.task_oracle,
146 | train=is_training,
147 | device=args.device)
148 | loss_func = torch.nn.MSELoss()
149 | collect_accuracies = False
150 | else:
151 | raise ValueError('Unrecognized dataset {}'.format(args.dataset))
152 |
153 | embedding_model = None
154 |
155 | if args.model_type == 'fc':
156 | model = FullyConnectedModel(
157 | input_size=np.prod(dataset.input_size),
158 | output_size=dataset.output_size,
159 | hidden_sizes=args.hidden_sizes,
160 | disable_norm=args.disable_norm,
161 | bias_transformation_size=args.bias_transformation_size)
162 | elif args.model_type == 'multi':
163 | model = MultiFullyConnectedModel(
164 | input_size=np.prod(dataset.input_size),
165 | output_size=dataset.output_size,
166 | hidden_sizes=args.hidden_sizes,
167 | disable_norm=args.disable_norm,
168 | num_tasks=_num_tasks,
169 | bias_transformation_size=args.bias_transformation_size)
170 | elif args.model_type == 'gated':
171 | model = GatedNet(
172 | input_size=np.prod(dataset.input_size),
173 | output_size=dataset.output_size,
174 | hidden_sizes=args.hidden_sizes,
175 | condition_type=args.condition_type,
176 | condition_order=args.condition_order)
177 | else:
178 | raise ValueError('Unrecognized model type {}'.format(args.model_type))
179 | model_parameters = list(model.parameters())
180 |
181 | if args.embedding_type == '':
182 | embedding_model = None
183 | elif args.embedding_type == 'LSTM':
184 | embedding_model = LSTMEmbeddingModel(
185 | input_size=np.prod(dataset.input_size),
186 | output_size=dataset.output_size,
187 | embedding_dims=args.embedding_dims,
188 | hidden_size=args.embedding_hidden_size,
189 | num_layers=args.embedding_num_layers)
190 | embedding_parameters = list(embedding_model.parameters())
191 | else:
192 | raise ValueError('Unrecognized embedding type {}'.format(
193 | args.embedding_type))
194 |
195 | optimizers = None
196 | if embedding_model:
197 | optimizers = (torch.optim.Adam(model_parameters, lr=args.slow_lr),
198 | torch.optim.Adam(embedding_parameters, lr=args.slow_lr))
199 | else:
200 | optimizers = (torch.optim.Adam(model_parameters, lr=args.slow_lr), )
201 |
202 | if args.checkpoint != '':
203 | checkpoint = torch.load(args.checkpoint)
204 | model.load_state_dict(checkpoint['model_state_dict'])
205 | model.to(args.device)
206 | if 'optimizer' in checkpoint:
207 | pass
208 | else:
209 | optimizers[0].load_state_dict(checkpoint['optimizers'][0])
210 | optimizer_to_device(optimizers[0], args.device)
211 |
212 | if embedding_model:
213 | embedding_model.load_state_dict(
214 | checkpoint['embedding_model_state_dict'])
215 | optimizers[1].load_state_dict(checkpoint['optimizers'][1])
216 | optimizer_to_device(optimizers[1], args.device)
217 |
218 | meta_learner = MetaLearner(
219 | model, embedding_model, optimizers, fast_lr=args.fast_lr,
220 | loss_func=loss_func, first_order=args.first_order,
221 | num_updates=args.num_updates,
222 | inner_loop_grad_clip=args.inner_loop_grad_clip,
223 | collect_accuracies=collect_accuracies, device=args.device,
224 | embedding_grad_clip=args.embedding_grad_clip,
225 | model_grad_clip=args.model_grad_clip)
226 |
227 | trainer = Trainer(
228 | meta_learner=meta_learner, meta_dataset=dataset, writer=writer,
229 | log_interval=args.log_interval, save_interval=args.save_interval,
230 | model_type=args.model_type, save_folder=save_folder)
231 |
232 | if is_training:
233 | trainer.train()
234 | else:
235 | trainer.eval()
236 |
237 |
238 | if __name__ == '__main__':
239 |
240 | def str2bool(arg):
241 | return arg.lower() == 'true'
242 |
243 | parser = argparse.ArgumentParser(
244 | description='Multimodal Model-Agnostic Meta-Learning (MAML)')
245 |
246 | # Model
247 | parser.add_argument('--hidden-sizes', type=int,
248 | default=[256, 128, 64, 64], nargs='+',
249 | help='number of hidden units per layer')
250 | parser.add_argument('--model-type', type=str, default='fc',
251 | help='type of the model')
252 | parser.add_argument('--condition-type', type=str, default='affine',
253 | help='type of the conditional layers')
254 | parser.add_argument('--use-max-pool', action='store_true',
255 | help='choose whether to use max pooling with convolutional model')
256 | parser.add_argument('--num-channels', type=int, default=64,
257 | help='number of channels in convolutional layers')
258 | parser.add_argument('--disable-norm', action='store_true',
259 | help='disable batchnorm after linear layers in a fully connected model')
260 | parser.add_argument('--bias-transformation-size', type=int, default=0,
261 | help='size of bias transformation vector that is concatenated with '
262 | 'input')
263 | parser.add_argument('--condition-order', type=str, default='low2high',
264 | help='order of the conditional layers to be used')
265 |
266 | # Embedding
267 | parser.add_argument('--embedding-type', type=str, default='',
268 | help='type of the embedding')
269 | parser.add_argument('--embedding-hidden-size', type=int, default=40,
270 | help='number of hidden units per layer in recurrent embedding model')
271 | parser.add_argument('--embedding-num-layers', type=int, default=2,
272 | help='number of layers in recurrent embedding model')
273 | parser.add_argument('--embedding-dims', type=int, nargs='+', default=0,
274 | help='dimensions of the embeddings')
275 |
276 | # Randomly sampled embedding vectors
277 | parser.add_argument('--num-sample-embedding', type=int, default=0,
278 | help='number of randomly sampled embedding vectors')
279 | parser.add_argument(
280 | '--sample-embedding-file', type=str, default='embeddings',
281 | help='the file name of randomly sampled embedding vectors')
282 | parser.add_argument(
283 | '--sample-embedding-file-type', type=str, default='hdf5')
284 |
285 | # Inner loop
286 | parser.add_argument('--first-order', action='store_true',
287 | help='use the first-order approximation of MAML')
288 | parser.add_argument('--fast-lr', type=float, default=0.4,
289 | help='learning rate for the 1-step gradient update of MAML')
290 | parser.add_argument('--inner-loop-grad-clip', type=float, default=0.0,
291 | help='enable gradient clipping in the inner loop')
292 | parser.add_argument('--num-updates', type=int, default=1,
293 | help='how many update steps in the inner loop')
294 |
295 | # Optimization
296 | parser.add_argument('--num-batches', type=int, default=1920000,
297 | help='number of batches')
298 | parser.add_argument('--meta-batch-size', type=int, default=32,
299 | help='number of tasks per batch')
300 | parser.add_argument('--slow-lr', type=float, default=0.001,
301 | help='learning rate for the global update of MAML')
302 |
303 | # Miscellaneous
304 | parser.add_argument('--output-folder', type=str, default='maml',
305 | help='name of the output folder')
306 | parser.add_argument('--device', type=str, default='cpu',
307 | help='set the device (cpu or cuda)')
308 | parser.add_argument('--num-workers', type=int, default=4,
309 | help='how many DataLoader workers to use')
310 | parser.add_argument('--log-interval', type=int, default=100,
311 | help='number of batches between tensorboard writes')
312 | parser.add_argument('--save-interval', type=int, default=1000,
313 | help='number of batches between model saves')
314 | parser.add_argument('--eval', action='store_true', default=False,
315 | help='evaluate model')
316 | parser.add_argument('--checkpoint', type=str, default='',
317 | help='path to saved parameters.')
318 |
319 | # Dataset
320 | parser.add_argument('--dataset', type=str, default='omniglot',
321 | help='which dataset to use')
322 | parser.add_argument('--data-root', type=str, default='data',
323 | help='path to store datasets')
324 | parser.add_argument('--num-samples-per-class', type=int, default=1,
325 | help='how many samples per class for training')
326 | parser.add_argument('--num-val-samples', type=int, default=1,
327 | help='how many samples per class for validation')
328 | parser.add_argument('--input-range', type=float, default=[-5.0, 5.0],
329 | nargs='+', help='input range of simple functions')
330 | parser.add_argument('--phase-range', type=float, default=[0, np.pi],
331 | nargs='+', help='phase range of sinusoids')
332 | parser.add_argument('--amp-range', type=float, default=[0.1, 5.0],
333 | nargs='+', help='amp range of sinusoids')
334 | parser.add_argument('--slope-range', type=float, default=[-3.0, 3.0],
335 | nargs='+', help='slope range of linear functions')
336 | parser.add_argument('--intersect-range', type=float, default=[-3.0, 3.0],
337 | nargs='+', help='intersect range of linear functions')
338 | parser.add_argument('--noise-std', type=float, default=0.0,
339 | help='add gaussian noise to mixed functions')
340 | parser.add_argument('--oracle', action='store_true',
341 | help='concatenate phase and amp to sinusoid inputs')
342 | parser.add_argument('--task-oracle', action='store_true',
343 | help='uses task id for prediction in some models')
344 |
345 | parser.add_argument('--embedding-grad-clip', type=float, default=2.0,
346 | help='')
347 | parser.add_argument('--model-grad-clip', type=float, default=2.0,
348 | help='')
349 |
350 | args = parser.parse_args()
351 |
352 | if args.embedding_dims == 0:
353 | args.embedding_dims = args.hidden_sizes
354 |
355 | # Create logs and saves folder if they don't exist
356 | if not os.path.exists('./train_dir'):
357 | os.makedirs('./train_dir')
358 |
359 | # Make sure num sample embedding < num sample tasks
360 | args.num_sample_embedding = min(
361 | args.num_sample_embedding, args.num_batches)
362 |
363 | # Device
364 | args.device = torch.device(
365 | args.device if torch.cuda.is_available() else 'cpu')
366 |
367 | main(args)
368 |
--------------------------------------------------------------------------------
/maml/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vuoristo/MMAML-Regression/1a8bea4d60461d8814e3c9f91427ed56378716ce/maml/__init__.py
--------------------------------------------------------------------------------
/maml/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vuoristo/MMAML-Regression/1a8bea4d60461d8814e3c9f91427ed56378716ce/maml/datasets/__init__.py
--------------------------------------------------------------------------------
/maml/datasets/metadataset.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, namedtuple
2 |
3 | Task = namedtuple('Task', ['x', 'y', 'task_info'])
4 |
--------------------------------------------------------------------------------
/maml/datasets/simple_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from maml.datasets.metadataset import Task
5 |
6 |
7 | def generate_sinusoid_batch(amp_range, phase_range, input_range, num_samples,
8 | batch_size, oracle, bias=0):
9 | amp = np.random.uniform(amp_range[0], amp_range[1], [batch_size])
10 | phase = np.random.uniform(phase_range[0], phase_range[1], [batch_size])
11 | outputs = np.zeros([batch_size, num_samples, 1])
12 | inputs = np.zeros([batch_size, num_samples, 1])
13 | for i in range(batch_size):
14 | inputs[i] = np.random.uniform(input_range[0], input_range[1],
15 | [num_samples, 1])
16 | outputs[i] = amp[i] * np.sin(inputs[i] - phase[i]) + bias
17 |
18 | if oracle:
19 | amps = np.ones_like(inputs) * amp.reshape(-1, 1, 1)
20 | phases = np.ones_like(inputs) * phase.reshape(-1, 1, 1)
21 | inputs = np.concatenate((inputs, amps, phases), axis=2)
22 |
23 | return inputs, outputs, amp, phase
24 |
25 |
26 | def generate_linear_batch(slope_range, intersect_range, input_range,
27 | num_samples, batch_size, oracle):
28 | slope = np.random.uniform(slope_range[0], slope_range[1], [batch_size])
29 | intersect = np.random.uniform(intersect_range[0], intersect_range[1],
30 | [batch_size])
31 | outputs = np.zeros([batch_size, num_samples, 1])
32 | inputs = np.zeros([batch_size, num_samples, 1])
33 | for i in range(batch_size):
34 | inputs[i] = np.random.uniform(input_range[0], input_range[1],
35 | [num_samples, 1])
36 | outputs[i] = inputs[i] * slope[i] + intersect[i]
37 |
38 | if oracle:
39 | slopes = np.ones_like(inputs) * slope.reshape(-1, 1, 1)
40 | intersects = np.ones_like(inputs) * intersect.reshape(-1, 1, 1)
41 | inputs = np.concatenate((inputs, slopes, intersects), axis=2)
42 |
43 | return inputs, outputs, slope, intersect
44 |
45 |
46 | class SimpleFunctionDataset(object):
47 | def __init__(self, num_total_batches=200000, num_samples_per_function=5,
48 | num_val_samples=5, meta_batch_size=75, oracle=False,
49 | train=True, device='cpu', dtype=torch.float, **kwargs):
50 | self._num_total_batches = num_total_batches
51 | self._num_samples_per_function = num_samples_per_function
52 | self._num_val_samples = num_val_samples
53 | self._num_total_samples = num_samples_per_function
54 | self._meta_batch_size = meta_batch_size
55 | self._oracle = oracle
56 | self._train = train
57 | self._device = device
58 | self._dtype = dtype
59 |
60 | def _generate_batch(self):
61 | raise NotImplementedError('Subclass should implement _generate_batch')
62 |
63 | def __iter__(self):
64 | for batch in range(self._num_total_batches):
65 | inputs, outputs, infos = self._generate_batch()
66 |
67 | train_tasks = []
68 | val_tasks = []
69 | for task in range(self._meta_batch_size):
70 | task_inputs = torch.tensor(
71 | inputs[task], device=self._device, dtype=self._dtype)
72 | task_outputs = torch.tensor(
73 | outputs[task], device=self._device, dtype=self._dtype)
74 | task_infos = infos[task]
75 | train_task = Task(task_inputs[self._num_val_samples:],
76 | task_outputs[self._num_val_samples:],
77 | task_infos)
78 | train_tasks.append(train_task)
79 | val_task = Task(task_inputs[:self._num_val_samples],
80 | task_outputs[:self._num_val_samples],
81 | task_infos)
82 | val_tasks.append(val_task)
83 | yield train_tasks, val_tasks
84 |
85 | class BiasedSinusoidMetaDataset(SimpleFunctionDataset):
86 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi],
87 | input_range=[-5.0, 5.0], bias=0, **kwargs):
88 | super(BiasedSinusoidMetaDataset, self).__init__(**kwargs)
89 | self._amp_range = amp_range
90 | self._phase_range = phase_range
91 | self._input_range = input_range
92 | self._bias = bias
93 |
94 | if self._oracle:
95 | self.input_size = 3
96 | else:
97 | self.input_size = 1
98 | self.output_size = 1
99 |
100 | def _generate_batch(self):
101 | inputs, outputs, amp, phase = generate_sinusoid_batch(
102 | amp_range=self._amp_range, phase_range=self._phase_range,
103 | input_range=self._input_range,
104 | num_samples=self._num_total_samples,
105 | batch_size=self._meta_batch_size, oracle=self._oracle, bias=self._bias)
106 | task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]}
107 | for i in range(len(amp))]
108 | return inputs, outputs, task_infos
109 |
110 | class SinusoidMetaDataset(SimpleFunctionDataset):
111 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi],
112 | input_range=[-5.0, 5.0], **kwargs):
113 | super(SinusoidMetaDataset, self).__init__(**kwargs)
114 | self._amp_range = amp_range
115 | self._phase_range = phase_range
116 | self._input_range = input_range
117 |
118 | if self._oracle:
119 | self.input_size = 3
120 | else:
121 | self.input_size = 1
122 | self.output_size = 1
123 |
124 | def _generate_batch(self):
125 | inputs, outputs, amp, phase = generate_sinusoid_batch(
126 | amp_range=self._amp_range, phase_range=self._phase_range,
127 | input_range=self._input_range,
128 | num_samples=self._num_total_samples,
129 | batch_size=self._meta_batch_size, oracle=self._oracle)
130 | task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]}
131 | for i in range(len(amp))]
132 | return inputs, outputs, task_infos
133 |
134 |
135 | class LinearMetaDataset(SimpleFunctionDataset):
136 | def __init__(self, slope_range=[-3.0, 3.0], intersect_range=[-3, 3],
137 | input_range=[-5.0, 5.0], **kwargs):
138 | super(LinearMetaDataset, self).__init__(**kwargs)
139 | self._slope_range = slope_range
140 | self._intersect_range = intersect_range
141 | self._input_range = input_range
142 |
143 | if self._oracle:
144 | self.input_size = 3
145 | else:
146 | self.input_size = 1
147 | self.output_size = 1
148 |
149 | def _generate_batch(self):
150 | inputs, outputs, slope, intersect = generate_linear_batch(
151 | slope_range=self._slope_range,
152 | intersect_range=self._intersect_range,
153 | input_range=self._input_range,
154 | num_samples=self._num_total_samples,
155 | batch_size=self._meta_batch_size, oracle=self._oracle)
156 | task_infos = [{'task_id': 0, 'slope': slope[i], 'intersect': intersect[i]}
157 | for i in range(len(slope))]
158 | return inputs, outputs, task_infos
159 |
160 | def generate_quadratic_batch(center_range, bias_range, sign_range, input_range,
161 | num_samples, batch_size, oracle):
162 | center = np.random.uniform(center_range[0], center_range[1], [batch_size])
163 | bias = np.random.uniform(bias_range[0], bias_range[1], [batch_size])
164 |
165 | # alpha range
166 | alpha = np.random.uniform(sign_range[0], sign_range[1], [batch_size])
167 | sign = np.random.randint(2, size=[batch_size])
168 | sign[sign == 0] = -1
169 | sign = alpha * sign
170 |
171 | outputs = np.zeros([batch_size, num_samples, 1])
172 | inputs = np.zeros([batch_size, num_samples, 1])
173 | for i in range(batch_size):
174 | inputs[i] = np.random.uniform(input_range[0], input_range[1],
175 | [num_samples, 1])
176 | outputs[i] = sign[i] * (inputs[i] - center[i])**2 + bias[i]
177 |
178 | if oracle:
179 | centers = np.ones_like(inputs) * center.reshape(-1, 1, 1)
180 | biases = np.ones_like(inputs) * bias.reshape(-1, 1, 1)
181 | inputs = np.concatenate((inputs, centers, biases), axis=2)
182 |
183 | return inputs, outputs, sign, center, bias
184 |
185 |
186 | class QuadraticMetaDataset(SimpleFunctionDataset):
187 | """ Quadratic function like: sign * (x - center)^2 + bias
188 | """
189 | def __init__(self, center_range=[-3.0, 3.0], bias_range=[-3, 3], sign_range=[0.02, 0.15],
190 | input_range=[-5.0, 5.0], **kwargs):
191 | super(QuadraticMetaDataset, self).__init__(**kwargs)
192 | self._center_range = center_range
193 | self._bias_range = bias_range
194 | self._input_range = input_range
195 | self._sign_range = sign_range
196 |
197 | if self._oracle:
198 | self.input_size = 3
199 | else:
200 | self.input_size = 1
201 | self.output_size = 1
202 |
203 | def _generate_batch(self):
204 | inputs, outputs, sign, center, bias = generate_quadratic_batch(
205 | center_range=self._center_range,
206 | bias_range=self._bias_range,
207 | sign_range=self._sign_range,
208 | input_range=self._input_range,
209 | num_samples=self._num_total_samples,
210 | batch_size=self._meta_batch_size, oracle=self._oracle)
211 | task_infos = [{'task_id': 2, 'sign': sign[i], 'center': center[i], 'bias': bias[i]}
212 | for i in range(len(sign))]
213 | return inputs, outputs, task_infos
214 |
215 |
216 | class MixedFunctionsMetaDataset(SimpleFunctionDataset):
217 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi],
218 | input_range=[-5.0, 5.0], slope_range=[-3.0, 3.0],
219 | intersect_range=[-3.0, 3.0], task_oracle=False,
220 | noise_std=0, **kwargs):
221 | super(MixedFunctionsMetaDataset, self).__init__(**kwargs)
222 | self._amp_range = amp_range
223 | self._phase_range = phase_range
224 | self._slope_range = slope_range
225 | self._intersect_range = intersect_range
226 | self._input_range = input_range
227 | self._task_oracle = task_oracle
228 | self._noise_std = noise_std
229 |
230 | if not self._oracle:
231 | if not self._task_oracle:
232 | self.input_size = 1
233 | else:
234 | self.input_size = 2
235 | else:
236 | if not self._task_oracle:
237 | self.input_size = 3
238 | else:
239 | self.input_size = 4
240 |
241 | self.output_size = 1
242 | self.num_tasks = 2
243 |
244 | def _generate_batch(self):
245 | half_batch_size = self._meta_batch_size // 2
246 | sin_inputs, sin_outputs, amp, phase = generate_sinusoid_batch(
247 | amp_range=self._amp_range, phase_range=self._phase_range,
248 | input_range=self._input_range,
249 | num_samples=self._num_total_samples,
250 | batch_size=half_batch_size, oracle=self._oracle)
251 | sin_task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]}
252 | for i in range(len(amp))]
253 | if self._task_oracle:
254 | sin_inputs = np.concatenate(
255 | (sin_inputs, np.zeros(sin_inputs.shape[:2] + (1,))), axis=2)
256 |
257 | lin_inputs, lin_outputs, slope, intersect = generate_linear_batch(
258 | slope_range=self._slope_range,
259 | intersect_range=self._intersect_range,
260 | input_range=self._input_range,
261 | num_samples=self._num_total_samples,
262 | batch_size=half_batch_size, oracle=self._oracle)
263 | lin_task_infos = [{'task_id': 1, 'slope': slope[i], 'intersect': intersect[i]}
264 | for i in range(len(slope))]
265 | if self._task_oracle:
266 | lin_inputs = np.concatenate(
267 | (lin_inputs, np.ones(lin_inputs.shape[:2] + (1,))), axis=2)
268 | inputs = np.concatenate((sin_inputs, lin_inputs))
269 | outputs = np.concatenate((sin_outputs, lin_outputs))
270 |
271 | if self._noise_std > 0:
272 | outputs = outputs + np.random.normal(scale=self._noise_std, size=outputs.shape)
273 | task_infos = sin_task_infos + lin_task_infos
274 | return inputs, outputs, task_infos
275 |
276 | class ManyFunctionsMetaDataset(SimpleFunctionDataset):
277 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi],
278 | input_range=[-5.0, 5.0], slope_range=[-3.0, 3.0],
279 | intersect_range=[-3.0, 3.0], center_range=[-3.0, 3.0],
280 | bias_range=[-3.0, 3.0], sign_range=[0.02, 0.15], task_oracle=False,
281 | noise_std=0, **kwargs):
282 | super(ManyFunctionsMetaDataset, self).__init__(**kwargs)
283 | self._amp_range = amp_range
284 | self._phase_range = phase_range
285 | self._slope_range = slope_range
286 | self._intersect_range = intersect_range
287 | self._input_range = input_range
288 | self._center_range = center_range
289 | self._bias_range = bias_range
290 | self._sign_range = sign_range
291 | self._task_oracle = task_oracle
292 | self._noise_std = noise_std
293 |
294 | if not self._oracle:
295 | if not self._task_oracle:
296 | self.input_size = 1
297 | else:
298 | self.input_size = 2
299 | else:
300 | if not self._task_oracle:
301 | self.input_size = 3
302 | else:
303 | self.input_size = 4
304 |
305 | self.output_size = 1
306 | self.num_tasks = 2
307 |
308 | def _generate_batch(self):
309 | half_batch_size = self._meta_batch_size // 3
310 | sin_inputs, sin_outputs, amp, phase = generate_sinusoid_batch(
311 | amp_range=self._amp_range, phase_range=self._phase_range,
312 | input_range=self._input_range,
313 | num_samples=self._num_total_samples,
314 | batch_size=half_batch_size, oracle=self._oracle)
315 | sin_task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]}
316 | for i in range(len(amp))]
317 | if self._task_oracle:
318 | sin_inputs = np.concatenate(
319 | (sin_inputs, np.zeros(sin_inputs.shape[:2] + (1,))), axis=2)
320 |
321 | lin_inputs, lin_outputs, slope, intersect = generate_linear_batch(
322 | slope_range=self._slope_range,
323 | intersect_range=self._intersect_range,
324 | input_range=self._input_range,
325 | num_samples=self._num_total_samples,
326 | batch_size=half_batch_size, oracle=self._oracle)
327 | lin_task_infos = [{'task_id': 1, 'slope': slope[i], 'intersect': intersect[i]}
328 | for i in range(len(slope))]
329 | if self._task_oracle:
330 | lin_inputs = np.concatenate(
331 | (lin_inputs, np.ones(lin_inputs.shape[:2] + (1,))), axis=2)
332 |
333 | qua_inputs, qua_outputs, sign, center, bias = generate_quadratic_batch(
334 | center_range=self._center_range,
335 | bias_range=self._bias_range,
336 | sign_range=self._sign_range,
337 | input_range=self._input_range,
338 | num_samples=self._num_total_samples,
339 | batch_size=half_batch_size, oracle=self._oracle)
340 | qua_task_infos = [{'task_id': 2, 'sign': sign[i], 'center': center[i], 'bias': bias[i]}
341 | for i in range(len(sign))]
342 |
343 | if self._task_oracle:
344 | qua_inputs = np.concatenate(
345 | (qua_inputs, np.ones(qua_inputs.shape[:2] + (1,))), axis=2)
346 |
347 | inputs = np.concatenate((sin_inputs, lin_inputs, qua_inputs))
348 | outputs = np.concatenate((sin_outputs, lin_outputs, qua_outputs))
349 |
350 | if self._noise_std > 0:
351 | outputs = outputs + np.random.normal(scale=self._noise_std, size=outputs.shape)
352 | task_infos = sin_task_infos + lin_task_infos + qua_task_infos
353 | return inputs, outputs, task_infos
354 |
355 |
356 | class MultiSinusoidsMetaDataset(SimpleFunctionDataset):
357 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi], biases=(-5, 5),
358 | input_range=[-5.0, 5.0], task_oracle=False,
359 | noise_std=0, **kwargs):
360 | super(MultiSinusoidsMetaDataset, self).__init__(**kwargs)
361 | self._amp_range = amp_range
362 | self._phase_range = phase_range
363 | self._input_range = input_range
364 | self._task_oracle = task_oracle
365 | self._noise_std = noise_std
366 | self._biases = biases
367 |
368 | if not self._oracle:
369 | if not self._task_oracle:
370 | self.input_size = 1
371 | else:
372 | self.input_size = 2
373 | else:
374 | if not self._task_oracle:
375 | self.input_size = 3
376 | else:
377 | self.input_size = 4
378 |
379 | self.output_size = 1
380 | self.num_tasks = 2
381 |
382 | def _generate_batch(self):
383 | half_batch_size = self._meta_batch_size // 2
384 | sin1_inputs, sin1_outputs, amp1, phase1 = generate_sinusoid_batch(
385 | amp_range=self._amp_range, phase_range=self._phase_range,
386 | input_range=self._input_range,
387 | num_samples=self._num_total_samples,
388 | batch_size=half_batch_size, oracle=self._oracle, bias=self._biases[0])
389 | sin1_task_infos = [{'task_id': 0, 'amp': amp1[i], 'phase': phase1[i]}
390 | for i in range(len(amp1))]
391 | if self._task_oracle:
392 | sin1_inputs = np.concatenate(
393 | (sin1_inputs, np.zeros(sin1_inputs.shape[:2] + (1,))), axis=2)
394 |
395 | sin2_inputs, sin2_outputs, amp2, phase2 = generate_sinusoid_batch(
396 | amp_range=self._amp_range, phase_range=self._phase_range,
397 | input_range=self._input_range,
398 | num_samples=self._num_total_samples,
399 | batch_size=half_batch_size, oracle=self._oracle, bias=self._biases[0])
400 | sin2_task_infos = [{'task_id': 1, 'amp': amp2[i], 'phase': phase2[i]}
401 | for i in range(len(amp2))]
402 | if self._task_oracle:
403 | sin2_inputs = np.concatenate(
404 | (sin2_inputs, np.zeros(sin2_inputs.shape[:2] + (1,))), axis=2)
405 |
406 | inputs = np.concatenate((sin1_inputs, sin2_inputs))
407 | outputs = np.concatenate((sin2_outputs, sin2_outputs))
408 |
409 | if self._noise_std > 0:
410 | outputs = outputs + np.random.normal(scale=self._noise_std,
411 | size=outputs.shape)
412 | task_infos = sin1_task_infos + sin2_task_infos
413 | return inputs, outputs, task_infos
414 |
415 | def generate_tanh_batch(center_range, bias_range, slope_range, input_range,
416 | num_samples, batch_size, oracle):
417 | center = np.random.uniform(center_range[0], center_range[1], [batch_size])
418 | bias = np.random.uniform(bias_range[0], bias_range[1], [batch_size])
419 |
420 | # alpha range
421 | slope = np.random.uniform(slope_range[0], slope_range[1], [batch_size])
422 |
423 | outputs = np.zeros([batch_size, num_samples, 1])
424 | inputs = np.zeros([batch_size, num_samples, 1])
425 | for i in range(batch_size):
426 | inputs[i] = np.random.uniform(input_range[0], input_range[1],
427 | [num_samples, 1])
428 | outputs[i] = slope[i] * np.tanh(inputs[i] - center[i]) + bias[i]
429 |
430 | if oracle:
431 | centers = np.ones_like(inputs) * center.reshape(-1, 1, 1)
432 | biases = np.ones_like(inputs) * bias.reshape(-1, 1, 1)
433 | inputs = np.concatenate((inputs, centers, biases), axis=2)
434 |
435 | return inputs, outputs, slope, center, bias
436 |
437 | def generate_abs_batch(slope_range, center_range, bias_range, input_range,
438 | num_samples, batch_size, oracle):
439 | slope = np.random.uniform(slope_range[0], slope_range[1], [batch_size])
440 | bias = np.random.uniform(bias_range[0], bias_range[1], [batch_size])
441 | center = np.random.uniform(center_range[0], center_range[1], [batch_size])
442 |
443 | outputs = np.zeros([batch_size, num_samples, 1])
444 | inputs = np.zeros([batch_size, num_samples, 1])
445 | for i in range(batch_size):
446 | inputs[i] = np.random.uniform(input_range[0], input_range[1],
447 | [num_samples, 1])
448 | outputs[i] = np.abs(inputs[i] - center[i]) * slope[i] + bias[i]
449 |
450 | if oracle:
451 | slopes = np.ones_like(inputs) * slope.reshape(-1, 1, 1)
452 | intersects = np.ones_like(inputs) * intersect.reshape(-1, 1, 1)
453 | inputs = np.concatenate((inputs, slopes, intersects), axis=2)
454 |
455 | return inputs, outputs, slope, center, bias
456 |
457 |
458 | class FiveFunctionsMetaDataset(SimpleFunctionDataset):
459 | def __init__(self, amp_range=[0.1, 5.0], phase_range=[0, np.pi],
460 | input_range=[-5.0, 5.0], slope_range=[-3.0, 3.0],
461 | intersect_range=[-3.0, 3.0], center_range=[-3.0, 3.0],
462 | bias_range=[-3.0, 3.0], sign_range=[0.02, 0.15], task_oracle=False,
463 | noise_std=0, **kwargs):
464 | super(FiveFunctionsMetaDataset, self).__init__(**kwargs)
465 | self._amp_range = amp_range
466 | self._phase_range = phase_range
467 | self._slope_range = slope_range
468 | self._intersect_range = intersect_range
469 | self._input_range = input_range
470 | self._center_range = center_range
471 | self._bias_range = bias_range
472 | self._sign_range = sign_range
473 | self._task_oracle = task_oracle
474 | self._noise_std = noise_std
475 |
476 | assert self._task_oracle == False
477 | if not self._oracle:
478 | if not self._task_oracle:
479 | self.input_size = 1
480 | else:
481 | self.input_size = 2
482 | else:
483 | if not self._task_oracle:
484 | self.input_size = 3
485 | else:
486 | self.input_size = 4
487 |
488 | self.output_size = 1
489 | self.num_tasks = 2
490 |
491 | def _generate_batch(self):
492 | assert self._meta_batch_size % 5 == 0, 'Error size of meta batch.'
493 | half_batch_size = self._meta_batch_size // 5
494 | sin_inputs, sin_outputs, amp, phase = generate_sinusoid_batch(
495 | amp_range=self._amp_range, phase_range=self._phase_range,
496 | input_range=self._input_range,
497 | num_samples=self._num_total_samples,
498 | batch_size=half_batch_size, oracle=self._oracle)
499 | sin_task_infos = [{'task_id': 0, 'amp': amp[i], 'phase': phase[i]}
500 | for i in range(len(amp))]
501 |
502 | lin_inputs, lin_outputs, slope, intersect = generate_linear_batch(
503 | slope_range=self._slope_range,
504 | intersect_range=self._intersect_range,
505 | input_range=self._input_range,
506 | num_samples=self._num_total_samples,
507 | batch_size=half_batch_size, oracle=self._oracle)
508 | lin_task_infos = [{'task_id': 1, 'slope': slope[i], 'intersect': intersect[i]}
509 | for i in range(len(slope))]
510 |
511 | qua_inputs, qua_outputs, sign, center, bias = generate_quadratic_batch(
512 | center_range=self._center_range,
513 | bias_range=self._bias_range,
514 | sign_range=self._sign_range,
515 | input_range=self._input_range,
516 | num_samples=self._num_total_samples,
517 | batch_size=half_batch_size, oracle=self._oracle)
518 | qua_task_infos = [{'task_id': 2, 'sign': sign[i], 'center': center[i], 'bias': bias[i]}
519 | for i in range(len(sign))]
520 |
521 | tanh_inputs, tanh_outputs, slope, center, bias = generate_tanh_batch(
522 | center_range=self._center_range,
523 | bias_range=self._bias_range,
524 | slope_range=self._slope_range,
525 | input_range=self._input_range,
526 | num_samples=self._num_total_samples,
527 | batch_size=half_batch_size, oracle=self._oracle)
528 | tanh_task_infos = [{'task_id': 3, 'slope': slope[i], 'center': center[i], 'bias': bias[i]}
529 | for i in range(len(sign))]
530 |
531 | abs_inputs, abs_outputs, slope, center, bias = generate_abs_batch(
532 | slope_range=self._slope_range,
533 | center_range=self._center_range,
534 | bias_range=self._bias_range,
535 | input_range=self._input_range,
536 | num_samples=self._num_total_samples,
537 | batch_size=half_batch_size, oracle=self._oracle)
538 | abs_task_infos = [{'task_id': 4, 'slope': slope[i], 'center': center[i], 'bias': bias[i]}
539 | for i in range(len(sign))]
540 |
541 | inputs = np.concatenate((sin_inputs, lin_inputs, qua_inputs, tanh_inputs, abs_inputs))
542 | outputs = np.concatenate((sin_outputs, lin_outputs, qua_outputs, tanh_outputs, abs_outputs))
543 |
544 | if self._noise_std > 0:
545 | outputs = outputs + np.random.normal(scale=self._noise_std, size=outputs.shape)
546 | task_infos = sin_task_infos + lin_task_infos + qua_task_infos + tanh_task_infos + abs_task_infos
547 | return inputs, outputs, task_infos
548 |
--------------------------------------------------------------------------------
/maml/metalearner.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch.nn.utils.clip_grad import clip_grad_norm_, clip_grad_value_
4 | from maml.utils import accuracy
5 |
6 |
7 | def get_grad_norm(parameters, norm_type=2):
8 | if isinstance(parameters, torch.Tensor):
9 | parameters = [parameters]
10 | parameters = list(filter(lambda p: p.grad is not None, parameters))
11 | norm_type = float(norm_type)
12 | total_norm = 0
13 | for p in parameters:
14 | param_norm = p.grad.data.norm(norm_type)
15 | total_norm += param_norm.item() ** norm_type
16 | total_norm = total_norm ** (1. / norm_type)
17 |
18 | return total_norm
19 |
20 |
21 | class MetaLearner(object):
22 | def __init__(self, model, embedding_model, optimizers, fast_lr, loss_func,
23 | first_order, num_updates, inner_loop_grad_clip,
24 | collect_accuracies, device, embedding_grad_clip=0,
25 | model_grad_clip=0):
26 | self._model = model
27 | self._embedding_model = embedding_model
28 | self._fast_lr = fast_lr
29 | self._optimizers = optimizers
30 | self._loss_func = loss_func
31 | self._first_order = first_order
32 | self._num_updates = num_updates
33 | self._inner_loop_grad_clip = inner_loop_grad_clip
34 | self._collect_accuracies = collect_accuracies
35 | self._device = device
36 | self._embedding_grad_clip = embedding_grad_clip
37 | self._model_grad_clip = model_grad_clip
38 | self._grads_mean = []
39 |
40 | self.to(device)
41 |
42 | self._reset_measurements()
43 |
44 | def _reset_measurements(self):
45 | self._count_iters = 0.0
46 | self._cum_loss = 0.0
47 | self._cum_accuracy = 0.0
48 |
49 | def _update_measurements(self, task, loss, preds):
50 | self._count_iters += 1.0
51 | self._cum_loss += loss.data.cpu().numpy()
52 | if self._collect_accuracies:
53 | self._cum_accuracy += accuracy(
54 | preds, task.y).data.cpu().numpy()
55 |
56 | def _pop_measurements(self):
57 | measurements = {}
58 | loss = self._cum_loss / (self._count_iters + 1e-32)
59 | measurements['loss'] = loss
60 | if self._collect_accuracies:
61 | accuracy = self._cum_accuracy / (self._count_iters + 1e-32)
62 | measurements['accuracy'] = accuracy
63 | self._reset_measurements()
64 | return measurements
65 |
66 | def measure(self, tasks, train_tasks=None, adapted_params_list=None,
67 | embeddings_list=None):
68 | """Measures performance on tasks. Either train_tasks has to be a list
69 | of training task for computing embeddings, or adapted_params_list and
70 | embeddings_list have to contain adapted_params and embeddings"""
71 | if adapted_params_list is None:
72 | adapted_params_list = [None] * len(tasks)
73 | if embeddings_list is None:
74 | embeddings_list = [None] * len(tasks)
75 | for i in range(len(tasks)):
76 | params = adapted_params_list[i]
77 | if params is None:
78 | params = self._model.param_dict
79 | embeddings = embeddings_list[i]
80 | task = tasks[i]
81 | preds = self._model(task, params=params, embeddings=embeddings)
82 | loss = self._loss_func(preds, task.y)
83 | self._update_measurements(task, loss, preds)
84 |
85 | measurements = self._pop_measurements()
86 | return measurements
87 |
88 | def update_params(self, loss, params):
89 | """Apply one step of gradient descent on the loss function `loss`,
90 | with step-size `self._fast_lr`, and returns the updated parameters.
91 | """
92 | create_graph = not self._first_order
93 | grads = torch.autograd.grad(loss, params.values(),
94 | create_graph=create_graph, allow_unused=True)
95 | for (name, param), grad in zip(params.items(), grads):
96 | if self._inner_loop_grad_clip > 0 and grad is not None:
97 | grad = grad.clamp(min=-self._inner_loop_grad_clip,
98 | max=self._inner_loop_grad_clip)
99 | if grad is not None:
100 | params[name] = param - self._fast_lr * grad
101 |
102 | return params
103 |
104 | def adapt(self, train_tasks, return_task_embedding=False):
105 | adapted_params = []
106 | embeddings_list = []
107 | task_embeddings_list = []
108 |
109 | for task in train_tasks:
110 | params = self._model.param_dict
111 | embeddings = None
112 | if self._embedding_model:
113 | if return_task_embedding:
114 | embeddings, task_embedding = self._embedding_model(
115 | task, return_task_embedding=True)
116 | task_embeddings_list.append(task_embedding)
117 | else:
118 | embeddings = self._embedding_model(
119 | task, return_task_embedding=False)
120 | for i in range(self._num_updates):
121 | preds = self._model(task, params=params, embeddings=embeddings)
122 | loss = self._loss_func(preds, task.y)
123 | params = self.update_params(loss, params=params)
124 | if i == 0:
125 | self._update_measurements(task, loss, preds)
126 | adapted_params.append(params)
127 | embeddings_list.append(embeddings)
128 |
129 | measurements = self._pop_measurements()
130 | if return_task_embedding:
131 | return measurements, adapted_params, embeddings_list, task_embeddings_list
132 | else:
133 | return measurements, adapted_params, embeddings_list
134 |
135 | def step(self, adapted_params_list, embeddings_list, val_tasks,
136 | is_training):
137 | for optimizer in self._optimizers:
138 | optimizer.zero_grad()
139 | post_update_losses = []
140 |
141 | for adapted_params, embeddings, task in zip(
142 | adapted_params_list, embeddings_list, val_tasks):
143 | preds = self._model(task, params=adapted_params,
144 | embeddings=embeddings)
145 | loss = self._loss_func(preds, task.y)
146 | post_update_losses.append(loss)
147 | self._update_measurements(task, loss, preds)
148 |
149 | mean_loss = torch.mean(torch.stack(post_update_losses))
150 | if is_training:
151 | self._optimizers[0].zero_grad()
152 | if len(self._optimizers) > 1:
153 | self._optimizers[1].zero_grad()
154 |
155 | mean_loss.backward()
156 | if len(self._optimizers) > 1:
157 | if self._embedding_grad_clip > 0:
158 | _grad_norm = clip_grad_norm_(
159 | self._embedding_model.parameters(), self._embedding_grad_clip)
160 | else:
161 | _grad_norm = get_grad_norm(
162 | self._embedding_model.parameters())
163 | # grad_norm
164 | self._grads_mean.append(_grad_norm)
165 | self._optimizers[1].step()
166 |
167 | if self._model_grad_clip > 0:
168 | _grad_norm = clip_grad_norm_(
169 | self._model.parameters(), self._model_grad_clip)
170 | self._optimizers[0].step()
171 |
172 | measurements = self._pop_measurements()
173 | return measurements
174 |
175 | def to(self, device, **kwargs):
176 | self._device = device
177 | self._model.to(device, **kwargs)
178 | if self._embedding_model:
179 | self._embedding_model.to(device, **kwargs)
180 |
181 | def state_dict(self):
182 | state = {
183 | 'model_state_dict': self._model.state_dict(),
184 | 'optimizers': [optimizer.state_dict() for optimizer in self._optimizers]
185 | }
186 | if self._embedding_model:
187 | state.update(
188 | {'embedding_model_state_dict':
189 | self._embedding_model.state_dict()})
190 | return state
191 |
--------------------------------------------------------------------------------
/maml/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vuoristo/MMAML-Regression/1a8bea4d60461d8814e3c9f91427ed56378716ce/maml/models/__init__.py
--------------------------------------------------------------------------------
/maml/models/fully_connected.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from maml.models.model import Model
7 |
8 |
9 | def weight_init(module):
10 | if isinstance(module, torch.nn.Linear):
11 | torch.nn.init.normal_(module.weight, mean=0, std=0.01)
12 | module.bias.data.zero_()
13 |
14 |
15 | class FullyConnectedModel(Model):
16 | def __init__(self, input_size, output_size, hidden_sizes=(),
17 | nonlinearity=F.relu, disable_norm=False,
18 | bias_transformation_size=0):
19 | super(FullyConnectedModel, self).__init__()
20 | self.hidden_sizes = hidden_sizes
21 | self.nonlinearity = nonlinearity
22 | self.num_layers = len(hidden_sizes) + 1
23 | self.disable_norm = disable_norm
24 | self.bias_transformation_size = bias_transformation_size
25 |
26 | if bias_transformation_size > 0:
27 | input_size = input_size + bias_transformation_size
28 | self.bias_transformation = torch.nn.Parameter(
29 | torch.zeros(bias_transformation_size))
30 |
31 | layer_sizes = [input_size] + hidden_sizes + [output_size]
32 | for i in range(1, self.num_layers):
33 | self.add_module(
34 | 'layer{0}_linear'.format(i),
35 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i]))
36 | if not self.disable_norm:
37 | self.add_module(
38 | 'layer{0}_bn'.format(i),
39 | torch.nn.BatchNorm1d(layer_sizes[i], momentum=0.001))
40 | self.add_module(
41 | 'output_linear',
42 | torch.nn.Linear(layer_sizes[self.num_layers - 1],
43 | layer_sizes[self.num_layers]))
44 | self.apply(weight_init)
45 |
46 | def forward(self, task, params=None, training=True, embeddings=None):
47 | if params is None:
48 | params = OrderedDict(self.named_parameters())
49 | x = task.x.view(task.x.size(0), -1)
50 |
51 | if self.bias_transformation_size > 0:
52 | x = torch.cat((x, params['bias_transformation'].expand(
53 | x.size(0), params['bias_transformation'].size(0))), dim=1)
54 |
55 | for key, module in self.named_modules():
56 | if 'linear' in key:
57 | x = F.linear(x, weight=params[key + '.weight'],
58 | bias=params[key + '.bias'])
59 | if self.disable_norm and 'output' not in key:
60 | x = self.nonlinearity(x)
61 | if 'bn' in key:
62 | x = F.batch_norm(x, weight=params[key + '.weight'],
63 | bias=params[key + '.bias'],
64 | running_mean=module.running_mean,
65 | running_var=module.running_var,
66 | training=training)
67 | x = self.nonlinearity(x)
68 | return x
69 |
70 |
71 | class MultiFullyConnectedModel(Model):
72 | def __init__(self, input_size, output_size, hidden_sizes=(),
73 | nonlinearity=F.relu, disable_norm=False, num_tasks=1,
74 | bias_transformation_size=0):
75 | super(MultiFullyConnectedModel, self).__init__()
76 | self.hidden_sizes = hidden_sizes
77 | self.nonlinearity = nonlinearity
78 | self.num_layers = len(hidden_sizes) + 1
79 | self.disable_norm = disable_norm
80 | self.bias_transformation_size = bias_transformation_size
81 | self.num_tasks = num_tasks
82 |
83 | if bias_transformation_size > 0:
84 | input_size = input_size + bias_transformation_size
85 | self.bias_transformation = torch.nn.Embedding(
86 | self.num_tasks, bias_transformation_size
87 | )
88 |
89 | layer_sizes = [input_size] + hidden_sizes + [output_size]
90 | for j in range(0, self.num_tasks):
91 | for i in range(1, self.num_layers):
92 | self.add_module(
93 | 'task{0}_layer{1}_linear'.format(j, i),
94 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i]))
95 | if not self.disable_norm:
96 | self.add_module(
97 | 'task{0}_layer{1}_bn'.format(j, i),
98 | torch.nn.BatchNorm1d(layer_sizes[i], momentum=0.001))
99 | self.add_module(
100 | 'task{0}_output_linear'.format(j),
101 | torch.nn.Linear(layer_sizes[self.num_layers - 1],
102 | layer_sizes[self.num_layers]))
103 | self.apply(weight_init)
104 |
105 | def forward(self, task, params=None, training=True, embeddings=None):
106 | if params is None:
107 | params = OrderedDict(self.named_parameters())
108 | x = task.x.view(task.x.size(0), -1)
109 | task_id = task.task_info['task_id']
110 |
111 | if self.bias_transformation_size > 0:
112 | bias_trans = self.bias_transformation(
113 | torch.LongTensor([task_id]).to(x.device))
114 | x = torch.cat((
115 | x,
116 | bias_trans.expand(x.size(0), bias_trans.size(1))
117 | ), dim=1)
118 |
119 | for key, module in self.named_modules():
120 | if 'task{0}'.format(task_id) in key:
121 | if 'linear' in key:
122 | x = F.linear(x, weight=params[key + '.weight'],
123 | bias=params[key + '.bias'])
124 | if self.disable_norm and 'output' not in key:
125 | x = self.nonlinearity(x)
126 | if 'bn' in key:
127 | x = F.batch_norm(x, weight=params[key + '.weight'],
128 | bias=params[key + '.bias'],
129 | running_mean=module.running_mean,
130 | running_var=module.running_var,
131 | training=training)
132 | x = self.nonlinearity(x)
133 | return x
134 |
--------------------------------------------------------------------------------
/maml/models/gated_net.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from maml.models.model import Model
7 |
8 |
9 | def weight_init(module):
10 | if isinstance(module, torch.nn.Linear):
11 | torch.nn.init.normal_(module.weight, mean=0, std=0.01)
12 | module.bias.data.zero_()
13 |
14 |
15 | class GatedNet(Model):
16 | def __init__(self, input_size, output_size, hidden_sizes=[40, 40],
17 | nonlinearity=F.relu, condition_type='sigmoid_gate', condition_order='low2high'):
18 | super(GatedNet, self).__init__()
19 | self._nonlinearity = nonlinearity
20 | self._condition_type = condition_type
21 | self._condition_order = condition_order
22 |
23 | self.num_layers = len(hidden_sizes) + 1
24 |
25 | layer_sizes = [input_size] + hidden_sizes + [output_size]
26 | for i in range(1, self.num_layers):
27 | self.add_module(
28 | 'layer{0}_linear'.format(i),
29 | torch.nn.Linear(layer_sizes[i - 1], layer_sizes[i]))
30 | self.add_module(
31 | 'output_linear',
32 | torch.nn.Linear(layer_sizes[self.num_layers - 1],
33 | layer_sizes[self.num_layers]))
34 | self.apply(weight_init)
35 |
36 | def conditional_layer(self, x, embedding):
37 | if self._condition_type == 'sigmoid_gate':
38 | x = x * F.sigmoid(embedding).expand_as(x)
39 | elif self._condition_type == 'affine':
40 | gammas, betas = torch.split(embedding, x.size(1), dim=-1)
41 | gammas = gammas + torch.ones_like(gammas)
42 | x = x * gammas + betas
43 | elif self._condition_type == 'softmax':
44 | x = x * F.softmax(embedding).expand_as(x)
45 | else:
46 | raise ValueError('Unrecognized conditional layer type {}'.format(
47 | self._condition_type))
48 | return x
49 |
50 | def forward(self, task, params=None, embeddings=None, training=True):
51 | if params is None:
52 | params = OrderedDict(self.named_parameters())
53 |
54 | if embeddings is not None:
55 | if self._condition_order == 'high2low': # High2Low
56 | embeddings = {'layer{}_linear'.format(len(params)-i): embedding
57 | for i, embedding in enumerate(embeddings[::-1])}
58 | elif self._condition_order == 'low2high': # Low2High
59 | embeddings = {'layer{}_linear'.format(i): embedding
60 | for i, embedding in enumerate(embeddings[::-1], start=1)}
61 | else:
62 | raise NotImplementedError(
63 | 'Unsuppported order for using conditional layers')
64 | x = task.x.view(task.x.size(0), -1)
65 |
66 | for key, module in self.named_modules():
67 | if 'linear' in key:
68 | x = F.linear(x, weight=params[key + '.weight'],
69 | bias=params[key + '.bias'])
70 | if 'output' not in key and embeddings is not None: # conditioning and nonlinearity
71 | if type(embeddings.get(key, -1)) != type(-1):
72 | x = self.conditional_layer(x, embeddings[key])
73 |
74 | x = self._nonlinearity(x)
75 |
76 | return x
77 |
--------------------------------------------------------------------------------
/maml/models/lstm_embedding_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LSTMEmbeddingModel(torch.nn.Module):
5 | def __init__(self, input_size, output_size, embedding_dims,
6 | hidden_size=40, num_layers=2):
7 | super(LSTMEmbeddingModel, self).__init__()
8 | self._input_size = input_size
9 | self._output_size = output_size
10 | self._hidden_size = hidden_size
11 | self._num_layers = num_layers
12 | self._embedding_dims = embedding_dims
13 | self._bidirectional = True
14 | self._device = 'cpu'
15 |
16 | rnn_input_size = int(input_size + output_size)
17 | self.rnn = torch.nn.LSTM(
18 | rnn_input_size, hidden_size, num_layers, bidirectional=self._bidirectional)
19 |
20 | self._embeddings = torch.nn.ModuleList()
21 | for dim in embedding_dims:
22 | self._embeddings.append(torch.nn.Linear(
23 | hidden_size*(2 if self._bidirectional else 1), dim))
24 |
25 | def forward(self, task, return_task_embedding=True):
26 | batch_size = 1
27 | h0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1),
28 | batch_size, self._hidden_size, device=self._device)
29 | c0 = torch.zeros(self._num_layers*(2 if self._bidirectional else 1),
30 | batch_size, self._hidden_size, device=self._device)
31 |
32 | x = task.x.view(task.x.size(0), -1)
33 | y = task.y.view(task.y.size(0), -1)
34 |
35 | # LSTM input dimensions are seq_len, batch, input_size
36 | inputs = torch.cat((x, y), dim=1).view(x.size(0), 1, -1)
37 | output, (hn, cn) = self.rnn(inputs, (h0, c0))
38 | if self._bidirectional:
39 | N, B, H = output.shape
40 | output = output.view(N, B, 2, H // 2)
41 | embedding_input = torch.cat(
42 | [output[-1, :, 0], output[0, :, 1]], dim=1)
43 |
44 | out_embeddings = []
45 | for embedding in self._embeddings:
46 | out_embeddings.append(embedding(embedding_input))
47 | if return_task_embedding:
48 | return out_embeddings, embedding_input
49 | else:
50 | return out_embeddings
51 |
52 | def to(self, device, **kwargs):
53 | self._device = device
54 | super(LSTMEmbeddingModel, self).to(device, **kwargs)
55 |
--------------------------------------------------------------------------------
/maml/models/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | class Model(torch.nn.Module):
7 | def __init__(self):
8 | super(Model, self).__init__()
9 |
10 | @property
11 | def param_dict(self):
12 | return OrderedDict(self.named_parameters())
13 |
--------------------------------------------------------------------------------
/maml/models/simple_embedding_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class SimpleEmbeddingModel(torch.nn.Module):
5 | def __init__(self, num_embeddings, embedding_dims):
6 | super(SimpleEmbeddingModel, self).__init__()
7 | self._embeddings = torch.nn.ModuleList()
8 | for dim in embedding_dims:
9 | self._embeddings.append(torch.nn.Embedding(num_embeddings, dim))
10 | self._device = 'cpu'
11 |
12 | def forward(self, task):
13 | task_id = torch.tensor(task.task_id, dtype=torch.long,
14 | device=self._device)
15 | out_embeddings = []
16 | for embedding in self._embeddings:
17 | out_embeddings.append(embedding(task_id))
18 | return out_embeddings
19 |
20 | def to(self, device, **kwargs):
21 | self._device = device
22 | super(SimpleEmbeddingModel, self).to(device, **kwargs)
23 |
--------------------------------------------------------------------------------
/maml/sampler.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict, namedtuple
3 |
4 | from torch.utils.data.sampler import Sampler
5 |
6 |
7 | class ClassBalancedSampler(Sampler):
8 | """Generates indices for class balanced batch by sampling with replacement.
9 | """
10 |
11 | def __init__(self, dataset_labels, num_classes_per_batch,
12 | num_samples_per_class, num_total_batches, train):
13 | """
14 | Args:
15 | dataset_labels: list of dataset labels
16 | num_classes_per_batch: number of classes to sample for each batch
17 | num_samples_per_class: number of samples to sample for each class
18 | for each batch. For K shot learning this should be K + number
19 | of validation samples
20 | num_total_batches: total number of batches to generate
21 | """
22 | self._dataset_labels = dataset_labels
23 | self._classes = set(self._dataset_labels)
24 | self._class_to_samples = defaultdict(set)
25 | for i, c in enumerate(self._dataset_labels):
26 | self._class_to_samples[c].add(i)
27 |
28 | self._num_classes_per_batch = num_classes_per_batch
29 | self._num_samples_per_class = num_samples_per_class
30 | self._num_total_batches = num_total_batches
31 | self._train = train
32 |
33 | def __iter__(self):
34 | for i in range(self._num_total_batches):
35 | batch_classes = random.sample(
36 | self._class_to_samples.keys(), self._num_classes_per_batch)
37 | batch_samples = []
38 | for c in batch_classes:
39 | class_samples = random.sample(
40 | self._class_to_samples[c], self._num_samples_per_class)
41 | for sample in class_samples:
42 | batch_samples.append(sample)
43 | random.shuffle(batch_samples)
44 | for sample in batch_samples:
45 | yield sample
46 |
47 | def __len__(self):
48 | return self._num_total_batches
49 |
--------------------------------------------------------------------------------
/maml/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | class Trainer(object):
10 | def __init__(self, meta_learner, meta_dataset, writer, log_interval,
11 | save_interval, model_type, save_folder):
12 | self._meta_learner = meta_learner
13 | self._meta_dataset = meta_dataset
14 | self._writer = writer
15 | self._log_interval = log_interval
16 | self._save_interval = save_interval
17 | self._model_type = model_type
18 | self._save_folder = save_folder
19 |
20 | def run(self, is_training):
21 | if not is_training:
22 | all_pre_val_measurements = defaultdict(list)
23 | all_pre_train_measurements = defaultdict(list)
24 | all_post_val_measurements = defaultdict(list)
25 | all_post_train_measurements = defaultdict(list)
26 |
27 | for i, (train_tasks, val_tasks) in enumerate(
28 | iter(self._meta_dataset), start=1):
29 | (pre_train_measurements, adapted_params, embeddings
30 | ) = self._meta_learner.adapt(train_tasks)
31 | post_val_measurements = self._meta_learner.step(
32 | adapted_params, embeddings, val_tasks, is_training)
33 |
34 | # Tensorboard
35 | if (i % self._log_interval == 0 or i == 1):
36 | pre_val_measurements = self._meta_learner.measure(
37 | tasks=val_tasks, embeddings_list=embeddings)
38 | post_train_measurements = self._meta_learner.measure(
39 | tasks=train_tasks, adapted_params_list=adapted_params,
40 | embeddings_list=embeddings)
41 |
42 | _grads_mean = np.mean(self._meta_learner._grads_mean)
43 | self._meta_learner._grads_mean = []
44 |
45 | self.log_output(
46 | pre_val_measurements, pre_train_measurements,
47 | post_val_measurements, post_train_measurements,
48 | i, _grads_mean)
49 |
50 | if is_training:
51 | self.write_tensorboard(
52 | pre_val_measurements, pre_train_measurements,
53 | post_val_measurements, post_train_measurements,
54 | i, _grads_mean)
55 |
56 | # Save model
57 | if i % self._save_interval == 0 and is_training:
58 | save_name = 'maml_{0}_{1}.pt'.format(self._model_type, i)
59 | save_path = os.path.join(self._save_folder, save_name)
60 | with open(save_path, 'wb') as f:
61 | torch.save(self._meta_learner.state_dict(), f)
62 |
63 | # Collect evaluation statistics over full dataset
64 | if not is_training:
65 | for key, value in sorted(pre_val_measurements.items()):
66 | all_pre_val_measurements[key].append(value)
67 | for key, value in sorted(pre_train_measurements.items()):
68 | all_pre_train_measurements[key].append(value)
69 | for key, value in sorted(post_val_measurements.items()):
70 | all_post_val_measurements[key].append(value)
71 | for key, value in sorted(post_train_measurements.items()):
72 | all_post_train_measurements[key].append(value)
73 |
74 | # Compute evaluation statistics assuming all batches were the same size
75 | if not is_training:
76 | results = {'num_batches': i}
77 | for key, value in sorted(all_pre_val_measurements.items()):
78 | results['pre_val_' + key] = value
79 | for key, value in sorted(all_pre_train_measurements.items()):
80 | results['pre_train_' + key] = value
81 | for key, value in sorted(all_post_val_measurements.items()):
82 | results['post_val_' + key] = value
83 | for key, value in sorted(all_post_train_measurements.items()):
84 | results['post_train_' + key] = value
85 |
86 | print('Evaluation results:')
87 | for key, value in sorted(results.items()):
88 | if not isinstance(value, int):
89 | print('{}: {:.6f} +- {:.6e}, std={:.6f}'.format(
90 | key,
91 | float(np.mean(value)),
92 | float(self.compute_confidence_interval(value)),
93 | float(np.std(value)),
94 | ))
95 | else:
96 | print('{}: {}'.format(key, value))
97 |
98 | results_path = os.path.join(self._save_folder, 'results.json')
99 | with open(results_path, 'w') as f:
100 | json.dump(results, f)
101 |
102 | def compute_confidence_interval(self, value):
103 | """
104 | Compute 95% +- confidence intervals over tasks
105 | change 1.960 to 2.576 for 99% +- confidence intervals
106 | """
107 | return np.std(value) * 1.960 / np.sqrt(len(value))
108 |
109 | def train(self):
110 | self.run(is_training=True)
111 |
112 | def eval(self):
113 | self.run(is_training=False)
114 |
115 | def write_tensorboard(self, pre_val_measurements, pre_train_measurements,
116 | post_val_measurements, post_train_measurements,
117 | iteration, embedding_grads_mean=None):
118 | for key, value in pre_val_measurements.items():
119 | self._writer.add_scalar(
120 | '{}/before_update/meta_val'.format(key), value, iteration)
121 | for key, value in pre_train_measurements.items():
122 | self._writer.add_scalar(
123 | '{}/before_update/meta_train'.format(key), value, iteration)
124 | for key, value in post_train_measurements.items():
125 | self._writer.add_scalar(
126 | '{}/after_update/meta_train'.format(key), value, iteration)
127 | for key, value in post_val_measurements.items():
128 | self._writer.add_scalar(
129 | '{}/after_update/meta_val'.format(key), value, iteration)
130 | if embedding_grads_mean is not None:
131 | self._writer.add_scalar(
132 | 'embedding_grads_mean', embedding_grads_mean, iteration)
133 |
134 | def log_output(self, pre_val_measurements, pre_train_measurements,
135 | post_val_measurements, post_train_measurements,
136 | iteration, embedding_grads_mean=None):
137 | log_str = 'Iteration: {} '.format(iteration)
138 | for key, value in sorted(pre_val_measurements.items()):
139 | log_str = (log_str + '{} meta_val before: {:.3f} '
140 | ''.format(key, value))
141 | for key, value in sorted(pre_train_measurements.items()):
142 | log_str = (log_str + '{} meta_train before: {:.3f} '
143 | ''.format(key, value))
144 | for key, value in sorted(post_train_measurements.items()):
145 | log_str = (log_str + '{} meta_train after: {:.3f} '
146 | ''.format(key, value))
147 | for key, value in sorted(post_val_measurements.items()):
148 | log_str = (log_str + '{} meta_val after: {:.3f} '
149 | ''.format(key, value))
150 | if embedding_grads_mean is not None:
151 | log_str = (log_str + 'embedding_grad_norm after: {:.3f} '
152 | ''.format(embedding_grads_mean))
153 | print(log_str)
154 |
--------------------------------------------------------------------------------
/maml/utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | import torch
4 |
5 |
6 | def accuracy(preds, y):
7 | _, preds = torch.max(preds.data, 1)
8 | total = y.size(0)
9 | correct = (preds == y).sum().float()
10 | return correct / total
11 |
12 |
13 | def optimizer_to_device(optimizer, device):
14 | for state in optimizer.state.values():
15 | for k, v in state.items():
16 | if isinstance(v, torch.Tensor):
17 | state[k] = v.to(device)
18 |
19 |
20 | def get_git_revision_hash():
21 | return str(subprocess.check_output(['git', 'rev-parse', 'HEAD']))
22 |
--------------------------------------------------------------------------------