├── .gitignore ├── LICENSE ├── README.md ├── assets ├── baseline.png ├── derivation.png ├── derivation.svg ├── derivation2.png ├── derivation2.svg ├── eq1.png ├── eq2.png ├── maml.png └── random.png ├── maml.py ├── maml_1hidden.py └── utils ├── data_generator.py ├── gradient_check.py └── optim.py /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | # vim 3 | *.swp 4 | *.swo 5 | 6 | **.pkl 7 | *.pkl 8 | 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Matthew Wilson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAML in raw numpy 2 | 3 | This is an implementation of vanilla Model-Agnostic Meta-Learning ([MAML](https://github.com/cbfinn/maml)) 4 | in raw numpy. I made this to better understand the algorithm and what it is doing. I derived 5 | the forward and backward passes following conventions from [CS231n](http://cs231n.github.io/). 6 | This code is just a rough sketch to understand the algorithm better, so it works, but 7 | is not optimized or well parameterized. 8 | This turned out to be pretty interesting to see the algorithm 9 | logic without the backprop abstracted away by an autograd package like TensorFlow. 10 | 11 | **Table of contents** 12 | - [Results](#results) 13 | - [What is MAML?](#whatismaml) 14 | - [Derivation](#derivation) 15 | 16 | 17 | 18 | 19 | ## Results 20 | 21 | To verify my implementation, I test on the 1D sinusoid regression problem 22 | from [Section 5.1](https://arxiv.org/pdf/1703.03400.pdf) of the MAML paper (see 23 | also the description of the problem in [Section 4](https://arxiv.org/pdf/1803.02999.pdf) of the Reptile). 24 | 25 | I train for 10k iterations on a [dataset](utils/data_generator.py) of sine 26 | function input/outputs with randomly sampled amplitude and phase, and then 27 | fine-tune on 10 samples from a fixed amplitude and phase. 28 | After fine-tuning, I predict the value of the fixed sine function 29 | for 50 evenly distributed x values between (-5, 5), and plot the results 30 | compared to the ground truth for pre-trained MAML, pre-trained baseline 31 | (joint training), and a randomly initialized network. I find that 32 | MAML is able to fit the sinusoid much more effectively. 33 | 34 | MAML | Baseline (joint training)| Random init 35 | :-------------------------:|:-------------------------:|:----------:| 36 | ![](/assets/maml.png) | ![](/assets/baseline.png) | ![](/assets/random.png) 37 | 38 | Here are the commands to the run the code: 39 | 40 | - Train for 10k iterations and then save the weights to a file:
41 | ``` 42 | python3 maml.py # train both MAML and baseline (joint trained) weights 43 | ``` 44 | - After training, fine-tune the network and plot results on sine task:
45 | ``` 46 | python3 maml.py --test 1 47 | ``` 48 | - Run gradient check on implementation: 49 | ``` 50 | python3 maml.py --gradcheck 1 51 | ``` 52 | 53 | 54 | ### Notes 55 | These results come from using a neural network with 2 hidden layers. I 56 | originally tried using 1 hidden layer (see [`maml_1hidden.py`](maml_1hidden.py), 57 | because it was easier to derive, but I 58 | found that it did not have enough it did not have enough representational 59 | capacity to solve the sinusoid problem (see [Meta-Learning And Universality](https://arxiv.org/pdf/1710.11622.pdf) for more details on representational capacity of MAML). 60 | 61 |
62 | 63 | ## What is MAML? 64 | 65 | ### Introduction 66 | 67 | Model-Agnostic Meta-Learning (MAML) is a gradient based meta-learning algorithm. For an 68 | overview of meta-learning, see a blog post from the author [here](https://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/), and a good talk [here](https://youtu.be/i05Fk4ebMY0). 69 | Roughly meta-learning tries to solve sample-ineffiency problems in 70 | machine learning. It tries to allow models to learn 71 | quickly on new tasks by better incorporating past information from previous tasks. 72 | 73 | Unlike [several](https://arxiv.org/abs/1611.02779) [other](https://openreview.net/forum?id=rJY0-Kcll) [meta-learning](https://arxiv.org/abs/1606.04474) [methods](https://arxiv.org/abs/1707.03141) which use RNNs, MAML only uses 74 | feed-forward networks and gradient descent. The interesting piece is how it 75 | sets up the gradient descent scheme to optimize the network for efficient 76 | fine-tuning on the meta-test set. 77 | In standard neural network training, we use gradient-descent and backprop for 78 | training. MAML assumes that you will use this same approach to quickly 79 | fine-tune on your task and it builds this into the meta-training optimization. 80 | 81 | MAML breaks the meta-learning problem into two phases: a **meta-traning phase** and a **fine-tuning phase**. The meta-training phase optimizes the network parameters so that the fine-tune phase is more effective — so that the network parameters will be sensitive to gradients and can 82 | quickly adapt to solve newly sampled tasks in the distribution. The fine-tuning phase will just 83 | run standard gradient descent using the weights that were produced in the meta-training phase, just like you would fine-tune a network for a task using e.g., pre-trained 84 | ImageNet weights. This process looks somewhat similar to transfer learning, but 85 | is more general and produces better results on meta-learning problems like one-shot learning (where you are given a single instance of a new 86 | object class like electric scooter, and your model must quickly adapt so that 87 | it can effectively distinguish new images of electric scooters from other objects). 88 | 89 | 90 | ### Meta-training 91 | During meta-training, MAML draws several samples from a **task**, and splits them 92 | into **A** and **B** examples. For example you could draw 10 (x,y) pairs from a sinusoid 93 | problem and split them into 5 A and 5 B examples. In this case each task is 94 | defined by a fixed amplitude and phase of the sinusoid, but tasks can represent 95 | more interesting variations, like what objects the robot should interact 96 | with in [imitating a human demonstration](https://sites.google.com/view/daml). 97 | 98 | Once we have sampled the A and B examples from the task, we will use the A 99 | examples for an **inner optimization** (standard gradient descent), 100 | and the B examples for **outer optimization** (gradient descent back through 101 | the inner optimization). At a high level: we will inner optimize on the A 102 | examples, test the generalization performance on the B examples, and 103 | meta-optimize on that loss (using gradient descent through the whole 104 | computation) in order to place the parameters in a good initialization 105 | for quickly fine-tuning to many varied tasks. 106 | 107 | For concretely how that is done, here is the algorithm logic and 108 | pseudocode that closely match the [TensorFlow 109 | implementation](https://github.com/cbfinn/maml): 110 | 111 | ### MAML algorithm 112 | 113 | 114 | **Algorithm logic (do this for many loops)** 115 | 116 | 1. Sample task T from distribution of possible tasks 117 | 1. Sample examples from T and split into A and B examples 118 | 1. Network forward pass with weights W, using A examples 119 | 1. Backward pass to compute gradients dWa 120 | 1. Apply gradients dWa using SGD: W' <-- W - alpha\*dWa 121 | 1. Forward pass with temp weights W', using B examples this time 122 | 1. Backward pass through the whole thing to compute gradients dWb (NOTE: this gradient is with respect to input weights W, not W'. This is a second order derivative and backprops through the B forward, the gradient update step, the A backward, and the A forward computations. Everything in the below [derivation diagrams](#derivation) is just the meta-forward pass. This is backpropping through the whole thing, starting at pred_b) 123 | 1. Apply gradients dW' (using Adam: W <-- W - alpha\*dWb) 124 | 125 | 126 | NOTE: You could also do batches of tasks at a time and sum the lossBs. 127 | 128 | **Pseudocode that roughly matches Finn's implementation of [MAML in TensorFlow](https://github.com/cbfinn/maml):** 129 | 130 | ```python 131 | weights = init_NN_weights() # neural network weights and biases 132 | 133 | task_data = sample_task() 134 | 135 | inputA, labelA, inputB, labelB = task_data.meta_split() 136 | 137 | # forward pass of network using weights and A examples 138 | netoutA = forward(inputA, weights) 139 | lossA = loss_func(netoutA, labelA) 140 | 141 | gradients = get_gradients(lossA) # w.r.t. weights 142 | 143 | fast_weights = weights + -learning_rate * gradients # gradient descent step on weights 144 | 145 | netoutB = forward(inputB, fast_weights) 146 | lossB = loss_func(netoutB, labelB) 147 | 148 | # then you would plug this lossB in an optimizer like Adam to optimize 149 | # w.r.t. to the original weights. fast_weights are basically just a temporary 150 | # thing to see how gradient descent on the inner lossA led to update them. 151 | # The only state that is retained between iterations of MAML are the weights (not fast). 152 | ``` 153 | 154 | ### Fine-tuning 155 | 156 | At the fine-tune stage, you now have a set of meta-trained weights. Given a 157 | new task, you can just run the inner optimization, keep track of the 158 | fast_weights, and then use them to predict new examples. 159 | 160 | **Pseudocode to illustrate how fine-tuning works and relates to training** 161 | ``` 162 | inputA, labelA = test_data() 163 | 164 | netoutA = forward(inputA, weights) 165 | lossA = loss_func(netoutA, labelA) 166 | 167 | gradients = get_gradients(lossA) # w.r.t. weights 168 | 169 | fast_weights = weights + -learning_rate * gradients # gradient descent step on weights 170 | 171 | 172 | prediction = forward(new_input_to_predict_label_for, fast_weights) 173 | ``` 174 | 175 | 176 | 177 | 178 | ## Derivation 179 | 180 | 181 | The below diagram shows the meta-forward pass for MAML with a single inner 182 | update step. By computing the gradients through this computational graph, 183 | I derived the computations required for the meta-backwared pass. I show 184 | the computation for a single hidden-layer neural network for simplicity, but 185 | in the code I use a two hidden-layer neural network. 186 | 187 | NOTE: (dW2, db2, dW1, db1) are computed in the upper figure nd passed to the lower 188 | figure. Gradients are backpropagated from the output all the way back through 189 | both through to the upper figure. I use the approach from [CS231n](http://cs231n.github.io/). 190 | 191 | **Inner forward and backward:** 192 | ![derivation](/assets/derivation.png) 193 | 194 | **Inner gradient (SGD) update and second (outer) forward pass:** 195 | ![derivation2](/assets/derivation2.png) 196 | 197 | -------------------------------------------------------------------------------- /assets/baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/baseline.png -------------------------------------------------------------------------------- /assets/derivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/derivation.png -------------------------------------------------------------------------------- /assets/derivation.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 22 | 24 | 32 | 38 | 39 | 47 | 53 | 54 | 62 | 68 | 69 | 77 | 83 | 84 | 92 | 98 | 99 | 107 | 113 | 114 | 122 | 128 | 129 | 138 | 144 | 145 | 153 | 159 | 160 | 168 | 174 | 175 | 183 | 189 | 190 | 198 | 204 | 205 | 213 | 219 | 220 | 228 | 234 | 235 | 243 | 249 | 250 | 258 | 264 | 265 | 273 | 279 | 280 | 288 | 294 | 295 | 303 | 309 | 310 | 318 | 324 | 325 | 333 | 339 | 340 | 348 | 354 | 355 | 363 | 369 | 370 | 378 | 384 | 385 | 393 | 399 | 400 | 408 | 414 | 415 | 423 | 429 | 430 | 438 | 444 | 445 | 453 | 459 | 460 | 468 | 474 | 475 | 483 | 489 | 490 | 498 | 504 | 505 | 513 | 519 | 520 | 528 | 534 | 535 | 543 | 549 | 550 | 558 | 564 | 565 | 566 | 599 | 604 | 605 | 607 | 608 | 610 | image/svg+xml 611 | 613 | 614 | 615 | 616 | 617 | 622 | 626 | 632 | * 644 | 645 | 656 | W1 669 | xa 682 | 686 | 692 | * 704 | 705 | b1 718 | labela 731 | 735 | 741 | + 752 | 753 | W2 766 | 770 | 776 | relu 788 | 789 | b2 802 | 806 | 812 | + 823 | 824 | 828 | 834 | - 845 | 846 | 850 | 854 | 860 | *2 871 | 872 | 873 | 877 | 883 | * 895 | 896 | 900 | 906 | sum 918 | 919 | 923 | 929 | drelu1affine1_a > 0 950 | 951 | 955 | 961 | sum 973 | 974 | 982 | 990 | 998 | 1006 | 1014 | 1022 | 1030 | 1037 | 1044 | 1052 | 1060 | 1068 | 1076 | 1084 | 1091 | 1097 | 1103 | 1108 | 1113 | 1119 | 1123 | 1129 | * 1141 | 1142 | 1150 | 1158 | 1165 | 1172 | 1178 | 1185 | 1190 | 1195 | 1201 | 1208 | 1214 | 1220 | 1224 | 1230 | * 1242 | 1243 | 1248 | 1253 | 1258 | 1263 | 1268 | 1273 | 1279 | 1285 | 1291 | 1297 | dW2 1310 | db2 1323 | dW1 1336 | db1 1349 | affine1_a 1371 | relu1_a 1383 | affine2_a = pred_a 1395 | 1401 | 1407 | dout_a 1419 | 1425 | drelu1_a 1437 | 1443 | daffine1_a 1455 | 1466 | 1474 | 1480 | 1486 | 1492 | 1498 | 1504 | 1510 | 1511 | 1512 | -------------------------------------------------------------------------------- /assets/derivation2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/derivation2.png -------------------------------------------------------------------------------- /assets/derivation2.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 22 | 24 | 32 | 38 | 39 | 47 | 53 | 54 | 62 | 68 | 69 | 77 | 83 | 84 | 92 | 98 | 99 | 107 | 113 | 114 | 122 | 128 | 129 | 137 | 143 | 144 | 152 | 158 | 159 | 167 | 173 | 174 | 182 | 188 | 189 | 197 | 203 | 204 | 212 | 218 | 219 | 227 | 233 | 234 | 242 | 248 | 249 | 257 | 263 | 264 | 272 | 278 | 279 | 287 | 293 | 294 | 302 | 308 | 309 | 317 | 323 | 324 | 332 | 338 | 339 | 347 | 353 | 354 | 362 | 368 | 369 | 377 | 383 | 384 | 392 | 398 | 399 | 407 | 413 | 414 | 422 | 428 | 429 | 437 | 443 | 444 | 452 | 458 | 459 | 467 | 473 | 474 | 482 | 488 | 489 | 497 | 503 | 504 | 512 | 518 | 519 | 527 | 533 | 534 | 542 | 548 | 549 | 550 | 586 | 591 | 592 | 594 | 595 | 597 | image/svg+xml 598 | 600 | 601 | 602 | 603 | 604 | 609 | 620 | 636 | 647 | 658 | 669 | b2 682 | 688 | db2 701 | 707 | 711 | 717 | - 728 | 729 | 733 | 737 | 743 | *2 754 | 755 | 756 | 760 | 766 | + 777 | 778 | W2 791 | 797 | dW2 810 | 816 | 820 | 826 | - 837 | 838 | 842 | 846 | 852 | *2 863 | 864 | 865 | 869 | 875 | + 886 | 887 | 891 | 897 | + 908 | 909 | xb 922 | 928 | 932 | 938 | - 949 | 950 | W1 963 | 969 | dW1 982 | 988 | 992 | 996 | 1002 | *2 1013 | 1014 | 1015 | b1 1028 | 1034 | db1 1047 | 1053 | 1057 | 1063 | - 1074 | 1075 | 1079 | 1083 | 1089 | *2 1100 | 1101 | 1102 | 1106 | 1112 | + 1123 | 1124 | 1128 | 1134 | * 1146 | 1147 | 1151 | 1157 | + 1168 | 1169 | 1173 | 1179 | relu 1191 | 1192 | 1196 | 1202 | * 1214 | 1215 | 1219 | 1225 | + 1236 | 1237 | 1245 | 1253 | 1261 | 1269 | 1277 | 1285 | 1293 | 1301 | 1309 | 1316 | 1324 | 1332 | 1340 | 1347 | 1355 | 1363 | 1371 | 1379 | 1387 | 1395 | 1403 | 1411 | 1419 | 1427 | 1436 | 1444 | 1449 | 1455 | 1461 | 1467 | 1473 | 1479 | 1485 | W1' 1501 | affine1_b 1513 | dout_a 1525 | pred_b 1537 | dout_b 1549 | 1554 | b1' 1568 | W2' 1584 | b2' 1598 | 1599 | 1600 | -------------------------------------------------------------------------------- /assets/eq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/eq1.png -------------------------------------------------------------------------------- /assets/eq2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/eq2.png -------------------------------------------------------------------------------- /assets/maml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/maml.png -------------------------------------------------------------------------------- /assets/random.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matwilso/maml_numpy/a78d2d17d138c1bd8fe562660c93117b4851548c/assets/random.png -------------------------------------------------------------------------------- /maml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import pickle 4 | import copy 5 | import random 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import matplotlib as mpl; 9 | mpl.rcParams["savefig.directory"] = '~/Desktop'#$os.chdir(os.path.dirname(__file__)) 10 | import argparse 11 | from collections import defaultdict 12 | 13 | from utils.optim import AdamOptimizer 14 | from utils.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array, rel_error 15 | from utils.data_generator import SinusoidGenerator 16 | 17 | """ 18 | This file contains logic for training a fully-connected neural network with 19 | 2 hidden layers using the Model-Agnostic Meta-Learning (MAML) algorithm. 20 | 21 | It is designed to solve the toy sinusoid meta-learning problem presented in the MAML paper, 22 | and uses the same architecture as presented in the paper. 23 | 24 | Passing the `--gradcheck=1` flag, will run finite differences gradient check 25 | on the meta forward and backward to ensure correct implementation. 26 | 27 | After training a network, you can pass the `--test=1` flag to compare against 28 | a joint-trained and random network baseline. 29 | """ 30 | 31 | 32 | # special dictionary to return 0 if element does not exist (makes gradient code simpler) 33 | GradDict = lambda: defaultdict(lambda: 0) 34 | normalize = lambda x: (x - x.mean()) / (x.std() + 1e-8) 35 | 36 | # weight util functions 37 | def build_weights(hidden_dims=(64, 64)): 38 | """Return dictionary on neural network weights""" 39 | # Initialize all weights (model params) with "Xavier Initialization" 40 | # weight matrix init = uniform(-1, 1) / sqrt(layer_input) 41 | # bias init = zeros() 42 | H1, H2 = hidden_dims 43 | w = {} 44 | w['W1'] = (-1 + 2*np.random.rand(1, H1)) / np.sqrt(1) 45 | w['b1'] = np.zeros(H1) 46 | w['W2'] = (-1 + 2*np.random.rand(H1, H2)) / np.sqrt(H1) 47 | w['b2'] = np.zeros(H2) 48 | w['W3'] = (-1 + 2*np.random.rand(H2, 1)) / np.sqrt(H2) 49 | w['b3'] = np.zeros(1) 50 | 51 | # Cast all parameters to the correct datatype 52 | for k, v in w.items(): 53 | w[k] = v.astype(np.float32) 54 | return w 55 | 56 | def save_weights(weights, filename, quiet=False): 57 | with open(filename, 'wb') as f: 58 | pickle.dump(weights, f) 59 | if not quiet: 60 | print('weights saved to {}'.format(filename)) 61 | 62 | def load_weights(filename, quiet=False): 63 | with open(filename, 'rb') as f: 64 | weights = pickle.load(f) 65 | if not quiet: 66 | print('weights loaded from {}'.format(filename)) 67 | return weights 68 | 69 | class Network(object): 70 | """ 71 | Forward and backward pass logic for 3 layer neural network 72 | (see https://github.com/matwilso/maml_numpy#derivation for derivation) 73 | """ 74 | 75 | def __init__(self, inner_lr=0.01, normalize=normalize): 76 | self.inner_lr = inner_lr # alpha in the paper 77 | self.normalize = normalize # function to normalize gradients before applying them to weights (helps with stability) 78 | 79 | def inner_forward(self, x_a, weights, cache={}): 80 | """Submodule for meta_forward. This is just a standard forward pass for a neural net. 81 | 82 | Args: 83 | x_a (ndarray): Example or examples of sinusoid from given phase, amplitude. 84 | weights (dict): Dictionary of weights and biases for neural net 85 | cache (dict): Pass in dictionary to be updated with values needed in meta_backward 86 | 87 | Returns: 88 | pred_a (ndarray): Predicted values for example(s) x_a 89 | """ 90 | w = weights 91 | W1, b1, W2, b2, W3, b3 = w['W1'], w['b1'], w['W2'], w['b2'], w['W3'], w['b3'] 92 | # layer 1 93 | affine1_a = x_a.dot(W1) + b1 94 | relu1_a = np.maximum(0, affine1_a) 95 | # layer 2 96 | affine2_a = relu1_a.dot(W2) + b2 97 | relu2_a = np.maximum(0, affine2_a) 98 | # layer 3 99 | pred_a = relu2_a.dot(W3) + b3 100 | 101 | cache.update(dict(x_a=x_a, affine1_a=affine1_a, relu1_a=relu1_a, affine2_a=affine2_a, relu2_a=relu2_a)) 102 | return pred_a 103 | 104 | def inner_backward(self, dout_a, weights, cache, grads=GradDict(), lr=None): 105 | """For fine-tuning network at meta-test time 106 | 107 | (Although this has some repeated code from meta_backward, it was hard to 108 | use as a subprocess for meta_backward. It required several changes in 109 | code and made things more confusing.) 110 | 111 | Args: 112 | dout_a (ndarray): Gradient of output (usually loss) 113 | weights (dict): Dictionary of weights and biases for neural net 114 | cache (dict): Dictionary of relevant values from forward pass 115 | 116 | Returns: 117 | dict: New dictionary, with updated weights 118 | """ 119 | w = weights; c = cache 120 | W1, b1, W2, b2, W3, b3 = w['W1'], w['b1'], w['W2'], w['b2'], w['W3'], w['b3'] 121 | lr = lr or self.inner_lr 122 | 123 | drelu2_a = dout_a.dot(W3.T) 124 | dW3 = c['relu2_a'].T.dot(dout_a) 125 | db3 = np.sum(dout_a, axis=0) 126 | 127 | daffine2_a = np.where(c['affine2_a'] > 0, drelu2_a, 0) 128 | 129 | drelu1_a = daffine2_a.dot(W2.T) 130 | dW2 = c['relu1_a'].T.dot(dout_a) 131 | db2 = np.sum(dout_a, axis=0) 132 | 133 | daffine1_a = np.where(c['affine1_a'] > 0, drelu1_a, 0) 134 | 135 | dW1 = c['x_a'].T.dot(daffine1_a) 136 | db1 = np.sum(daffine1_a, axis=0) 137 | 138 | grads['W1'] += dW1 139 | grads['b1'] += db1 140 | grads['W2'] += dW2 141 | grads['b2'] += db2 142 | grads['W3'] += dW3 143 | grads['b3'] += db3 144 | 145 | # Return new weights (for fine-tuning) 146 | new_weights = {} 147 | new_weights['W1'] = W1 - lr*self.normalize(dW1) 148 | new_weights['b1'] = b1 - lr*self.normalize(db1) 149 | new_weights['W2'] = W2 - lr*self.normalize(dW2) 150 | new_weights['b2'] = b2 - lr*self.normalize(db2) 151 | new_weights['W3'] = W3 - lr*self.normalize(dW3) 152 | new_weights['b3'] = b3 - lr*self.normalize(db3) 153 | return new_weights 154 | 155 | 156 | def meta_forward(self, x_a, x_b, label_a, weights, cache={}): 157 | """Full forward pass for MAML. Does a inner_forward, backprop, and gradient 158 | update. This will all be backpropped through w.r.t. weights in meta_backward 159 | 160 | Args: 161 | x_a (ndarray): Example or examples of sinusoid from given phase, amplitude. 162 | x_b (ndarray): Independent example(s) from same phase, amplitude as x_a's 163 | label_a (ndarray): Ground truth labels for x_a 164 | weights (dict): Dictionary of weights and biases for neural net 165 | cache (dict): Pass in dictionary to be updated with values needed in meta_backward 166 | 167 | Returns: 168 | pred_b (ndarray): Predicted values for example(s) x_b 169 | """ 170 | w = weights 171 | W1, b1, W2, b2, W3, b3 = w['W1'], w['b1'], w['W2'], w['b2'], w['W3'], w['b3'] 172 | 173 | # A: inner 174 | # standard forward and backward computations 175 | inner_cache = {} 176 | pred_a = self.inner_forward(x_a, w, inner_cache) 177 | 178 | # inner loss 179 | dout_a = 2*(pred_a - label_a) 180 | 181 | # d 3rd layer 182 | dW3 = inner_cache['relu2_a'].T.dot(dout_a) 183 | db3 = np.sum(dout_a, axis=0) 184 | drelu2_a = dout_a.dot(W3.T) 185 | 186 | daffine2_a = np.where(inner_cache['affine2_a'] > 0, drelu2_a, 0) 187 | 188 | # d 2nd layer 189 | dW2 = inner_cache['relu1_a'].T.dot(daffine2_a) 190 | db2 = np.sum(daffine2_a, axis=0) 191 | drelu1_a = daffine2_a.dot(W2.T) 192 | 193 | daffine1_a = np.where(inner_cache['affine1_a'] > 0, drelu1_a, 0) 194 | 195 | # d 1st layer 196 | dW1 = x_a.T.dot(daffine1_a) 197 | db1 = np.sum(daffine1_a, axis=0) 198 | 199 | # Forward on fast weights 200 | # B: meta/outer 201 | # SGD step is baked into forward pass, representing optimizing through fine-tuning 202 | # Theta prime in the paper. Also called fast_weights in Finn's TF implementation 203 | W1_prime = W1 - self.inner_lr*dW1 204 | b1_prime = b1 - self.inner_lr*db1 205 | W2_prime = W2 - self.inner_lr*dW2 206 | b2_prime = b2 - self.inner_lr*db2 207 | W3_prime = W3 - self.inner_lr*dW3 208 | b3_prime = b3 - self.inner_lr*db3 209 | 210 | # Do another forward pass with the fast weights, to predict B example 211 | affine1_b = x_b.dot(W1_prime) + b1_prime 212 | relu1_b = np.maximum(0, affine1_b) 213 | affine2_b = relu1_b.dot(W2_prime) + b2_prime 214 | relu2_b = np.maximum(0, affine2_b) 215 | pred_b = relu2_b.dot(W3_prime) + b3_prime 216 | 217 | # Cache relevant values for meta backpropping 218 | outer_cache = dict(dout_a=dout_a, x_b=x_b, affine1_b=affine1_b, relu1_b=relu1_b, affine2_b=affine2_b, relu2_b=relu2_b, daffine2_a=daffine2_a, W2_prime=W2_prime, W3_prime=W3_prime) 219 | cache.update(inner_cache) 220 | cache.update(outer_cache) 221 | 222 | return pred_b 223 | 224 | def meta_backward(self, dout_b, weights, cache, grads=GradDict()): 225 | """Full backward pass for MAML. Through all operations from forward pass 226 | 227 | Args: 228 | dout_b (ndarray): Gradient signal of network output (usually loss gradient) 229 | weights (dict): Dictionary of weights and biases used in forward pass 230 | cache (dict): Dictionary of relevant values from forward pass 231 | grads (dict): Pass in dictionary to be updated with weight gradients 232 | """ 233 | c = cache; w = weights 234 | W1, b1, W2, b2, W3, b3 = w['W1'], w['b1'], w['W2'], w['b2'], w['W3'], w['b3'] 235 | 236 | # First, backprop through the B network pass 237 | # d 3rd layer 238 | drelu2_b = dout_b.dot(c['W3_prime'].T) 239 | dW3_prime = c['relu2_b'].T.dot(dout_b) 240 | db3_prime = np.sum(dout_b, axis=0) 241 | 242 | daffine2_b = np.where(c['affine2_b'] > 0, drelu2_b, 0) 243 | 244 | # d 2nd layer 245 | drelu1_b = daffine2_b.dot(c['W2_prime'].T) 246 | dW2_prime = c['relu1_b'].T.dot(daffine2_b) 247 | db2_prime = np.sum(daffine2_b, axis=0) 248 | 249 | daffine1_b = np.where(c['affine1_b'] > 0, drelu1_b, 0) 250 | 251 | # d 1st layer 252 | dW1_prime = c['x_b'].T.dot(daffine1_b) 253 | db1_prime = np.sum(daffine1_b, axis=0) 254 | 255 | # Next, backprop through the gradient descent step 256 | dW1 = dW1_prime 257 | db1 = db1_prime 258 | dW2 = dW2_prime 259 | db2 = db2_prime 260 | dW3 = dW3_prime 261 | db3 = db3_prime 262 | 263 | ddW1 = dW1_prime * -self.inner_lr 264 | ddb1 = db1_prime * -self.inner_lr 265 | ddW2 = dW2_prime * -self.inner_lr 266 | ddb2 = db2_prime * -self.inner_lr 267 | ddW3 = dW3_prime * -self.inner_lr 268 | ddb3 = db3_prime * -self.inner_lr 269 | 270 | # Then, backprop through the first backprop 271 | # start with dW1's 272 | ddaffine1_a = c['x_a'].dot(ddW1) 273 | ddaffine1_a += ddb1 274 | 275 | ddrelu1_a = np.where(c['affine1_a'] > 0, ddaffine1_a, 0) 276 | 277 | ddaffine2_a = ddrelu1_a.dot(W2) 278 | dW2 += ddrelu1_a.T.dot(c['daffine2_a']) 279 | 280 | # dW2's 281 | drelu1_a = c['daffine2_a'].dot(ddW2.T) # shortcut back because of the grad dependency 282 | ddaffine2_a += ddb2 283 | ddaffine2_a += c['relu1_a'].dot(ddW2) 284 | 285 | ddrelu2_a = np.where(c['affine2_a'] > 0, ddaffine2_a, 0) 286 | 287 | ddout_a = ddrelu2_a.dot(W3) 288 | dW3 += ddrelu2_a.T.dot(c['dout_a']) 289 | 290 | # dW3's 291 | drelu2_a = c['dout_a'].dot(ddW3.T) # shortcut back because of the grad dependency 292 | ddout_a += ddb3 293 | ddout_a += c['relu2_a'].dot(ddW3) 294 | 295 | # Finally, backprop through the first forward 296 | dpred_a = ddout_a * 2 297 | 298 | drelu2_a += dpred_a.dot(W3.T) 299 | db3 += np.sum(dpred_a, axis=0) 300 | dW3 += c['relu2_a'].T.dot(dpred_a) 301 | 302 | daffine2_a = np.where(c['affine2_a'] > 0, drelu2_a, 0) 303 | 304 | drelu1_a += daffine2_a.dot(W2.T) 305 | dW2 += c['relu1_a'].T.dot(daffine2_a) 306 | db2 += np.sum(daffine2_a, axis=0) 307 | 308 | daffine1_a = np.where(c['affine1_a'] > 0, drelu1_a, 0) 309 | 310 | dW1 += c['x_a'].T.dot(daffine1_a) 311 | db1 += np.sum(daffine1_a, axis=0) 312 | 313 | # update gradients 314 | grads['W1'] += self.normalize(dW1) 315 | grads['b1'] += self.normalize(db1) 316 | grads['W2'] += self.normalize(dW2) 317 | grads['b2'] += self.normalize(db2) 318 | grads['W3'] += self.normalize(dW3) 319 | grads['b3'] += self.normalize(db3) 320 | 321 | 322 | def gradcheck(): 323 | # Test the network gradient 324 | nn = Network(normalize=lambda x: x) # don't normalize gradients so we can check validity 325 | grads = GradDict() # initialize grads to 0 326 | # dummy inputs, labels, and fake backwards gradient signal 327 | x_a = np.random.randn(15, 1) 328 | x_b = np.random.randn(15, 1) 329 | label = np.random.randn(15, 1) 330 | dout = np.random.randn(15, 1) 331 | # make weights. don't use build_weights here because this is more stable 332 | W1 = np.random.randn(1, 40) 333 | b1 = np.random.randn(40) 334 | W2 = np.random.randn(40, 40) 335 | b2 = np.random.randn(40) 336 | W3 = np.random.randn(40, 1) 337 | b3 = np.random.randn(1) 338 | weights = dict(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3) 339 | 340 | # helper function to only change a single key of interest for independent finite differences 341 | def rep_param(weights, name, val): 342 | clean_params = copy.deepcopy(weights) 343 | clean_params[name] = val 344 | return clean_params 345 | 346 | # Evaluate gradients numerically, using finite differences 347 | numerical_grads = {} 348 | for key in weights: 349 | num_grad = eval_numerical_gradient_array(lambda w: nn.meta_forward(x_a, x_b, label, rep_param(weights, key, w)), weights[key], dout, h=1e-5) 350 | numerical_grads[key] = num_grad 351 | 352 | # Compute neural network gradients 353 | cache = {} 354 | out = nn.meta_forward(x_a, x_b, label, weights, cache=cache) 355 | nn.meta_backward(dout, weights, cache, grads) 356 | 357 | # The error should be around 1e-10 358 | print() 359 | for key in weights: 360 | print('d{} error: {}'.format(key, rel_error(numerical_grads[key], grads[key]))) 361 | print() 362 | 363 | def test(): 364 | """Take one grad step using a minibatch of size 5 and see how well it works 365 | 366 | Basically what they show in Figure 2 of the paper 367 | """ 368 | nn = Network(inner_lr=FLAGS.inner_lr) 369 | 370 | pre_weights = {} 371 | pre_weights['maml'] = load_weights(FLAGS.weight_path) 372 | if FLAGS.use_baseline: 373 | pre_weights['baseline'] = load_weights('baseline_'+FLAGS.weight_path) 374 | pre_weights['random'] = build_weights() 375 | 376 | # Generate N batches of data, with same shape as training, but that all have the same amplitude and phase 377 | N = 2 378 | #sinegen = SinusoidGenerator(FLAGS.inner_bs*N, 1, config={'input_range':[1.0,5.0]}) 379 | sinegen = SinusoidGenerator(FLAGS.inner_bs*N, 1) 380 | x, y, amp, phase = map(lambda x: x[0], sinegen.generate()) # grab all the first elems 381 | xs = np.split(x, N) 382 | ys = np.split(y, N) 383 | 384 | # Copy pre-update weights for later comparison 385 | deepcopy = lambda weights: {key: weights[key].copy() for key in weights} 386 | post_weights = {} 387 | for key in pre_weights: 388 | post_weights[key] = deepcopy(pre_weights[key]) 389 | 390 | T = 10 391 | # Run fine-tuning 392 | for key in post_weights: 393 | for t in range(T): 394 | for i in range(len(xs)): 395 | x = xs[i] 396 | y = ys[i] 397 | grads = GradDict() 398 | cache = {} 399 | pred = nn.inner_forward(x, post_weights[key], cache) 400 | loss = (pred - y)**2 401 | dout = 2*(pred - y) 402 | post_weights[key] = nn.inner_backward(dout, post_weights[key], cache) 403 | 404 | 405 | colors = {'maml': 'r', 'baseline': 'b', 'random': 'g'} 406 | name = {'maml': 'MAML', 'baseline': 'joint training', 'random': 'random initialization'} 407 | 408 | sine_ground = lambda x: amp*np.sin(x - phase) 409 | sine_pre_pred = lambda x, key: nn.inner_forward(x, pre_weights[key])[0] 410 | sine_post_pred = lambda x, key: nn.inner_forward(x, post_weights[key])[0] 411 | 412 | x_vals = np.linspace(-5, 5) 413 | y_ground = np.apply_along_axis(sine_ground, 0, x_vals) 414 | 415 | 416 | for key in post_weights: 417 | y_pre = np.array([sine_pre_pred(np.array(x), key) for x in x_vals]).squeeze() 418 | y_nn = np.array([sine_post_pred(np.array(x), key) for x in x_vals]).squeeze() 419 | plt.plot(x_vals, y_ground, 'k', label='{:.2f}sin(x - {:.2f})'.format(amp, phase)) 420 | plt.plot(np.concatenate(xs), np.concatenate(ys), 'ok', label='samples') 421 | plt.plot(x_vals, y_pre, colors[key]+'--', label='pre-update') 422 | plt.plot(x_vals, y_nn, colors[key]+'-', label='post-update') 423 | 424 | plt.legend() 425 | plt.title('Fine-tuning performance {}'.format(name[key])) 426 | plt.savefig(key+'.png') 427 | plt.show() 428 | 429 | def train(): 430 | nn = Network(inner_lr=FLAGS.inner_lr) 431 | weights = build_weights() 432 | optimizer = AdamOptimizer(weights, learning_rate=FLAGS.meta_lr) 433 | if FLAGS.use_baseline: 434 | baseline_weights = build_weights() 435 | baseline_optimizer = AdamOptimizer(baseline_weights, learning_rate=FLAGS.meta_lr) 436 | 437 | sinegen = SinusoidGenerator(2*FLAGS.inner_bs, 25) # update_batch * 2, meta batch size 438 | 439 | try: 440 | nitr = int(FLAGS.num_iter) 441 | for itr in range(int(nitr)): 442 | # create a minibatch of size 25, with 10 points 443 | batch_x, batch_y, amp, phase = sinegen.generate() 444 | 445 | inputa = batch_x[:, :FLAGS.inner_bs :] 446 | labela = batch_y[:, :FLAGS.inner_bs :] 447 | inputb = batch_x[:, FLAGS.inner_bs :] # b used for testing 448 | labelb = batch_y[:, FLAGS.inner_bs :] 449 | 450 | # META BATCH 451 | grads = GradDict() # zero grads 452 | baseline_grads = GradDict() # zero grads 453 | losses = [] 454 | baseline_losses = [] 455 | for batch_i in range(len(inputa)): 456 | ia, la, ib, lb = inputa[batch_i], labela[batch_i], inputb[batch_i], labelb[batch_i] 457 | cache = {} 458 | pred_b = nn.meta_forward(ia, ib, la, weights, cache=cache) 459 | losses.append((pred_b - lb)**2) 460 | dout_b = 2*(pred_b - lb) 461 | nn.meta_backward(dout_b, weights, cache, grads) 462 | 463 | 464 | if FLAGS.use_baseline: 465 | baseline_cache = {} 466 | baseline_i = np.concatenate([ia,ib]) 467 | baseline_l = np.concatenate([la,lb]) 468 | baseline_pred = nn.inner_forward(baseline_i, baseline_weights, cache=baseline_cache) 469 | baseline_losses.append((baseline_pred - baseline_l)**2) 470 | dout_b = 2*(baseline_pred - baseline_l) 471 | nn.inner_backward(dout_b, baseline_weights, baseline_cache, baseline_grads) 472 | 473 | optimizer.apply_gradients(weights, grads, learning_rate=FLAGS.meta_lr) 474 | if FLAGS.use_baseline: 475 | baseline_optimizer.apply_gradients(baseline_weights, baseline_grads, learning_rate=FLAGS.meta_lr) 476 | if itr % 100 == 0: 477 | if FLAGS.use_baseline: 478 | print("[itr: {}] MAML loss = {} Baseline loss = {}".format(itr, np.sum(losses), np.sum(baseline_losses))) 479 | else: 480 | print("[itr: {}] Loss = {}".format(itr, np.sum(losses))) 481 | except KeyboardInterrupt: 482 | pass 483 | save_weights(weights, FLAGS.weight_path) 484 | if FLAGS.use_baseline: 485 | save_weights(baseline_weights, "baseline_"+FLAGS.weight_path) 486 | 487 | 488 | if __name__ == '__main__': 489 | parser = argparse.ArgumentParser(description='MAML') 490 | parser.add_argument('--seed', type=int, default=2, help='') 491 | parser.add_argument('--gradcheck', type=int, default=0, help='Run gradient check and other tests') 492 | parser.add_argument('--test', type=int, default=0, help='Run test on trained network to see if it works') 493 | parser.add_argument('--meta_lr', type=float, default=1e-3, help='Meta learning rate') 494 | parser.add_argument('--inner_lr', type=float, default=1e-2, help='Inner learning rate') 495 | parser.add_argument('--inner_bs', type=int, default=5, help='Inner batch size') 496 | parser.add_argument('--weight_path', type=str, default='trained_maml_weights.pkl', help='File name to save and load weights') 497 | parser.add_argument('--use_baseline', type=int, default=1, help='Whether to train a baseline network') 498 | parser.add_argument('--num_iter', type=float, default=1e4, help='Number of iterations') 499 | FLAGS = parser.parse_args() 500 | np.random.seed(FLAGS.seed) 501 | 502 | if FLAGS.gradcheck: 503 | gradcheck() 504 | elif FLAGS.test: 505 | test() 506 | else: 507 | train() 508 | -------------------------------------------------------------------------------- /maml_1hidden.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import pickle 3 | import copy 4 | import random 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import argparse 8 | from collections import defaultdict 9 | 10 | from utils.optim import AdamOptimizer 11 | from utils.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array, rel_error 12 | from utils.data_generator import SinusoidGenerator 13 | 14 | 15 | # this will create a special dictionary that returns 0 if the element is not set, instead of error 16 | # (it makes the code for updating gradients simpler) 17 | GradDict = lambda: defaultdict(lambda: 0) 18 | 19 | normalize = lambda x: (x - x.mean()) / (x.std() + 1e-8) 20 | 21 | def build_weights(hidden_dim=200): 22 | """Return weights to be used in forward pass""" 23 | # Initialize all weights (model params) with "Xavier Initialization" 24 | # weight matrix init = uniform(-1, 1) / sqrt(layer_input) 25 | # bias init = zeros() 26 | H = hidden_dim 27 | d = {} 28 | d['W1'] = (-1 + 2*np.random.rand(1, H)) / np.sqrt(1) 29 | d['b1'] = np.zeros(H) 30 | d['W2'] = (-1 + 2*np.random.rand(H, 1)) / np.sqrt(H) 31 | d['b2'] = np.zeros(1) 32 | 33 | # Cast all parameters to the correct datatype 34 | for k, v in d.items(): 35 | d[k] = v.astype(np.float32) 36 | return d 37 | 38 | def save_weights(weights, filename, quiet=False): 39 | with open(filename, 'wb') as f: 40 | pickle.dump(weights, f) 41 | if not quiet: 42 | print('weights saved to {}'.format(filename)) 43 | 44 | def load_weights(filename, quiet=False): 45 | with open(filename, 'rb') as f: 46 | weights = pickle.load(f) 47 | if not quiet: 48 | print('weights loaded from {}'.format(filename)) 49 | return weights 50 | 51 | 52 | class Network(object): 53 | """BYOW: Bring Your Own Weights 54 | 55 | Hard-code operations for a 2 layer neural network 56 | """ 57 | def __init__(self, alpha=0.01, normalized=normalize): 58 | self.ALPHA = alpha 59 | self.normalized = normalized 60 | 61 | def inner_forward(self, x_a, w): 62 | """submodule for forward pass""" 63 | W1, b1, W2, b2 = w['W1'], w['b1'], w['W2'], w['b2'] 64 | 65 | affine1_a = x_a.dot(W1) + b1 66 | relu1_a = np.maximum(0, affine1_a) 67 | pred_a = relu1_a.dot(W2) + b2 68 | 69 | cache = dict(x_a=x_a, affine1_a=affine1_a, relu1_a=relu1_a) 70 | return pred_a, cache 71 | 72 | def inner_backward(self, dout_a, weights, cache): 73 | """just for fine-tuning at the end""" 74 | w = weights; c = cache 75 | W1, b1, W2, b2 = w['W1'], w['b1'], w['W2'], w['b2'] 76 | 77 | drelu1_a = dout_a.dot(W2.T) 78 | dW2 = cache['relu1_a'].T.dot(dout_a) 79 | db2 = np.sum(dout_a, axis=0) 80 | 81 | daffine1_a = np.where(cache['affine1_a'] > 0, drelu1_a, 0) 82 | 83 | dW1 = c['x_a'].T.dot(daffine1_a) 84 | db1 = np.sum(daffine1_a, axis=0) 85 | 86 | # grad steps 87 | new_weights = {} 88 | new_weights['W1'] = W1 - self.ALPHA*self.normalized(dW1) 89 | new_weights['b1'] = b1 - self.ALPHA*self.normalized(db1) 90 | new_weights['W2'] = W2 - self.ALPHA*self.normalized(dW2) 91 | new_weights['b2'] = b2 - self.ALPHA*self.normalized(db2) 92 | return new_weights 93 | 94 | 95 | def meta_forward(self, x_a, x_b, label_a, weights, cache=None): 96 | w = weights 97 | W1, b1, W2, b2 = w['W1'], w['b1'], w['W2'], w['b2'] 98 | 99 | # standard forward and backward computations 100 | # (a) 101 | pred_a, inner_cache = self.inner_forward(x_a, w) 102 | 103 | dout_a = 2*(pred_a - label_a) 104 | 105 | drelu1_a = dout_a.dot(W2.T) 106 | dW2 = inner_cache['relu1_a'].T.dot(dout_a) 107 | db2 = np.sum(dout_a, axis=0) 108 | 109 | daffine1_a = np.where(inner_cache['affine1_a'] > 0, drelu1_a, 0) 110 | 111 | dW1 = x_a.T.dot(daffine1_a) 112 | db1 = np.sum(daffine1_a, axis=0) 113 | 114 | # Forward on fast weights 115 | # (b) 116 | 117 | # grad steps 118 | W1_prime = W1 - self.ALPHA*dW1 119 | b1_prime = b1 - self.ALPHA*db1 120 | W2_prime = W2 - self.ALPHA*dW2 121 | b2_prime = b2 - self.ALPHA*db2 122 | 123 | affine1_b = x_b.dot(W1_prime) + b1_prime 124 | relu1_b = np.maximum(0, affine1_b) 125 | pred_b = relu1_b.dot(W2_prime) + b2_prime 126 | 127 | if cache: 128 | outer_cache = dict(dout_a=dout_a, x_b=x_b, affine1_b=affine1_b, relu1_b=relu1_b, W2_prime=W2_prime) 129 | return pred_b, {**inner_cache, **outer_cache} 130 | else: 131 | return pred_b 132 | 133 | def meta_backward(self, dout_b, weights, cache, grads=None): 134 | c = cache; w = weights # short 135 | W1, b1, W2, b2 = w['W1'], w['b1'], w['W2'], w['b2'] 136 | 137 | # deriv w.r.t b (lower half) 138 | # d 1st layer 139 | dW2_prime = c['relu1_b'].T.dot(dout_b) 140 | db2_prime = np.sum(dout_b, axis=0) 141 | drelu1_b = dout_b.dot(c['W2_prime'].T) 142 | 143 | daffine1_b = np.where(c['affine1_b'] > 0, drelu1_b, 0) 144 | # d 2nd layer 145 | dW1_prime = c['x_b'].T.dot(daffine1_b) 146 | db1_prime = np.sum(daffine1_b, axis=0) 147 | 148 | # deriv w.r.t a (upper half) 149 | 150 | # going back through the gradient descent step 151 | dW1 = dW1_prime 152 | db1 = db1_prime 153 | dW2 = dW2_prime 154 | db2 = db2_prime 155 | 156 | ddW1 = dW1_prime * -self.ALPHA 157 | ddb1 = db1_prime * -self.ALPHA 158 | ddW2 = dW2_prime * -self.ALPHA 159 | ddb2 = db2_prime * -self.ALPHA 160 | 161 | # backpropping through the first backprop 162 | ddout_a = c['relu1_a'].dot(ddW2) 163 | ddout_a += ddb2 164 | drelu1_a = c['dout_a'].dot(ddW2.T) # shortcut back because of the grad dependency 165 | 166 | ddaffine1_a = c['x_a'].dot(ddW1) 167 | ddaffine1_a += ddb1 168 | ddrelu1_a = np.where(c['affine1_a'] > 0, ddaffine1_a, 0) 169 | 170 | dW2 += ddrelu1_a.T.dot(c['dout_a']) 171 | 172 | ddout_a += ddrelu1_a.dot(W2) 173 | 174 | dpred_a = ddout_a * 2 # = dout_a 175 | 176 | dW2 += c['relu1_a'].T.dot(dpred_a) 177 | db2 += np.sum(dpred_a, axis=0) 178 | 179 | drelu1_a += dpred_a.dot(W2.T) 180 | 181 | daffine1_a = np.where(c['affine1_a'] > 0, drelu1_a, 0) 182 | 183 | dW1 += c['x_a'].T.dot(daffine1_a) 184 | db1 += np.sum(daffine1_a, axis=0) 185 | 186 | if grads is not None: 187 | # update gradients 188 | grads['W1'] += self.normalized(dW1) 189 | grads['b1'] += self.normalized(db1) 190 | grads['W2'] += self.normalized(dW2) 191 | grads['b2'] += self.normalized(db2) 192 | 193 | 194 | def gradcheck(): 195 | # Test the network gradient 196 | nn = Network(normalized=lambda x: x) 197 | grads = GradDict() 198 | 199 | np.random.seed(231) 200 | x_a = np.random.randn(15, 1) 201 | x_b = np.random.randn(15, 1) 202 | label = np.random.randn(15, 1) 203 | W1 = np.random.randn(1, 40) 204 | b1 = np.random.randn(40) 205 | W2 = np.random.randn(40, 1) 206 | b2 = np.random.randn(1) 207 | 208 | dout = np.random.randn(15, 1) 209 | 210 | weights = w = {} 211 | w['W1'] = W1 212 | w['b1'] = b1 213 | w['W2'] = W2 214 | w['b2'] = b2 215 | 216 | def rep_param(weights, name, val): 217 | clean_params = copy.deepcopy(weights) 218 | clean_params[name] = val 219 | return clean_params 220 | 221 | dW1_num = eval_numerical_gradient_array(lambda w: nn.meta_forward(x_a, x_b, label, rep_param(weights, 'W1', w)), W1, dout) 222 | db1_num = eval_numerical_gradient_array(lambda b: nn.meta_forward(x_a, x_b, label, rep_param(weights, 'b1', b)), b1, dout) 223 | dW2_num = eval_numerical_gradient_array(lambda w: nn.meta_forward(x_a, x_b, label, rep_param(weights, 'W2', w)), W2, dout) 224 | db2_num = eval_numerical_gradient_array(lambda b: nn.meta_forward(x_a, x_b, label, rep_param(weights, 'b2', b)), b2, dout) 225 | 226 | out, cache = nn.meta_forward(x_a, x_b, label, weights, cache=True) 227 | nn.meta_backward(dout, weights, cache, grads) 228 | 229 | # The error should be around 1e-10 230 | print() 231 | print('Testing meta-learning NN backward function:') 232 | print('dW1 error: ', rel_error(dW1_num, grads['W1'])) 233 | print('db1 error: ', rel_error(db1_num, grads['b1'])) 234 | print('dW2 error: ', rel_error(dW2_num, grads['W2'])) 235 | print('db2 error: ', rel_error(db2_num, grads['b2'])) 236 | print() 237 | 238 | def test(): 239 | """take one grad step using a minibatch of size 5 and see how well it works 240 | 241 | basically what they show in Figure 2 of: 242 | https://arxiv.org/pdf/1703.03400.pdf 243 | """ 244 | nn = Network() 245 | pre_weights = load_weights(FLAGS.weight_path) 246 | random_weights = build_weights() 247 | 248 | # values for fine-tuning step 249 | N = 10 250 | sin_gen = SinusoidGenerator(5*N, 1) 251 | x, y, amp, phase = map(lambda x: x[0], sin_gen.generate()) # grab all the first elems 252 | xs = np.split(x, N) 253 | ys = np.split(y, N) 254 | 255 | new_weights = pre_weights.copy() 256 | new_random_weights = random_weights.copy() 257 | for i in range(len(xs)): 258 | x = xs[i] 259 | y = ys[i] 260 | grads = GradDict() 261 | pred, cache = nn.inner_forward(x, new_weights) 262 | loss = (pred - y)**2 263 | dout = 2*(pred - y) 264 | new_weights = nn.inner_backward(dout, new_weights, cache) 265 | 266 | for i in range(len(xs)): 267 | x = xs[i] 268 | y = ys[i] 269 | grads = GradDict() 270 | pred, cache = nn.inner_forward(x, new_random_weights) 271 | loss = (pred - y)**2 272 | dout = 2*(pred - y) 273 | new_random_weights = nn.inner_backward(dout, new_random_weights, cache) 274 | 275 | 276 | sine_true = lambda x: amp*np.sin(x - phase) 277 | sine_nn = lambda x: nn.inner_forward(x, new_weights)[0] 278 | sine_pre = lambda x: nn.inner_forward(x, pre_weights)[0] 279 | sine_random = lambda x: nn.inner_forward(x, random_weights)[0] 280 | sine_new_random = lambda x: nn.inner_forward(x, new_random_weights)[0] 281 | 282 | x_vals = np.linspace(-5, 5) 283 | 284 | y_true = np.apply_along_axis(sine_true, 0, x_vals) 285 | y_nn = np.array([sine_nn(np.array(x)) for x in x_vals]).squeeze() 286 | y_pre = np.array([sine_pre(np.array(x)) for x in x_vals]).squeeze() 287 | y_random = np.array([sine_random(np.array(x)) for x in x_vals]).squeeze() 288 | y_new_random = np.array([sine_new_random(np.array(x)) for x in x_vals]).squeeze() 289 | 290 | plt.plot(x_vals, y_true, 'k', label='{:.2f}sin(x - {:.2f})'.format(amp, phase)) 291 | plt.plot(x_vals, y_pre, 'r--', label='pre-update') 292 | plt.plot(x_vals, y_nn, 'r-', label='post-update') 293 | plt.plot(x_vals, y_random, 'g--', label='random') 294 | plt.plot(x_vals, y_new_random, 'g-', label='new_random') 295 | plt.legend() 296 | plt.show() 297 | 298 | 299 | def main(): 300 | nn = Network() 301 | weights = build_weights() 302 | optimizer = AdamOptimizer(weights, learning_rate=FLAGS.learning_rate) 303 | 304 | sin_gen = SinusoidGenerator(10, 25) # update_batch * 2, meta batch size 305 | 306 | 307 | lr = lambda x: x * FLAGS.learning_rate 308 | 309 | nitr = 1e4 310 | for itr in range(int(nitr)): 311 | frac = 1.0 - (itr / nitr) 312 | 313 | # create a minibatch of size 25, with 10 points 314 | batch_x, batch_y, amp, phase = sin_gen.generate() 315 | 316 | inputa = batch_x[:, :5, :] 317 | labela = batch_y[:, :5, :] 318 | inputb = batch_x[:, 5:, :] # b used for testing 319 | labelb = batch_y[:, 5:, :] 320 | 321 | # META BATCH 322 | grads = GradDict() # zero grads 323 | losses = [] 324 | for batch_i in range(len(inputa)): 325 | ia, la, ib, lb = inputa[batch_i], labela[batch_i], inputb[batch_i], labelb[batch_i] 326 | pred_b, cache = nn.meta_forward(ia, ib, la, weights, cache=True) 327 | losses.append((pred_b - lb)**2) 328 | dout_b = 2*(pred_b - lb) 329 | nn.meta_backward(dout_b, weights, cache, grads) 330 | optimizer.apply_gradients(weights, grads, learning_rate=lr(frac)) 331 | if itr % 100 == 0: 332 | print("[itr: {}] Loss = {}".format(itr, np.sum(losses))) 333 | 334 | save_weights(weights, FLAGS.weight_path) 335 | 336 | if __name__ == '__main__': 337 | parser = argparse.ArgumentParser(description='MAML') 338 | parser.add_argument('--gradcheck', type=int, default=0, help='Run gradient check and other tests') 339 | parser.add_argument('--test', type=int, default=0, help='Run test on trained network to see if it works') 340 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate') 341 | parser.add_argument('--weight_path', type=str, default='trained_maml_weights.pkl', help='File name to save and load weights') 342 | FLAGS = parser.parse_args() 343 | 344 | if FLAGS.gradcheck: 345 | gradcheck() 346 | exit(0) 347 | 348 | if FLAGS.test: 349 | test() 350 | exit(0) 351 | 352 | main() 353 | 354 | 355 | -------------------------------------------------------------------------------- /utils/data_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Taken (and modified/hacked) from Chelsea Finn's MAML implementation 4 | https://github.com/cbfinn/maml 5 | 6 | Code for loading data. 7 | 8 | """ 9 | import numpy as np 10 | import random 11 | 12 | class SinusoidGenerator(object): 13 | """ 14 | SinusoidGenerator capable of generating batches of sinusoid 15 | A "class" is considered a particular sinusoid function. 16 | """ 17 | def __init__(self, num_samples_per_class, batch_size, config={}): 18 | """ 19 | Args: 20 | num_samples_per_class: num samples to generate per class in one batch 21 | batch_size: size of meta batch size (e.g. number of functions) 22 | """ 23 | self.batch_size = batch_size 24 | self.num_samples_per_class = num_samples_per_class 25 | self.num_classes = 1 # by default 1 (only relevant for classification problems) 26 | 27 | self.generate = self.generate_sinusoid_batch 28 | self.amp_range = config.get('amp_range', [0.1, 5.0]) 29 | self.phase_range = config.get('phase_range', [0, np.pi]) 30 | self.input_range = config.get('input_range', [-5.0, 5.0]) 31 | self.dim_input = 1 32 | self.dim_output = 1 33 | 34 | def generate_sinusoid_batch(self, train=True, input_idx=None): 35 | # Note train arg is not used (but it is used for omniglot method. 36 | # input_idx is used during qualitative testing --the number of examples used for the grad update 37 | amp = np.random.uniform(self.amp_range[0], self.amp_range[1], [self.batch_size]) 38 | phase = np.random.uniform(self.phase_range[0], self.phase_range[1], [self.batch_size]) 39 | outputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_output]) 40 | init_inputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_input]) 41 | for func in range(self.batch_size): 42 | init_inputs[func] = np.random.uniform(self.input_range[0], self.input_range[1], [self.num_samples_per_class, 1]) 43 | if input_idx is not None: 44 | init_inputs[:,input_idx:,0] = np.linspace(self.input_range[0], self.input_range[1], num=self.num_samples_per_class-input_idx, retstep=False) 45 | outputs[func] = amp[func] * np.sin(init_inputs[func]-phase[func]) 46 | return init_inputs, outputs, amp, phase 47 | -------------------------------------------------------------------------------- /utils/gradient_check.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from builtins import range 3 | from past.builtins import xrange 4 | 5 | import numpy as np 6 | from random import randrange 7 | 8 | 9 | """ 10 | THIS IS FROM TAKEN FROM STANFORD'S CS231N COURSE, WHICH I HIGHLY RECOMMEND 11 | http://cs231n.github.io/ 12 | 13 | It does numerical gradient checking 14 | """ 15 | 16 | def eval_numerical_gradient(f, x, verbose=True, h=0.00001): 17 | """ 18 | a naive implementation of numerical gradient of f at x 19 | - f should be a function that takes a single argument 20 | - x is the point (numpy array) to evaluate the gradient at 21 | """ 22 | 23 | fx = f(x) # evaluate function value at original point 24 | grad = np.zeros_like(x) 25 | # iterate over all indexes in x 26 | it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite']) 27 | while not it.finished: 28 | 29 | # evaluate function at x+h 30 | ix = it.multi_index 31 | oldval = x[ix] 32 | x[ix] = oldval + h # increment by h 33 | fxph = f(x) # evalute f(x + h) 34 | x[ix] = oldval - h 35 | fxmh = f(x) # evaluate f(x - h) 36 | x[ix] = oldval # restore 37 | 38 | # compute the partial derivative with centered formula 39 | grad[ix] = (fxph - fxmh) / (2 * h) # the slope 40 | if verbose: 41 | print(ix, grad[ix]) 42 | it.iternext() # step to next dimension 43 | 44 | return grad 45 | 46 | 47 | def eval_numerical_gradient_array(f, x, df, h=1e-5): 48 | """ 49 | Evaluate a numeric gradient for a function that accepts a numpy 50 | array and returns a numpy array. 51 | """ 52 | grad = np.zeros_like(x) 53 | it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite']) 54 | while not it.finished: 55 | ix = it.multi_index 56 | 57 | oldval = x[ix] 58 | x[ix] = oldval + h 59 | pos = f(x).copy() 60 | x[ix] = oldval - h 61 | neg = f(x).copy() 62 | x[ix] = oldval 63 | 64 | grad[ix] = np.sum((pos - neg) * df) / (2 * h) 65 | it.iternext() 66 | return grad 67 | 68 | 69 | def eval_numerical_gradient_blobs(f, inputs, output, h=1e-5): 70 | """ 71 | Compute numeric gradients for a function that operates on input 72 | and output blobs. 73 | 74 | We assume that f accepts several input blobs as arguments, followed by a 75 | blob where outputs will be written. For example, f might be called like: 76 | 77 | f(x, w, out) 78 | 79 | where x and w are input Blobs, and the result of f will be written to out. 80 | 81 | Inputs: 82 | - f: function 83 | - inputs: tuple of input blobs 84 | - output: output blob 85 | - h: step size 86 | """ 87 | numeric_diffs = [] 88 | for input_blob in inputs: 89 | diff = np.zeros_like(input_blob.diffs) 90 | it = np.nditer(input_blob.vals, flags=['multi_index'], 91 | op_flags=['readwrite']) 92 | while not it.finished: 93 | idx = it.multi_index 94 | orig = input_blob.vals[idx] 95 | 96 | input_blob.vals[idx] = orig + h 97 | f(*(inputs + (output,))) 98 | pos = np.copy(output.vals) 99 | input_blob.vals[idx] = orig - h 100 | f(*(inputs + (output,))) 101 | neg = np.copy(output.vals) 102 | input_blob.vals[idx] = orig 103 | 104 | diff[idx] = np.sum((pos - neg) * output.diffs) / (2.0 * h) 105 | 106 | it.iternext() 107 | numeric_diffs.append(diff) 108 | return numeric_diffs 109 | 110 | 111 | def eval_numerical_gradient_net(net, inputs, output, h=1e-5): 112 | return eval_numerical_gradient_blobs(lambda *args: net.forward(), 113 | inputs, output, h=h) 114 | 115 | 116 | def grad_check_sparse(f, x, analytic_grad, num_checks=10, h=1e-5): 117 | """ 118 | sample a few random elements and only return numerical 119 | in this dimensions. 120 | """ 121 | 122 | for i in range(num_checks): 123 | ix = tuple([randrange(m) for m in x.shape]) 124 | 125 | oldval = x[ix] 126 | x[ix] = oldval + h # increment by h 127 | fxph = f(x) # evaluate f(x + h) 128 | x[ix] = oldval - h # increment by h 129 | fxmh = f(x) # evaluate f(x - h) 130 | x[ix] = oldval # reset 131 | 132 | grad_numerical = (fxph - fxmh) / (2 * h) 133 | grad_analytic = analytic_grad[ix] 134 | rel_error = (abs(grad_numerical - grad_analytic) / 135 | (abs(grad_numerical) + abs(grad_analytic))) 136 | print('numerical: %f analytic: %f, relative error: %e' 137 | %(grad_numerical, grad_analytic, rel_error)) 138 | 139 | 140 | def rel_error(x, y): 141 | """ returns relative error """ 142 | return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y)))) 143 | -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | """ 3 | THIS IS FROM TAKEN FROM STANFORD'S CS231N COURSE, WHICH I HIGHLY RECOMMEND 4 | http://cs231n.github.io/ 5 | 6 | This file implements various first-order update rules that are commonly used for 7 | training neural networks. Each update rule accepts current weights and the 8 | gradient of the loss with respect to those weights and produces the next set of 9 | weights. Each update rule has the same interface: 10 | 11 | def update(w, dw, config=None): 12 | 13 | Inputs: 14 | - w: A numpy array giving the current weights. 15 | - dw: A numpy array of the same shape as w giving the gradient of the 16 | loss with respect to w. 17 | - config: A dictionary containing hyperparameter values such as learning rate, 18 | momentum, etc. If the update rule requires caching values over many 19 | iterations, then config will also hold these cached values. 20 | 21 | Returns: 22 | - next_w: The next point after the update. 23 | - config: The config dictionary to be passed to the next iteration of the 24 | update rule. 25 | 26 | NOTE: For most update rules, the default learning rate will probably not perform 27 | well; however the default values of the other hyperparameters should work well 28 | for a variety of different problems. 29 | 30 | For efficiency, update rules may perform in-place updates, mutating w and 31 | setting next_w equal to w. 32 | """ 33 | 34 | 35 | 36 | 37 | 38 | def sgd(w, dw, config=None): 39 | """ 40 | Performs vanilla stochastic gradient descent. 41 | 42 | config format: 43 | - learning_rate: Scalar learning rate. 44 | """ 45 | if config is None: config = {} 46 | config.setdefault('learning_rate', 1e-2) 47 | 48 | w -= config['learning_rate'] * dw 49 | return w, config 50 | 51 | 52 | def adam(x, dx, config=None): 53 | """ 54 | Uses the Adam update rule, which incorporates moving averages of both the 55 | gradient and its square and a bias correction term. 56 | 57 | config format: 58 | - learning_rate: Scalar learning rate. 59 | - beta1: Decay rate for moving average of first moment of gradient. 60 | - beta2: Decay rate for moving average of second moment of gradient. 61 | - epsilon: Small scalar used for smoothing to avoid dividing by zero. 62 | - m: Moving average of gradient. 63 | - v: Moving average of squared gradient. 64 | - t: Iteration number. 65 | """ 66 | if config is None: config = {} 67 | config.setdefault('learning_rate', 1e-3) 68 | config.setdefault('beta1', 0.9) 69 | config.setdefault('beta2', 0.999) 70 | config.setdefault('epsilon', 1e-8) 71 | config.setdefault('m', np.zeros_like(x)) 72 | config.setdefault('v', np.zeros_like(x)) 73 | config.setdefault('t', 0) 74 | 75 | #print(config['learning_rate']) 76 | 77 | next_x = None 78 | beta1, beta2, eps = config['beta1'], config['beta2'], config['epsilon'] 79 | t, m, v = config['t'], config['m'], config['v'] 80 | m = beta1 * m + (1 - beta1) * dx 81 | v = beta2 * v + (1 - beta2) * (dx * dx) 82 | t += 1 83 | alpha = config['learning_rate'] * np.sqrt(1 - beta2 ** t) / (1 - beta1 ** t) 84 | x -= alpha * (m / (np.sqrt(v) + eps)) 85 | config['t'] = t 86 | config['m'] = m 87 | config['v'] = v 88 | next_x = x 89 | 90 | return next_x, config 91 | 92 | 93 | class AdamOptimizer(): 94 | def __init__(self, params, learning_rate=1e-3): 95 | # Configuration for Adam optimization 96 | self.optimization_config = {'learning_rate': learning_rate} 97 | self.adam_configs = {} 98 | for p in params: 99 | d = {k: v for k, v in self.optimization_config.items()} 100 | self.adam_configs[p] = d 101 | 102 | def apply_gradients(self, params, grads, learning_rate=None): 103 | for p in params: 104 | if learning_rate is not None: 105 | self.adam_configs[p]['learning_rate'] = learning_rate 106 | next_w, self.adam_configs[p] = adam(params[p], grads[p], config=self.adam_configs[p]) 107 | params[p] = next_w 108 | 109 | 110 | 111 | 112 | --------------------------------------------------------------------------------