├── .gitignore
├── LICENSE
├── README.md
├── assets
├── baseline.png
├── derivation.png
├── derivation.svg
├── derivation2.png
├── derivation2.svg
├── eq1.png
├── eq2.png
├── maml.png
└── random.png
├── maml.py
├── maml_1hidden.py
└── utils
├── data_generator.py
├── gradient_check.py
└── optim.py
/.gitignore:
--------------------------------------------------------------------------------
1 | #
2 | # vim
3 | *.swp
4 | *.swo
5 |
6 | **.pkl
7 | *.pkl
8 |
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # celery beat schedule file
88 | celerybeat-schedule
89 |
90 | # SageMath parsed files
91 | *.sage.py
92 |
93 | # Environments
94 | .env
95 | .venv
96 | env/
97 | venv/
98 | ENV/
99 | env.bak/
100 | venv.bak/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
114 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Matthew Wilson
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MAML in raw numpy
2 |
3 | This is an implementation of vanilla Model-Agnostic Meta-Learning ([MAML](https://github.com/cbfinn/maml))
4 | in raw numpy. I made this to better understand the algorithm and what it is doing. I derived
5 | the forward and backward passes following conventions from [CS231n](http://cs231n.github.io/).
6 | This code is just a rough sketch to understand the algorithm better, so it works, but
7 | is not optimized or well parameterized.
8 | This turned out to be pretty interesting to see the algorithm
9 | logic without the backprop abstracted away by an autograd package like TensorFlow.
10 |
11 | **Table of contents**
12 | - [Results](#results)
13 | - [What is MAML?](#whatismaml)
14 | - [Derivation](#derivation)
15 |
16 |
17 |
18 |
19 | ## Results
20 |
21 | To verify my implementation, I test on the 1D sinusoid regression problem
22 | from [Section 5.1](https://arxiv.org/pdf/1703.03400.pdf) of the MAML paper (see
23 | also the description of the problem in [Section 4](https://arxiv.org/pdf/1803.02999.pdf) of the Reptile).
24 |
25 | I train for 10k iterations on a [dataset](utils/data_generator.py) of sine
26 | function input/outputs with randomly sampled amplitude and phase, and then
27 | fine-tune on 10 samples from a fixed amplitude and phase.
28 | After fine-tuning, I predict the value of the fixed sine function
29 | for 50 evenly distributed x values between (-5, 5), and plot the results
30 | compared to the ground truth for pre-trained MAML, pre-trained baseline
31 | (joint training), and a randomly initialized network. I find that
32 | MAML is able to fit the sinusoid much more effectively.
33 |
34 | MAML | Baseline (joint training)| Random init
35 | :-------------------------:|:-------------------------:|:----------:|
36 |  |  | 
37 |
38 | Here are the commands to the run the code:
39 |
40 | - Train for 10k iterations and then save the weights to a file:
41 | ```
42 | python3 maml.py # train both MAML and baseline (joint trained) weights
43 | ```
44 | - After training, fine-tune the network and plot results on sine task:
45 | ```
46 | python3 maml.py --test 1
47 | ```
48 | - Run gradient check on implementation:
49 | ```
50 | python3 maml.py --gradcheck 1
51 | ```
52 |
53 |
54 | ### Notes
55 | These results come from using a neural network with 2 hidden layers. I
56 | originally tried using 1 hidden layer (see [`maml_1hidden.py`](maml_1hidden.py),
57 | because it was easier to derive, but I
58 | found that it did not have enough it did not have enough representational
59 | capacity to solve the sinusoid problem (see [Meta-Learning And Universality](https://arxiv.org/pdf/1710.11622.pdf) for more details on representational capacity of MAML).
60 |
61 |
62 |
63 | ## What is MAML?
64 |
65 | ### Introduction
66 |
67 | Model-Agnostic Meta-Learning (MAML) is a gradient based meta-learning algorithm. For an
68 | overview of meta-learning, see a blog post from the author [here](https://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/), and a good talk [here](https://youtu.be/i05Fk4ebMY0).
69 | Roughly meta-learning tries to solve sample-ineffiency problems in
70 | machine learning. It tries to allow models to learn
71 | quickly on new tasks by better incorporating past information from previous tasks.
72 |
73 | Unlike [several](https://arxiv.org/abs/1611.02779) [other](https://openreview.net/forum?id=rJY0-Kcll) [meta-learning](https://arxiv.org/abs/1606.04474) [methods](https://arxiv.org/abs/1707.03141) which use RNNs, MAML only uses
74 | feed-forward networks and gradient descent. The interesting piece is how it
75 | sets up the gradient descent scheme to optimize the network for efficient
76 | fine-tuning on the meta-test set.
77 | In standard neural network training, we use gradient-descent and backprop for
78 | training. MAML assumes that you will use this same approach to quickly
79 | fine-tune on your task and it builds this into the meta-training optimization.
80 |
81 | MAML breaks the meta-learning problem into two phases: a **meta-traning phase** and a **fine-tuning phase**. The meta-training phase optimizes the network parameters so that the fine-tune phase is more effective — so that the network parameters will be sensitive to gradients and can
82 | quickly adapt to solve newly sampled tasks in the distribution. The fine-tuning phase will just
83 | run standard gradient descent using the weights that were produced in the meta-training phase, just like you would fine-tune a network for a task using e.g., pre-trained
84 | ImageNet weights. This process looks somewhat similar to transfer learning, but
85 | is more general and produces better results on meta-learning problems like one-shot learning (where you are given a single instance of a new
86 | object class like electric scooter, and your model must quickly adapt so that
87 | it can effectively distinguish new images of electric scooters from other objects).
88 |
89 |
90 | ### Meta-training
91 | During meta-training, MAML draws several samples from a **task**, and splits them
92 | into **A** and **B** examples. For example you could draw 10 (x,y) pairs from a sinusoid
93 | problem and split them into 5 A and 5 B examples. In this case each task is
94 | defined by a fixed amplitude and phase of the sinusoid, but tasks can represent
95 | more interesting variations, like what objects the robot should interact
96 | with in [imitating a human demonstration](https://sites.google.com/view/daml).
97 |
98 | Once we have sampled the A and B examples from the task, we will use the A
99 | examples for an **inner optimization** (standard gradient descent),
100 | and the B examples for **outer optimization** (gradient descent back through
101 | the inner optimization). At a high level: we will inner optimize on the A
102 | examples, test the generalization performance on the B examples, and
103 | meta-optimize on that loss (using gradient descent through the whole
104 | computation) in order to place the parameters in a good initialization
105 | for quickly fine-tuning to many varied tasks.
106 |
107 | For concretely how that is done, here is the algorithm logic and
108 | pseudocode that closely match the [TensorFlow
109 | implementation](https://github.com/cbfinn/maml):
110 |
111 | ### MAML algorithm
112 |
113 |
114 | **Algorithm logic (do this for many loops)**
115 |
116 | 1. Sample task T from distribution of possible tasks
117 | 1. Sample examples from T and split into A and B examples
118 | 1. Network forward pass with weights W, using A examples
119 | 1. Backward pass to compute gradients dWa
120 | 1. Apply gradients dWa using SGD: W' <-- W - alpha\*dWa
121 | 1. Forward pass with temp weights W', using B examples this time
122 | 1. Backward pass through the whole thing to compute gradients dWb (NOTE: this gradient is with respect to input weights W, not W'. This is a second order derivative and backprops through the B forward, the gradient update step, the A backward, and the A forward computations. Everything in the below [derivation diagrams](#derivation) is just the meta-forward pass. This is backpropping through the whole thing, starting at pred_b)
123 | 1. Apply gradients dW' (using Adam: W <-- W - alpha\*dWb)
124 |
125 |
126 | NOTE: You could also do batches of tasks at a time and sum the lossBs.
127 |
128 | **Pseudocode that roughly matches Finn's implementation of [MAML in TensorFlow](https://github.com/cbfinn/maml):**
129 |
130 | ```python
131 | weights = init_NN_weights() # neural network weights and biases
132 |
133 | task_data = sample_task()
134 |
135 | inputA, labelA, inputB, labelB = task_data.meta_split()
136 |
137 | # forward pass of network using weights and A examples
138 | netoutA = forward(inputA, weights)
139 | lossA = loss_func(netoutA, labelA)
140 |
141 | gradients = get_gradients(lossA) # w.r.t. weights
142 |
143 | fast_weights = weights + -learning_rate * gradients # gradient descent step on weights
144 |
145 | netoutB = forward(inputB, fast_weights)
146 | lossB = loss_func(netoutB, labelB)
147 |
148 | # then you would plug this lossB in an optimizer like Adam to optimize
149 | # w.r.t. to the original weights. fast_weights are basically just a temporary
150 | # thing to see how gradient descent on the inner lossA led to update them.
151 | # The only state that is retained between iterations of MAML are the weights (not fast).
152 | ```
153 |
154 | ### Fine-tuning
155 |
156 | At the fine-tune stage, you now have a set of meta-trained weights. Given a
157 | new task, you can just run the inner optimization, keep track of the
158 | fast_weights, and then use them to predict new examples.
159 |
160 | **Pseudocode to illustrate how fine-tuning works and relates to training**
161 | ```
162 | inputA, labelA = test_data()
163 |
164 | netoutA = forward(inputA, weights)
165 | lossA = loss_func(netoutA, labelA)
166 |
167 | gradients = get_gradients(lossA) # w.r.t. weights
168 |
169 | fast_weights = weights + -learning_rate * gradients # gradient descent step on weights
170 |
171 |
172 | prediction = forward(new_input_to_predict_label_for, fast_weights)
173 | ```
174 |
175 |
176 |
177 |
178 | ## Derivation
179 |
180 |
181 | The below diagram shows the meta-forward pass for MAML with a single inner
182 | update step. By computing the gradients through this computational graph,
183 | I derived the computations required for the meta-backwared pass. I show
184 | the computation for a single hidden-layer neural network for simplicity, but
185 | in the code I use a two hidden-layer neural network.
186 |
187 | NOTE: (dW2, db2, dW1, db1) are computed in the upper figure nd passed to the lower
188 | figure. Gradients are backpropagated from the output all the way back through
189 | both through to the upper figure. I use the approach from [CS231n](http://cs231n.github.io/).
190 |
191 | **Inner forward and backward:**
192 | 
193 |
194 | **Inner gradient (SGD) update and second (outer) forward pass:**
195 | 
196 |
197 |
--------------------------------------------------------------------------------
/assets/baseline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/baseline.png
--------------------------------------------------------------------------------
/assets/derivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/derivation.png
--------------------------------------------------------------------------------
/assets/derivation.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
1512 |
--------------------------------------------------------------------------------
/assets/derivation2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/derivation2.png
--------------------------------------------------------------------------------
/assets/derivation2.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
1600 |
--------------------------------------------------------------------------------
/assets/eq1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/eq1.png
--------------------------------------------------------------------------------
/assets/eq2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/eq2.png
--------------------------------------------------------------------------------
/assets/maml.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/maml.png
--------------------------------------------------------------------------------
/assets/random.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/random.png
--------------------------------------------------------------------------------
/maml.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import pickle
4 | import copy
5 | import random
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import matplotlib as mpl;
9 | mpl.rcParams["savefig.directory"] = '~/Desktop'#$os.chdir(os.path.dirname(__file__))
10 | import argparse
11 | from collections import defaultdict
12 |
13 | from utils.optim import AdamOptimizer
14 | from utils.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array, rel_error
15 | from utils.data_generator import SinusoidGenerator
16 |
17 | """
18 | This file contains logic for training a fully-connected neural network with
19 | 2 hidden layers using the Model-Agnostic Meta-Learning (MAML) algorithm.
20 |
21 | It is designed to solve the toy sinusoid meta-learning problem presented in the MAML paper,
22 | and uses the same architecture as presented in the paper.
23 |
24 | Passing the `--gradcheck=1` flag, will run finite differences gradient check
25 | on the meta forward and backward to ensure correct implementation.
26 |
27 | After training a network, you can pass the `--test=1` flag to compare against
28 | a joint-trained and random network baseline.
29 | """
30 |
31 |
32 | # special dictionary to return 0 if element does not exist (makes gradient code simpler)
33 | GradDict = lambda: defaultdict(lambda: 0)
34 | normalize = lambda x: (x - x.mean()) / (x.std() + 1e-8)
35 |
36 | # weight util functions
37 | def build_weights(hidden_dims=(64, 64)):
38 | """Return dictionary on neural network weights"""
39 | # Initialize all weights (model params) with "Xavier Initialization"
40 | # weight matrix init = uniform(-1, 1) / sqrt(layer_input)
41 | # bias init = zeros()
42 | H1, H2 = hidden_dims
43 | w = {}
44 | w['W1'] = (-1 + 2*np.random.rand(1, H1)) / np.sqrt(1)
45 | w['b1'] = np.zeros(H1)
46 | w['W2'] = (-1 + 2*np.random.rand(H1, H2)) / np.sqrt(H1)
47 | w['b2'] = np.zeros(H2)
48 | w['W3'] = (-1 + 2*np.random.rand(H2, 1)) / np.sqrt(H2)
49 | w['b3'] = np.zeros(1)
50 |
51 | # Cast all parameters to the correct datatype
52 | for k, v in w.items():
53 | w[k] = v.astype(np.float32)
54 | return w
55 |
56 | def save_weights(weights, filename, quiet=False):
57 | with open(filename, 'wb') as f:
58 | pickle.dump(weights, f)
59 | if not quiet:
60 | print('weights saved to {}'.format(filename))
61 |
62 | def load_weights(filename, quiet=False):
63 | with open(filename, 'rb') as f:
64 | weights = pickle.load(f)
65 | if not quiet:
66 | print('weights loaded from {}'.format(filename))
67 | return weights
68 |
69 | class Network(object):
70 | """
71 | Forward and backward pass logic for 3 layer neural network
72 | (see https://github.com/matwilso/maml_numpy#derivation for derivation)
73 | """
74 |
75 | def __init__(self, inner_lr=0.01, normalize=normalize):
76 | self.inner_lr = inner_lr # alpha in the paper
77 | self.normalize = normalize # function to normalize gradients before applying them to weights (helps with stability)
78 |
79 | def inner_forward(self, x_a, weights, cache={}):
80 | """Submodule for meta_forward. This is just a standard forward pass for a neural net.
81 |
82 | Args:
83 | x_a (ndarray): Example or examples of sinusoid from given phase, amplitude.
84 | weights (dict): Dictionary of weights and biases for neural net
85 | cache (dict): Pass in dictionary to be updated with values needed in meta_backward
86 |
87 | Returns:
88 | pred_a (ndarray): Predicted values for example(s) x_a
89 | """
90 | w = weights
91 | W1, b1, W2, b2, W3, b3 = w['W1'], w['b1'], w['W2'], w['b2'], w['W3'], w['b3']
92 | # layer 1
93 | affine1_a = x_a.dot(W1) + b1
94 | relu1_a = np.maximum(0, affine1_a)
95 | # layer 2
96 | affine2_a = relu1_a.dot(W2) + b2
97 | relu2_a = np.maximum(0, affine2_a)
98 | # layer 3
99 | pred_a = relu2_a.dot(W3) + b3
100 |
101 | cache.update(dict(x_a=x_a, affine1_a=affine1_a, relu1_a=relu1_a, affine2_a=affine2_a, relu2_a=relu2_a))
102 | return pred_a
103 |
104 | def inner_backward(self, dout_a, weights, cache, grads=GradDict(), lr=None):
105 | """For fine-tuning network at meta-test time
106 |
107 | (Although this has some repeated code from meta_backward, it was hard to
108 | use as a subprocess for meta_backward. It required several changes in
109 | code and made things more confusing.)
110 |
111 | Args:
112 | dout_a (ndarray): Gradient of output (usually loss)
113 | weights (dict): Dictionary of weights and biases for neural net
114 | cache (dict): Dictionary of relevant values from forward pass
115 |
116 | Returns:
117 | dict: New dictionary, with updated weights
118 | """
119 | w = weights; c = cache
120 | W1, b1, W2, b2, W3, b3 = w['W1'], w['b1'], w['W2'], w['b2'], w['W3'], w['b3']
121 | lr = lr or self.inner_lr
122 |
123 | drelu2_a = dout_a.dot(W3.T)
124 | dW3 = c['relu2_a'].T.dot(dout_a)
125 | db3 = np.sum(dout_a, axis=0)
126 |
127 | daffine2_a = np.where(c['affine2_a'] > 0, drelu2_a, 0)
128 |
129 | drelu1_a = daffine2_a.dot(W2.T)
130 | dW2 = c['relu1_a'].T.dot(dout_a)
131 | db2 = np.sum(dout_a, axis=0)
132 |
133 | daffine1_a = np.where(c['affine1_a'] > 0, drelu1_a, 0)
134 |
135 | dW1 = c['x_a'].T.dot(daffine1_a)
136 | db1 = np.sum(daffine1_a, axis=0)
137 |
138 | grads['W1'] += dW1
139 | grads['b1'] += db1
140 | grads['W2'] += dW2
141 | grads['b2'] += db2
142 | grads['W3'] += dW3
143 | grads['b3'] += db3
144 |
145 | # Return new weights (for fine-tuning)
146 | new_weights = {}
147 | new_weights['W1'] = W1 - lr*self.normalize(dW1)
148 | new_weights['b1'] = b1 - lr*self.normalize(db1)
149 | new_weights['W2'] = W2 - lr*self.normalize(dW2)
150 | new_weights['b2'] = b2 - lr*self.normalize(db2)
151 | new_weights['W3'] = W3 - lr*self.normalize(dW3)
152 | new_weights['b3'] = b3 - lr*self.normalize(db3)
153 | return new_weights
154 |
155 |
156 | def meta_forward(self, x_a, x_b, label_a, weights, cache={}):
157 | """Full forward pass for MAML. Does a inner_forward, backprop, and gradient
158 | update. This will all be backpropped through w.r.t. weights in meta_backward
159 |
160 | Args:
161 | x_a (ndarray): Example or examples of sinusoid from given phase, amplitude.
162 | x_b (ndarray): Independent example(s) from same phase, amplitude as x_a's
163 | label_a (ndarray): Ground truth labels for x_a
164 | weights (dict): Dictionary of weights and biases for neural net
165 | cache (dict): Pass in dictionary to be updated with values needed in meta_backward
166 |
167 | Returns:
168 | pred_b (ndarray): Predicted values for example(s) x_b
169 | """
170 | w = weights
171 | W1, b1, W2, b2, W3, b3 = w['W1'], w['b1'], w['W2'], w['b2'], w['W3'], w['b3']
172 |
173 | # A: inner
174 | # standard forward and backward computations
175 | inner_cache = {}
176 | pred_a = self.inner_forward(x_a, w, inner_cache)
177 |
178 | # inner loss
179 | dout_a = 2*(pred_a - label_a)
180 |
181 | # d 3rd layer
182 | dW3 = inner_cache['relu2_a'].T.dot(dout_a)
183 | db3 = np.sum(dout_a, axis=0)
184 | drelu2_a = dout_a.dot(W3.T)
185 |
186 | daffine2_a = np.where(inner_cache['affine2_a'] > 0, drelu2_a, 0)
187 |
188 | # d 2nd layer
189 | dW2 = inner_cache['relu1_a'].T.dot(daffine2_a)
190 | db2 = np.sum(daffine2_a, axis=0)
191 | drelu1_a = daffine2_a.dot(W2.T)
192 |
193 | daffine1_a = np.where(inner_cache['affine1_a'] > 0, drelu1_a, 0)
194 |
195 | # d 1st layer
196 | dW1 = x_a.T.dot(daffine1_a)
197 | db1 = np.sum(daffine1_a, axis=0)
198 |
199 | # Forward on fast weights
200 | # B: meta/outer
201 | # SGD step is baked into forward pass, representing optimizing through fine-tuning
202 | # Theta prime in the paper. Also called fast_weights in Finn's TF implementation
203 | W1_prime = W1 - self.inner_lr*dW1
204 | b1_prime = b1 - self.inner_lr*db1
205 | W2_prime = W2 - self.inner_lr*dW2
206 | b2_prime = b2 - self.inner_lr*db2
207 | W3_prime = W3 - self.inner_lr*dW3
208 | b3_prime = b3 - self.inner_lr*db3
209 |
210 | # Do another forward pass with the fast weights, to predict B example
211 | affine1_b = x_b.dot(W1_prime) + b1_prime
212 | relu1_b = np.maximum(0, affine1_b)
213 | affine2_b = relu1_b.dot(W2_prime) + b2_prime
214 | relu2_b = np.maximum(0, affine2_b)
215 | pred_b = relu2_b.dot(W3_prime) + b3_prime
216 |
217 | # Cache relevant values for meta backpropping
218 | outer_cache = dict(dout_a=dout_a, x_b=x_b, affine1_b=affine1_b, relu1_b=relu1_b, affine2_b=affine2_b, relu2_b=relu2_b, daffine2_a=daffine2_a, W2_prime=W2_prime, W3_prime=W3_prime)
219 | cache.update(inner_cache)
220 | cache.update(outer_cache)
221 |
222 | return pred_b
223 |
224 | def meta_backward(self, dout_b, weights, cache, grads=GradDict()):
225 | """Full backward pass for MAML. Through all operations from forward pass
226 |
227 | Args:
228 | dout_b (ndarray): Gradient signal of network output (usually loss gradient)
229 | weights (dict): Dictionary of weights and biases used in forward pass
230 | cache (dict): Dictionary of relevant values from forward pass
231 | grads (dict): Pass in dictionary to be updated with weight gradients
232 | """
233 | c = cache; w = weights
234 | W1, b1, W2, b2, W3, b3 = w['W1'], w['b1'], w['W2'], w['b2'], w['W3'], w['b3']
235 |
236 | # First, backprop through the B network pass
237 | # d 3rd layer
238 | drelu2_b = dout_b.dot(c['W3_prime'].T)
239 | dW3_prime = c['relu2_b'].T.dot(dout_b)
240 | db3_prime = np.sum(dout_b, axis=0)
241 |
242 | daffine2_b = np.where(c['affine2_b'] > 0, drelu2_b, 0)
243 |
244 | # d 2nd layer
245 | drelu1_b = daffine2_b.dot(c['W2_prime'].T)
246 | dW2_prime = c['relu1_b'].T.dot(daffine2_b)
247 | db2_prime = np.sum(daffine2_b, axis=0)
248 |
249 | daffine1_b = np.where(c['affine1_b'] > 0, drelu1_b, 0)
250 |
251 | # d 1st layer
252 | dW1_prime = c['x_b'].T.dot(daffine1_b)
253 | db1_prime = np.sum(daffine1_b, axis=0)
254 |
255 | # Next, backprop through the gradient descent step
256 | dW1 = dW1_prime
257 | db1 = db1_prime
258 | dW2 = dW2_prime
259 | db2 = db2_prime
260 | dW3 = dW3_prime
261 | db3 = db3_prime
262 |
263 | ddW1 = dW1_prime * -self.inner_lr
264 | ddb1 = db1_prime * -self.inner_lr
265 | ddW2 = dW2_prime * -self.inner_lr
266 | ddb2 = db2_prime * -self.inner_lr
267 | ddW3 = dW3_prime * -self.inner_lr
268 | ddb3 = db3_prime * -self.inner_lr
269 |
270 | # Then, backprop through the first backprop
271 | # start with dW1's
272 | ddaffine1_a = c['x_a'].dot(ddW1)
273 | ddaffine1_a += ddb1
274 |
275 | ddrelu1_a = np.where(c['affine1_a'] > 0, ddaffine1_a, 0)
276 |
277 | ddaffine2_a = ddrelu1_a.dot(W2)
278 | dW2 += ddrelu1_a.T.dot(c['daffine2_a'])
279 |
280 | # dW2's
281 | drelu1_a = c['daffine2_a'].dot(ddW2.T) # shortcut back because of the grad dependency
282 | ddaffine2_a += ddb2
283 | ddaffine2_a += c['relu1_a'].dot(ddW2)
284 |
285 | ddrelu2_a = np.where(c['affine2_a'] > 0, ddaffine2_a, 0)
286 |
287 | ddout_a = ddrelu2_a.dot(W3)
288 | dW3 += ddrelu2_a.T.dot(c['dout_a'])
289 |
290 | # dW3's
291 | drelu2_a = c['dout_a'].dot(ddW3.T) # shortcut back because of the grad dependency
292 | ddout_a += ddb3
293 | ddout_a += c['relu2_a'].dot(ddW3)
294 |
295 | # Finally, backprop through the first forward
296 | dpred_a = ddout_a * 2
297 |
298 | drelu2_a += dpred_a.dot(W3.T)
299 | db3 += np.sum(dpred_a, axis=0)
300 | dW3 += c['relu2_a'].T.dot(dpred_a)
301 |
302 | daffine2_a = np.where(c['affine2_a'] > 0, drelu2_a, 0)
303 |
304 | drelu1_a += daffine2_a.dot(W2.T)
305 | dW2 += c['relu1_a'].T.dot(daffine2_a)
306 | db2 += np.sum(daffine2_a, axis=0)
307 |
308 | daffine1_a = np.where(c['affine1_a'] > 0, drelu1_a, 0)
309 |
310 | dW1 += c['x_a'].T.dot(daffine1_a)
311 | db1 += np.sum(daffine1_a, axis=0)
312 |
313 | # update gradients
314 | grads['W1'] += self.normalize(dW1)
315 | grads['b1'] += self.normalize(db1)
316 | grads['W2'] += self.normalize(dW2)
317 | grads['b2'] += self.normalize(db2)
318 | grads['W3'] += self.normalize(dW3)
319 | grads['b3'] += self.normalize(db3)
320 |
321 |
322 | def gradcheck():
323 | # Test the network gradient
324 | nn = Network(normalize=lambda x: x) # don't normalize gradients so we can check validity
325 | grads = GradDict() # initialize grads to 0
326 | # dummy inputs, labels, and fake backwards gradient signal
327 | x_a = np.random.randn(15, 1)
328 | x_b = np.random.randn(15, 1)
329 | label = np.random.randn(15, 1)
330 | dout = np.random.randn(15, 1)
331 | # make weights. don't use build_weights here because this is more stable
332 | W1 = np.random.randn(1, 40)
333 | b1 = np.random.randn(40)
334 | W2 = np.random.randn(40, 40)
335 | b2 = np.random.randn(40)
336 | W3 = np.random.randn(40, 1)
337 | b3 = np.random.randn(1)
338 | weights = dict(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3)
339 |
340 | # helper function to only change a single key of interest for independent finite differences
341 | def rep_param(weights, name, val):
342 | clean_params = copy.deepcopy(weights)
343 | clean_params[name] = val
344 | return clean_params
345 |
346 | # Evaluate gradients numerically, using finite differences
347 | numerical_grads = {}
348 | for key in weights:
349 | num_grad = eval_numerical_gradient_array(lambda w: nn.meta_forward(x_a, x_b, label, rep_param(weights, key, w)), weights[key], dout, h=1e-5)
350 | numerical_grads[key] = num_grad
351 |
352 | # Compute neural network gradients
353 | cache = {}
354 | out = nn.meta_forward(x_a, x_b, label, weights, cache=cache)
355 | nn.meta_backward(dout, weights, cache, grads)
356 |
357 | # The error should be around 1e-10
358 | print()
359 | for key in weights:
360 | print('d{} error: {}'.format(key, rel_error(numerical_grads[key], grads[key])))
361 | print()
362 |
363 | def test():
364 | """Take one grad step using a minibatch of size 5 and see how well it works
365 |
366 | Basically what they show in Figure 2 of the paper
367 | """
368 | nn = Network(inner_lr=FLAGS.inner_lr)
369 |
370 | pre_weights = {}
371 | pre_weights['maml'] = load_weights(FLAGS.weight_path)
372 | if FLAGS.use_baseline:
373 | pre_weights['baseline'] = load_weights('baseline_'+FLAGS.weight_path)
374 | pre_weights['random'] = build_weights()
375 |
376 | # Generate N batches of data, with same shape as training, but that all have the same amplitude and phase
377 | N = 2
378 | #sinegen = SinusoidGenerator(FLAGS.inner_bs*N, 1, config={'input_range':[1.0,5.0]})
379 | sinegen = SinusoidGenerator(FLAGS.inner_bs*N, 1)
380 | x, y, amp, phase = map(lambda x: x[0], sinegen.generate()) # grab all the first elems
381 | xs = np.split(x, N)
382 | ys = np.split(y, N)
383 |
384 | # Copy pre-update weights for later comparison
385 | deepcopy = lambda weights: {key: weights[key].copy() for key in weights}
386 | post_weights = {}
387 | for key in pre_weights:
388 | post_weights[key] = deepcopy(pre_weights[key])
389 |
390 | T = 10
391 | # Run fine-tuning
392 | for key in post_weights:
393 | for t in range(T):
394 | for i in range(len(xs)):
395 | x = xs[i]
396 | y = ys[i]
397 | grads = GradDict()
398 | cache = {}
399 | pred = nn.inner_forward(x, post_weights[key], cache)
400 | loss = (pred - y)**2
401 | dout = 2*(pred - y)
402 | post_weights[key] = nn.inner_backward(dout, post_weights[key], cache)
403 |
404 |
405 | colors = {'maml': 'r', 'baseline': 'b', 'random': 'g'}
406 | name = {'maml': 'MAML', 'baseline': 'joint training', 'random': 'random initialization'}
407 |
408 | sine_ground = lambda x: amp*np.sin(x - phase)
409 | sine_pre_pred = lambda x, key: nn.inner_forward(x, pre_weights[key])[0]
410 | sine_post_pred = lambda x, key: nn.inner_forward(x, post_weights[key])[0]
411 |
412 | x_vals = np.linspace(-5, 5)
413 | y_ground = np.apply_along_axis(sine_ground, 0, x_vals)
414 |
415 |
416 | for key in post_weights:
417 | y_pre = np.array([sine_pre_pred(np.array(x), key) for x in x_vals]).squeeze()
418 | y_nn = np.array([sine_post_pred(np.array(x), key) for x in x_vals]).squeeze()
419 | plt.plot(x_vals, y_ground, 'k', label='{:.2f}sin(x - {:.2f})'.format(amp, phase))
420 | plt.plot(np.concatenate(xs), np.concatenate(ys), 'ok', label='samples')
421 | plt.plot(x_vals, y_pre, colors[key]+'--', label='pre-update')
422 | plt.plot(x_vals, y_nn, colors[key]+'-', label='post-update')
423 |
424 | plt.legend()
425 | plt.title('Fine-tuning performance {}'.format(name[key]))
426 | plt.savefig(key+'.png')
427 | plt.show()
428 |
429 | def train():
430 | nn = Network(inner_lr=FLAGS.inner_lr)
431 | weights = build_weights()
432 | optimizer = AdamOptimizer(weights, learning_rate=FLAGS.meta_lr)
433 | if FLAGS.use_baseline:
434 | baseline_weights = build_weights()
435 | baseline_optimizer = AdamOptimizer(baseline_weights, learning_rate=FLAGS.meta_lr)
436 |
437 | sinegen = SinusoidGenerator(2*FLAGS.inner_bs, 25) # update_batch * 2, meta batch size
438 |
439 | try:
440 | nitr = int(FLAGS.num_iter)
441 | for itr in range(int(nitr)):
442 | # create a minibatch of size 25, with 10 points
443 | batch_x, batch_y, amp, phase = sinegen.generate()
444 |
445 | inputa = batch_x[:, :FLAGS.inner_bs :]
446 | labela = batch_y[:, :FLAGS.inner_bs :]
447 | inputb = batch_x[:, FLAGS.inner_bs :] # b used for testing
448 | labelb = batch_y[:, FLAGS.inner_bs :]
449 |
450 | # META BATCH
451 | grads = GradDict() # zero grads
452 | baseline_grads = GradDict() # zero grads
453 | losses = []
454 | baseline_losses = []
455 | for batch_i in range(len(inputa)):
456 | ia, la, ib, lb = inputa[batch_i], labela[batch_i], inputb[batch_i], labelb[batch_i]
457 | cache = {}
458 | pred_b = nn.meta_forward(ia, ib, la, weights, cache=cache)
459 | losses.append((pred_b - lb)**2)
460 | dout_b = 2*(pred_b - lb)
461 | nn.meta_backward(dout_b, weights, cache, grads)
462 |
463 |
464 | if FLAGS.use_baseline:
465 | baseline_cache = {}
466 | baseline_i = np.concatenate([ia,ib])
467 | baseline_l = np.concatenate([la,lb])
468 | baseline_pred = nn.inner_forward(baseline_i, baseline_weights, cache=baseline_cache)
469 | baseline_losses.append((baseline_pred - baseline_l)**2)
470 | dout_b = 2*(baseline_pred - baseline_l)
471 | nn.inner_backward(dout_b, baseline_weights, baseline_cache, baseline_grads)
472 |
473 | optimizer.apply_gradients(weights, grads, learning_rate=FLAGS.meta_lr)
474 | if FLAGS.use_baseline:
475 | baseline_optimizer.apply_gradients(baseline_weights, baseline_grads, learning_rate=FLAGS.meta_lr)
476 | if itr % 100 == 0:
477 | if FLAGS.use_baseline:
478 | print("[itr: {}] MAML loss = {} Baseline loss = {}".format(itr, np.sum(losses), np.sum(baseline_losses)))
479 | else:
480 | print("[itr: {}] Loss = {}".format(itr, np.sum(losses)))
481 | except KeyboardInterrupt:
482 | pass
483 | save_weights(weights, FLAGS.weight_path)
484 | if FLAGS.use_baseline:
485 | save_weights(baseline_weights, "baseline_"+FLAGS.weight_path)
486 |
487 |
488 | if __name__ == '__main__':
489 | parser = argparse.ArgumentParser(description='MAML')
490 | parser.add_argument('--seed', type=int, default=2, help='')
491 | parser.add_argument('--gradcheck', type=int, default=0, help='Run gradient check and other tests')
492 | parser.add_argument('--test', type=int, default=0, help='Run test on trained network to see if it works')
493 | parser.add_argument('--meta_lr', type=float, default=1e-3, help='Meta learning rate')
494 | parser.add_argument('--inner_lr', type=float, default=1e-2, help='Inner learning rate')
495 | parser.add_argument('--inner_bs', type=int, default=5, help='Inner batch size')
496 | parser.add_argument('--weight_path', type=str, default='trained_maml_weights.pkl', help='File name to save and load weights')
497 | parser.add_argument('--use_baseline', type=int, default=1, help='Whether to train a baseline network')
498 | parser.add_argument('--num_iter', type=float, default=1e4, help='Number of iterations')
499 | FLAGS = parser.parse_args()
500 | np.random.seed(FLAGS.seed)
501 |
502 | if FLAGS.gradcheck:
503 | gradcheck()
504 | elif FLAGS.test:
505 | test()
506 | else:
507 | train()
508 |
--------------------------------------------------------------------------------
/maml_1hidden.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import pickle
3 | import copy
4 | import random
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import argparse
8 | from collections import defaultdict
9 |
10 | from utils.optim import AdamOptimizer
11 | from utils.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array, rel_error
12 | from utils.data_generator import SinusoidGenerator
13 |
14 |
15 | # this will create a special dictionary that returns 0 if the element is not set, instead of error
16 | # (it makes the code for updating gradients simpler)
17 | GradDict = lambda: defaultdict(lambda: 0)
18 |
19 | normalize = lambda x: (x - x.mean()) / (x.std() + 1e-8)
20 |
21 | def build_weights(hidden_dim=200):
22 | """Return weights to be used in forward pass"""
23 | # Initialize all weights (model params) with "Xavier Initialization"
24 | # weight matrix init = uniform(-1, 1) / sqrt(layer_input)
25 | # bias init = zeros()
26 | H = hidden_dim
27 | d = {}
28 | d['W1'] = (-1 + 2*np.random.rand(1, H)) / np.sqrt(1)
29 | d['b1'] = np.zeros(H)
30 | d['W2'] = (-1 + 2*np.random.rand(H, 1)) / np.sqrt(H)
31 | d['b2'] = np.zeros(1)
32 |
33 | # Cast all parameters to the correct datatype
34 | for k, v in d.items():
35 | d[k] = v.astype(np.float32)
36 | return d
37 |
38 | def save_weights(weights, filename, quiet=False):
39 | with open(filename, 'wb') as f:
40 | pickle.dump(weights, f)
41 | if not quiet:
42 | print('weights saved to {}'.format(filename))
43 |
44 | def load_weights(filename, quiet=False):
45 | with open(filename, 'rb') as f:
46 | weights = pickle.load(f)
47 | if not quiet:
48 | print('weights loaded from {}'.format(filename))
49 | return weights
50 |
51 |
52 | class Network(object):
53 | """BYOW: Bring Your Own Weights
54 |
55 | Hard-code operations for a 2 layer neural network
56 | """
57 | def __init__(self, alpha=0.01, normalized=normalize):
58 | self.ALPHA = alpha
59 | self.normalized = normalized
60 |
61 | def inner_forward(self, x_a, w):
62 | """submodule for forward pass"""
63 | W1, b1, W2, b2 = w['W1'], w['b1'], w['W2'], w['b2']
64 |
65 | affine1_a = x_a.dot(W1) + b1
66 | relu1_a = np.maximum(0, affine1_a)
67 | pred_a = relu1_a.dot(W2) + b2
68 |
69 | cache = dict(x_a=x_a, affine1_a=affine1_a, relu1_a=relu1_a)
70 | return pred_a, cache
71 |
72 | def inner_backward(self, dout_a, weights, cache):
73 | """just for fine-tuning at the end"""
74 | w = weights; c = cache
75 | W1, b1, W2, b2 = w['W1'], w['b1'], w['W2'], w['b2']
76 |
77 | drelu1_a = dout_a.dot(W2.T)
78 | dW2 = cache['relu1_a'].T.dot(dout_a)
79 | db2 = np.sum(dout_a, axis=0)
80 |
81 | daffine1_a = np.where(cache['affine1_a'] > 0, drelu1_a, 0)
82 |
83 | dW1 = c['x_a'].T.dot(daffine1_a)
84 | db1 = np.sum(daffine1_a, axis=0)
85 |
86 | # grad steps
87 | new_weights = {}
88 | new_weights['W1'] = W1 - self.ALPHA*self.normalized(dW1)
89 | new_weights['b1'] = b1 - self.ALPHA*self.normalized(db1)
90 | new_weights['W2'] = W2 - self.ALPHA*self.normalized(dW2)
91 | new_weights['b2'] = b2 - self.ALPHA*self.normalized(db2)
92 | return new_weights
93 |
94 |
95 | def meta_forward(self, x_a, x_b, label_a, weights, cache=None):
96 | w = weights
97 | W1, b1, W2, b2 = w['W1'], w['b1'], w['W2'], w['b2']
98 |
99 | # standard forward and backward computations
100 | # (a)
101 | pred_a, inner_cache = self.inner_forward(x_a, w)
102 |
103 | dout_a = 2*(pred_a - label_a)
104 |
105 | drelu1_a = dout_a.dot(W2.T)
106 | dW2 = inner_cache['relu1_a'].T.dot(dout_a)
107 | db2 = np.sum(dout_a, axis=0)
108 |
109 | daffine1_a = np.where(inner_cache['affine1_a'] > 0, drelu1_a, 0)
110 |
111 | dW1 = x_a.T.dot(daffine1_a)
112 | db1 = np.sum(daffine1_a, axis=0)
113 |
114 | # Forward on fast weights
115 | # (b)
116 |
117 | # grad steps
118 | W1_prime = W1 - self.ALPHA*dW1
119 | b1_prime = b1 - self.ALPHA*db1
120 | W2_prime = W2 - self.ALPHA*dW2
121 | b2_prime = b2 - self.ALPHA*db2
122 |
123 | affine1_b = x_b.dot(W1_prime) + b1_prime
124 | relu1_b = np.maximum(0, affine1_b)
125 | pred_b = relu1_b.dot(W2_prime) + b2_prime
126 |
127 | if cache:
128 | outer_cache = dict(dout_a=dout_a, x_b=x_b, affine1_b=affine1_b, relu1_b=relu1_b, W2_prime=W2_prime)
129 | return pred_b, {**inner_cache, **outer_cache}
130 | else:
131 | return pred_b
132 |
133 | def meta_backward(self, dout_b, weights, cache, grads=None):
134 | c = cache; w = weights # short
135 | W1, b1, W2, b2 = w['W1'], w['b1'], w['W2'], w['b2']
136 |
137 | # deriv w.r.t b (lower half)
138 | # d 1st layer
139 | dW2_prime = c['relu1_b'].T.dot(dout_b)
140 | db2_prime = np.sum(dout_b, axis=0)
141 | drelu1_b = dout_b.dot(c['W2_prime'].T)
142 |
143 | daffine1_b = np.where(c['affine1_b'] > 0, drelu1_b, 0)
144 | # d 2nd layer
145 | dW1_prime = c['x_b'].T.dot(daffine1_b)
146 | db1_prime = np.sum(daffine1_b, axis=0)
147 |
148 | # deriv w.r.t a (upper half)
149 |
150 | # going back through the gradient descent step
151 | dW1 = dW1_prime
152 | db1 = db1_prime
153 | dW2 = dW2_prime
154 | db2 = db2_prime
155 |
156 | ddW1 = dW1_prime * -self.ALPHA
157 | ddb1 = db1_prime * -self.ALPHA
158 | ddW2 = dW2_prime * -self.ALPHA
159 | ddb2 = db2_prime * -self.ALPHA
160 |
161 | # backpropping through the first backprop
162 | ddout_a = c['relu1_a'].dot(ddW2)
163 | ddout_a += ddb2
164 | drelu1_a = c['dout_a'].dot(ddW2.T) # shortcut back because of the grad dependency
165 |
166 | ddaffine1_a = c['x_a'].dot(ddW1)
167 | ddaffine1_a += ddb1
168 | ddrelu1_a = np.where(c['affine1_a'] > 0, ddaffine1_a, 0)
169 |
170 | dW2 += ddrelu1_a.T.dot(c['dout_a'])
171 |
172 | ddout_a += ddrelu1_a.dot(W2)
173 |
174 | dpred_a = ddout_a * 2 # = dout_a
175 |
176 | dW2 += c['relu1_a'].T.dot(dpred_a)
177 | db2 += np.sum(dpred_a, axis=0)
178 |
179 | drelu1_a += dpred_a.dot(W2.T)
180 |
181 | daffine1_a = np.where(c['affine1_a'] > 0, drelu1_a, 0)
182 |
183 | dW1 += c['x_a'].T.dot(daffine1_a)
184 | db1 += np.sum(daffine1_a, axis=0)
185 |
186 | if grads is not None:
187 | # update gradients
188 | grads['W1'] += self.normalized(dW1)
189 | grads['b1'] += self.normalized(db1)
190 | grads['W2'] += self.normalized(dW2)
191 | grads['b2'] += self.normalized(db2)
192 |
193 |
194 | def gradcheck():
195 | # Test the network gradient
196 | nn = Network(normalized=lambda x: x)
197 | grads = GradDict()
198 |
199 | np.random.seed(231)
200 | x_a = np.random.randn(15, 1)
201 | x_b = np.random.randn(15, 1)
202 | label = np.random.randn(15, 1)
203 | W1 = np.random.randn(1, 40)
204 | b1 = np.random.randn(40)
205 | W2 = np.random.randn(40, 1)
206 | b2 = np.random.randn(1)
207 |
208 | dout = np.random.randn(15, 1)
209 |
210 | weights = w = {}
211 | w['W1'] = W1
212 | w['b1'] = b1
213 | w['W2'] = W2
214 | w['b2'] = b2
215 |
216 | def rep_param(weights, name, val):
217 | clean_params = copy.deepcopy(weights)
218 | clean_params[name] = val
219 | return clean_params
220 |
221 | dW1_num = eval_numerical_gradient_array(lambda w: nn.meta_forward(x_a, x_b, label, rep_param(weights, 'W1', w)), W1, dout)
222 | db1_num = eval_numerical_gradient_array(lambda b: nn.meta_forward(x_a, x_b, label, rep_param(weights, 'b1', b)), b1, dout)
223 | dW2_num = eval_numerical_gradient_array(lambda w: nn.meta_forward(x_a, x_b, label, rep_param(weights, 'W2', w)), W2, dout)
224 | db2_num = eval_numerical_gradient_array(lambda b: nn.meta_forward(x_a, x_b, label, rep_param(weights, 'b2', b)), b2, dout)
225 |
226 | out, cache = nn.meta_forward(x_a, x_b, label, weights, cache=True)
227 | nn.meta_backward(dout, weights, cache, grads)
228 |
229 | # The error should be around 1e-10
230 | print()
231 | print('Testing meta-learning NN backward function:')
232 | print('dW1 error: ', rel_error(dW1_num, grads['W1']))
233 | print('db1 error: ', rel_error(db1_num, grads['b1']))
234 | print('dW2 error: ', rel_error(dW2_num, grads['W2']))
235 | print('db2 error: ', rel_error(db2_num, grads['b2']))
236 | print()
237 |
238 | def test():
239 | """take one grad step using a minibatch of size 5 and see how well it works
240 |
241 | basically what they show in Figure 2 of:
242 | https://arxiv.org/pdf/1703.03400.pdf
243 | """
244 | nn = Network()
245 | pre_weights = load_weights(FLAGS.weight_path)
246 | random_weights = build_weights()
247 |
248 | # values for fine-tuning step
249 | N = 10
250 | sin_gen = SinusoidGenerator(5*N, 1)
251 | x, y, amp, phase = map(lambda x: x[0], sin_gen.generate()) # grab all the first elems
252 | xs = np.split(x, N)
253 | ys = np.split(y, N)
254 |
255 | new_weights = pre_weights.copy()
256 | new_random_weights = random_weights.copy()
257 | for i in range(len(xs)):
258 | x = xs[i]
259 | y = ys[i]
260 | grads = GradDict()
261 | pred, cache = nn.inner_forward(x, new_weights)
262 | loss = (pred - y)**2
263 | dout = 2*(pred - y)
264 | new_weights = nn.inner_backward(dout, new_weights, cache)
265 |
266 | for i in range(len(xs)):
267 | x = xs[i]
268 | y = ys[i]
269 | grads = GradDict()
270 | pred, cache = nn.inner_forward(x, new_random_weights)
271 | loss = (pred - y)**2
272 | dout = 2*(pred - y)
273 | new_random_weights = nn.inner_backward(dout, new_random_weights, cache)
274 |
275 |
276 | sine_true = lambda x: amp*np.sin(x - phase)
277 | sine_nn = lambda x: nn.inner_forward(x, new_weights)[0]
278 | sine_pre = lambda x: nn.inner_forward(x, pre_weights)[0]
279 | sine_random = lambda x: nn.inner_forward(x, random_weights)[0]
280 | sine_new_random = lambda x: nn.inner_forward(x, new_random_weights)[0]
281 |
282 | x_vals = np.linspace(-5, 5)
283 |
284 | y_true = np.apply_along_axis(sine_true, 0, x_vals)
285 | y_nn = np.array([sine_nn(np.array(x)) for x in x_vals]).squeeze()
286 | y_pre = np.array([sine_pre(np.array(x)) for x in x_vals]).squeeze()
287 | y_random = np.array([sine_random(np.array(x)) for x in x_vals]).squeeze()
288 | y_new_random = np.array([sine_new_random(np.array(x)) for x in x_vals]).squeeze()
289 |
290 | plt.plot(x_vals, y_true, 'k', label='{:.2f}sin(x - {:.2f})'.format(amp, phase))
291 | plt.plot(x_vals, y_pre, 'r--', label='pre-update')
292 | plt.plot(x_vals, y_nn, 'r-', label='post-update')
293 | plt.plot(x_vals, y_random, 'g--', label='random')
294 | plt.plot(x_vals, y_new_random, 'g-', label='new_random')
295 | plt.legend()
296 | plt.show()
297 |
298 |
299 | def main():
300 | nn = Network()
301 | weights = build_weights()
302 | optimizer = AdamOptimizer(weights, learning_rate=FLAGS.learning_rate)
303 |
304 | sin_gen = SinusoidGenerator(10, 25) # update_batch * 2, meta batch size
305 |
306 |
307 | lr = lambda x: x * FLAGS.learning_rate
308 |
309 | nitr = 1e4
310 | for itr in range(int(nitr)):
311 | frac = 1.0 - (itr / nitr)
312 |
313 | # create a minibatch of size 25, with 10 points
314 | batch_x, batch_y, amp, phase = sin_gen.generate()
315 |
316 | inputa = batch_x[:, :5, :]
317 | labela = batch_y[:, :5, :]
318 | inputb = batch_x[:, 5:, :] # b used for testing
319 | labelb = batch_y[:, 5:, :]
320 |
321 | # META BATCH
322 | grads = GradDict() # zero grads
323 | losses = []
324 | for batch_i in range(len(inputa)):
325 | ia, la, ib, lb = inputa[batch_i], labela[batch_i], inputb[batch_i], labelb[batch_i]
326 | pred_b, cache = nn.meta_forward(ia, ib, la, weights, cache=True)
327 | losses.append((pred_b - lb)**2)
328 | dout_b = 2*(pred_b - lb)
329 | nn.meta_backward(dout_b, weights, cache, grads)
330 | optimizer.apply_gradients(weights, grads, learning_rate=lr(frac))
331 | if itr % 100 == 0:
332 | print("[itr: {}] Loss = {}".format(itr, np.sum(losses)))
333 |
334 | save_weights(weights, FLAGS.weight_path)
335 |
336 | if __name__ == '__main__':
337 | parser = argparse.ArgumentParser(description='MAML')
338 | parser.add_argument('--gradcheck', type=int, default=0, help='Run gradient check and other tests')
339 | parser.add_argument('--test', type=int, default=0, help='Run test on trained network to see if it works')
340 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate')
341 | parser.add_argument('--weight_path', type=str, default='trained_maml_weights.pkl', help='File name to save and load weights')
342 | FLAGS = parser.parse_args()
343 |
344 | if FLAGS.gradcheck:
345 | gradcheck()
346 | exit(0)
347 |
348 | if FLAGS.test:
349 | test()
350 | exit(0)
351 |
352 | main()
353 |
354 |
355 |
--------------------------------------------------------------------------------
/utils/data_generator.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Taken (and modified/hacked) from Chelsea Finn's MAML implementation
4 | https://github.com/cbfinn/maml
5 |
6 | Code for loading data.
7 |
8 | """
9 | import numpy as np
10 | import random
11 |
12 | class SinusoidGenerator(object):
13 | """
14 | SinusoidGenerator capable of generating batches of sinusoid
15 | A "class" is considered a particular sinusoid function.
16 | """
17 | def __init__(self, num_samples_per_class, batch_size, config={}):
18 | """
19 | Args:
20 | num_samples_per_class: num samples to generate per class in one batch
21 | batch_size: size of meta batch size (e.g. number of functions)
22 | """
23 | self.batch_size = batch_size
24 | self.num_samples_per_class = num_samples_per_class
25 | self.num_classes = 1 # by default 1 (only relevant for classification problems)
26 |
27 | self.generate = self.generate_sinusoid_batch
28 | self.amp_range = config.get('amp_range', [0.1, 5.0])
29 | self.phase_range = config.get('phase_range', [0, np.pi])
30 | self.input_range = config.get('input_range', [-5.0, 5.0])
31 | self.dim_input = 1
32 | self.dim_output = 1
33 |
34 | def generate_sinusoid_batch(self, train=True, input_idx=None):
35 | # Note train arg is not used (but it is used for omniglot method.
36 | # input_idx is used during qualitative testing --the number of examples used for the grad update
37 | amp = np.random.uniform(self.amp_range[0], self.amp_range[1], [self.batch_size])
38 | phase = np.random.uniform(self.phase_range[0], self.phase_range[1], [self.batch_size])
39 | outputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_output])
40 | init_inputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_input])
41 | for func in range(self.batch_size):
42 | init_inputs[func] = np.random.uniform(self.input_range[0], self.input_range[1], [self.num_samples_per_class, 1])
43 | if input_idx is not None:
44 | init_inputs[:,input_idx:,0] = np.linspace(self.input_range[0], self.input_range[1], num=self.num_samples_per_class-input_idx, retstep=False)
45 | outputs[func] = amp[func] * np.sin(init_inputs[func]-phase[func])
46 | return init_inputs, outputs, amp, phase
47 |
--------------------------------------------------------------------------------
/utils/gradient_check.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from builtins import range
3 | from past.builtins import xrange
4 |
5 | import numpy as np
6 | from random import randrange
7 |
8 |
9 | """
10 | THIS IS FROM TAKEN FROM STANFORD'S CS231N COURSE, WHICH I HIGHLY RECOMMEND
11 | http://cs231n.github.io/
12 |
13 | It does numerical gradient checking
14 | """
15 |
16 | def eval_numerical_gradient(f, x, verbose=True, h=0.00001):
17 | """
18 | a naive implementation of numerical gradient of f at x
19 | - f should be a function that takes a single argument
20 | - x is the point (numpy array) to evaluate the gradient at
21 | """
22 |
23 | fx = f(x) # evaluate function value at original point
24 | grad = np.zeros_like(x)
25 | # iterate over all indexes in x
26 | it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
27 | while not it.finished:
28 |
29 | # evaluate function at x+h
30 | ix = it.multi_index
31 | oldval = x[ix]
32 | x[ix] = oldval + h # increment by h
33 | fxph = f(x) # evalute f(x + h)
34 | x[ix] = oldval - h
35 | fxmh = f(x) # evaluate f(x - h)
36 | x[ix] = oldval # restore
37 |
38 | # compute the partial derivative with centered formula
39 | grad[ix] = (fxph - fxmh) / (2 * h) # the slope
40 | if verbose:
41 | print(ix, grad[ix])
42 | it.iternext() # step to next dimension
43 |
44 | return grad
45 |
46 |
47 | def eval_numerical_gradient_array(f, x, df, h=1e-5):
48 | """
49 | Evaluate a numeric gradient for a function that accepts a numpy
50 | array and returns a numpy array.
51 | """
52 | grad = np.zeros_like(x)
53 | it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
54 | while not it.finished:
55 | ix = it.multi_index
56 |
57 | oldval = x[ix]
58 | x[ix] = oldval + h
59 | pos = f(x).copy()
60 | x[ix] = oldval - h
61 | neg = f(x).copy()
62 | x[ix] = oldval
63 |
64 | grad[ix] = np.sum((pos - neg) * df) / (2 * h)
65 | it.iternext()
66 | return grad
67 |
68 |
69 | def eval_numerical_gradient_blobs(f, inputs, output, h=1e-5):
70 | """
71 | Compute numeric gradients for a function that operates on input
72 | and output blobs.
73 |
74 | We assume that f accepts several input blobs as arguments, followed by a
75 | blob where outputs will be written. For example, f might be called like:
76 |
77 | f(x, w, out)
78 |
79 | where x and w are input Blobs, and the result of f will be written to out.
80 |
81 | Inputs:
82 | - f: function
83 | - inputs: tuple of input blobs
84 | - output: output blob
85 | - h: step size
86 | """
87 | numeric_diffs = []
88 | for input_blob in inputs:
89 | diff = np.zeros_like(input_blob.diffs)
90 | it = np.nditer(input_blob.vals, flags=['multi_index'],
91 | op_flags=['readwrite'])
92 | while not it.finished:
93 | idx = it.multi_index
94 | orig = input_blob.vals[idx]
95 |
96 | input_blob.vals[idx] = orig + h
97 | f(*(inputs + (output,)))
98 | pos = np.copy(output.vals)
99 | input_blob.vals[idx] = orig - h
100 | f(*(inputs + (output,)))
101 | neg = np.copy(output.vals)
102 | input_blob.vals[idx] = orig
103 |
104 | diff[idx] = np.sum((pos - neg) * output.diffs) / (2.0 * h)
105 |
106 | it.iternext()
107 | numeric_diffs.append(diff)
108 | return numeric_diffs
109 |
110 |
111 | def eval_numerical_gradient_net(net, inputs, output, h=1e-5):
112 | return eval_numerical_gradient_blobs(lambda *args: net.forward(),
113 | inputs, output, h=h)
114 |
115 |
116 | def grad_check_sparse(f, x, analytic_grad, num_checks=10, h=1e-5):
117 | """
118 | sample a few random elements and only return numerical
119 | in this dimensions.
120 | """
121 |
122 | for i in range(num_checks):
123 | ix = tuple([randrange(m) for m in x.shape])
124 |
125 | oldval = x[ix]
126 | x[ix] = oldval + h # increment by h
127 | fxph = f(x) # evaluate f(x + h)
128 | x[ix] = oldval - h # increment by h
129 | fxmh = f(x) # evaluate f(x - h)
130 | x[ix] = oldval # reset
131 |
132 | grad_numerical = (fxph - fxmh) / (2 * h)
133 | grad_analytic = analytic_grad[ix]
134 | rel_error = (abs(grad_numerical - grad_analytic) /
135 | (abs(grad_numerical) + abs(grad_analytic)))
136 | print('numerical: %f analytic: %f, relative error: %e'
137 | %(grad_numerical, grad_analytic, rel_error))
138 |
139 |
140 | def rel_error(x, y):
141 | """ returns relative error """
142 | return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))
143 |
--------------------------------------------------------------------------------
/utils/optim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | """
3 | THIS IS FROM TAKEN FROM STANFORD'S CS231N COURSE, WHICH I HIGHLY RECOMMEND
4 | http://cs231n.github.io/
5 |
6 | This file implements various first-order update rules that are commonly used for
7 | training neural networks. Each update rule accepts current weights and the
8 | gradient of the loss with respect to those weights and produces the next set of
9 | weights. Each update rule has the same interface:
10 |
11 | def update(w, dw, config=None):
12 |
13 | Inputs:
14 | - w: A numpy array giving the current weights.
15 | - dw: A numpy array of the same shape as w giving the gradient of the
16 | loss with respect to w.
17 | - config: A dictionary containing hyperparameter values such as learning rate,
18 | momentum, etc. If the update rule requires caching values over many
19 | iterations, then config will also hold these cached values.
20 |
21 | Returns:
22 | - next_w: The next point after the update.
23 | - config: The config dictionary to be passed to the next iteration of the
24 | update rule.
25 |
26 | NOTE: For most update rules, the default learning rate will probably not perform
27 | well; however the default values of the other hyperparameters should work well
28 | for a variety of different problems.
29 |
30 | For efficiency, update rules may perform in-place updates, mutating w and
31 | setting next_w equal to w.
32 | """
33 |
34 |
35 |
36 |
37 |
38 | def sgd(w, dw, config=None):
39 | """
40 | Performs vanilla stochastic gradient descent.
41 |
42 | config format:
43 | - learning_rate: Scalar learning rate.
44 | """
45 | if config is None: config = {}
46 | config.setdefault('learning_rate', 1e-2)
47 |
48 | w -= config['learning_rate'] * dw
49 | return w, config
50 |
51 |
52 | def adam(x, dx, config=None):
53 | """
54 | Uses the Adam update rule, which incorporates moving averages of both the
55 | gradient and its square and a bias correction term.
56 |
57 | config format:
58 | - learning_rate: Scalar learning rate.
59 | - beta1: Decay rate for moving average of first moment of gradient.
60 | - beta2: Decay rate for moving average of second moment of gradient.
61 | - epsilon: Small scalar used for smoothing to avoid dividing by zero.
62 | - m: Moving average of gradient.
63 | - v: Moving average of squared gradient.
64 | - t: Iteration number.
65 | """
66 | if config is None: config = {}
67 | config.setdefault('learning_rate', 1e-3)
68 | config.setdefault('beta1', 0.9)
69 | config.setdefault('beta2', 0.999)
70 | config.setdefault('epsilon', 1e-8)
71 | config.setdefault('m', np.zeros_like(x))
72 | config.setdefault('v', np.zeros_like(x))
73 | config.setdefault('t', 0)
74 |
75 | #print(config['learning_rate'])
76 |
77 | next_x = None
78 | beta1, beta2, eps = config['beta1'], config['beta2'], config['epsilon']
79 | t, m, v = config['t'], config['m'], config['v']
80 | m = beta1 * m + (1 - beta1) * dx
81 | v = beta2 * v + (1 - beta2) * (dx * dx)
82 | t += 1
83 | alpha = config['learning_rate'] * np.sqrt(1 - beta2 ** t) / (1 - beta1 ** t)
84 | x -= alpha * (m / (np.sqrt(v) + eps))
85 | config['t'] = t
86 | config['m'] = m
87 | config['v'] = v
88 | next_x = x
89 |
90 | return next_x, config
91 |
92 |
93 | class AdamOptimizer():
94 | def __init__(self, params, learning_rate=1e-3):
95 | # Configuration for Adam optimization
96 | self.optimization_config = {'learning_rate': learning_rate}
97 | self.adam_configs = {}
98 | for p in params:
99 | d = {k: v for k, v in self.optimization_config.items()}
100 | self.adam_configs[p] = d
101 |
102 | def apply_gradients(self, params, grads, learning_rate=None):
103 | for p in params:
104 | if learning_rate is not None:
105 | self.adam_configs[p]['learning_rate'] = learning_rate
106 | next_w, self.adam_configs[p] = adam(params[p], grads[p], config=self.adam_configs[p])
107 | params[p] = next_w
108 |
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------