├── MAML-Sines.ipynb ├── README.md ├── Sine tests.ipynb └── src ├── __pycache__ └── tasks.cpython-36.pyc └── tasks.py /README.md: -------------------------------------------------------------------------------- 1 | # maml-pytorch 2 | 3 | A PyTorch reimplementation of MAML, replicating some of the experiments from [Finn et al (2017): Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400). 4 | 5 | Currently, MAML-Sines.ipynb reproduces (qualitatively) the supervised learning experiment from the paper, on a distribution of sine wave regression tasks. It is designed to be as simple and clear as possible while replicating the behaviour. 6 | -------------------------------------------------------------------------------- /Sine tests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This notebook demonstrates the Sine_Task and Sine_Task_Distribution classes by training a simple model to fit a sampled sine wave." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from src.tasks import Sine_Task_Distribution" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "tasks = Sine_Task_Distribution(0.1, 5, 0, np.pi, -5, 5)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 18, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEElJREFUeJzt3X9sXWd9x/H3Fzcd1hjzWK3SOM1SbVGkQFkjeR0TGpsgKOlASVYBaidYEKBoEpWYmLw16lRBpakdlhhIVBtRQesYrIUuhABhXimV2B+D1SWFkBbTrII1TiEB6gGqB0n63R++aW8s/7jOObnn+j7vlxTlPM95dJ+vIutzbp5z/JzITCRJZXlB0wVIkrrP8JekAhn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGvyQV6JKmC1jMZZddlhs2bGi6DElaVR5++OEfZubwcuNqCf+I2A58CBgA7srMO+ad/zPgXcBZ4GfAnsx8dKnP3LBhA5OTk3WUJ0nFiIjvdTKu8rJPRAwAdwLXAZuBGyNi87xhn8zMqzPzGuD9wAeqzitJunB1rPlfCxzLzCcy8xfAPcDO9gGZ+ZO25i8D7iYnSQ2qY9lnBHiyrX0c+N35gyLiXcB7gEuB1yz0QRGxB9gDsH79+hpKkyQtpGtP+2TmnZn5m8BfAX+9yJh9mTmamaPDw8ver5AkXaA6wn8auLKtva7Vt5h7gF01zCtJukB1hP9DwMaIuCoiLgVuAA62D4iIjW3N1wOP1zCvJOkCVV7zz8wzEXETMMHco54fy8yjEXEbMJmZB4GbImIrcBp4GthddV5J0oWr5Tn/zDwEHJrXd2vb8bvrmEeSVA+3d5CkAhn+klSgnt3bR3Dg8DTjE1OcmJll7dAgY9s2sWvLSNNlSeoDhn+POnB4mr37jzB7+iwA0zOz7N1/BMALgKTKXPbpUeMTU88F/zmzp88yPjHVUEWS+onh36NOzMyuqF+SVsLw71FrhwZX1C9JK2H496ixbZsYXDNwXt/gmgHGtm1qqCJJ/cQbvj3q3E1dn/aRdDEY/j1s15YRw17SReGyjyQVyPCXpAIZ/pJUIMNfkgpk+EtSgQx/SSqQ4S9JBTL8JalAhr8kFcjwl6QCGf6SVCDDX5IK5MZu6irfSyz1BsNfXeN7iaXe4bKPusb3Eku9w/BX1/heYql3GP7qGt9LLPUOw19d43uJpd7hDV91je8llnpHLeEfEduBDwEDwF2Zece88+8B3gmcAU4Bb8/M79Uxt1YX30ss9YbKyz4RMQDcCVwHbAZujIjN84YdBkYz8xXAfcD7q84rSbpwdaz5Xwscy8wnMvMXwD3AzvYBmflgZj7Tan4VWFfDvJKkC1RH+I8AT7a1j7f6FvMO4IsLnYiIPRExGRGTp06dqqE0SdJCuvq0T0S8BRgFxhc6n5n7MnM0M0eHh4e7WZokFaWOG77TwJVt7XWtvvNExFbgFuAPMvPnNcwrSbpAdXzzfwjYGBFXRcSlwA3AwfYBEbEF+AiwIzNP1jCnJKmCyuGfmWeAm4AJ4DHgU5l5NCJui4gdrWHjwIuAT0fEIxFxcJGPkyR1QS3P+WfmIeDQvL5b24631jGPJKkebu8gSQUy/CWpQIa/JBXI8JekAhn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGvyQVyPCXpAIZ/pJUoFp29ex3Bw5PMz4xxYmZWdYODTK2bRO7tiz1pkpJ6m2G/zIOHJ5m7L5vcPpsAjA9M8vYfd8A8AIgadVy2WcZ7/vc0eeC/5zTZ5P3fe5oQxVJUnWG/zKefub0ivolaTUw/CWpQIb/MoYG16yoX5JWA8N/Ge/d8TLWvCDO61vzguC9O17WUEWSVJ1P+yzj3BM9PuopqZ8Y/h3YtWXEsJfUV1z2kaQCGf6SVCDDX5IKZPhLUoEMf0kqkOEvSQXyUU9J6hHd3D6+lm/+EbE9IqYi4lhE3LzA+VdHxNcj4kxEvLGOOSWpnxw4PM3e/UeYnpklmds+fu/+Ixw4PH1R5qsc/hExANwJXAdsBm6MiM3zhv0P8Dbgk1Xnk6R+ND4xxezps+f1zZ4+y/jE1EWZr45ln2uBY5n5BEBE3APsBB49NyAzv9s692wN80lS3zkxM7ui/qrqWPYZAZ5sax9v9a1YROyJiMmImDx16lQNpUnS6rB2aHBF/VX11NM+mbkvM0czc3R4eLjpciSpa8a2bWJwzcB5fYNrBhjbtumizFfHss80cGVbe12rT5LUoW7vIFxH+D8EbIyIq5gL/RuAP6nhcyWpKN3cQbjysk9mngFuAiaAx4BPZebRiLgtInYARMTvRMRx4E3ARyLCt59LUoNq+SWvzDwEHJrXd2vb8UPMLQdJknpAT93wlSR1h+EvSQUy/CWpQIa/JBXI8JekAhn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGvyQVyPCXpAIZ/pJUIMNfkgpk+EtSgQx/SSqQ4S9JBTL8JalAhr8kFcjwl6QCGf6SVCDDX5IKZPhLUoEuabqAuh04PM34xBQnZmZZOzTI2LZN7Noy0nRZktRT+ir8DxyeZu/+I8yePgvA9Mwse/cfAfACIElt+mrZZ3xi6rngP2f29FnGJ6YaqkiSelNfhf+JmdkV9UtSqfoq/NcODa6oX5JKVUv4R8T2iJiKiGMRcfMC538pIu5tnf9aRGyoY975xrZtYnDNwHl9g2sGGNu26WJMJ0mrVuXwj4gB4E7gOmAzcGNEbJ437B3A05n5W8DfAX9bdd6F7Noywu3XX83I0CABjAwNcvv1V3uzV5LmqeNpn2uBY5n5BEBE3APsBB5tG7MTeG/r+D7gwxERmZk1zH+eXVtGDHtJWkYdyz4jwJNt7eOtvgXHZOYZ4H+BX69hbknSBeipG74RsSciJiNi8tSpU02XI0l9q47wnwaubGuva/UtOCYiLgF+FfjR/A/KzH2ZOZqZo8PDwzWUJklaSB3h/xCwMSKuiohLgRuAg/PGHAR2t47fCHz5Yqz3S5I6U/mGb2aeiYibgAlgAPhYZh6NiNuAycw8CHwU+HhEHAN+zNwFQpLUkFr29snMQ8CheX23th3/H/CmOuaSJFXXUzd8JUndYfhLUoEMf0kqkOEvSQUy/CWpQIa/JBXI8JekAhn+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGvyQVyPCXpAIZ/pJUIMNfkgpk+EtSgQx/SSqQ4S9JBTL8JalAhr8kFcjwl6QCGf6SVCDDX5IKZPhLUoEMf0kqkOEvSQUy/CWpQJXCPyJeEhH3R8Tjrb9/bZFx/xYRMxHx+SrzSZLqUfWb/83AA5m5EXig1V7IOPDWinNJkmpSNfx3Ane3ju8Gdi00KDMfAH5acS5JUk2qhv/lmflU6/j7wOUVP0+S1AWXLDcgIr4EvHSBU7e0NzIzIyKrFBMRe4A9AOvXr6/yUZKkJSwb/pm5dbFzEfGDiLgiM5+KiCuAk1WKycx9wD6A0dHRShcSSdLiqi77HAR2t453A5+t+HmSpC6oGv53AK+LiMeBra02ETEaEXedGxQR/wF8GnhtRByPiG0V55UkVbDsss9SMvNHwGsX6J8E3tnW/v0q80iS6uVv+EpSgQx/SSqQ4S9JBTL8JalAhr8kFcjwl6QCGf6SVCDDX5IKZPhLUoEMf0kqkOEvSQWqtLePJK0GBw5PMz4xxYmZWdYODTK2bRO7tow0XVajDH9Jfe3A4Wn27j/C7OmzAEzPzLJ3/xGAoi8ALvtI6mvjE1PPBf85s6fPMj4x1VBFvcHwl9TXTszMrqi/FIa/pL62dmhwRf2lMPwl9bWxbZsYXDNwXt/gmgHGtm1qqKLe4A1fSX3t3E1dn/Y5n+Evqe/t2jJSfNjP57KPJBXI8JekAhn+klQg1/wl9RW3cuiM4S+pb7iVQ+dc9pHUN9zKoXOGv6S+4VYOnTP8JfUNt3LonOEvqW+4lUPnvOErqW+4lUPnKoV/RLwEuBfYAHwXeHNmPj1vzDXA3wMvBs4Cf5OZ91aZV5IW41YOnam67HMz8EBmbgQeaLXnewb408x8GbAd+GBEDFWcV5JUQdXw3wnc3Tq+G9g1f0BmficzH28dnwBOAsMV55UkVVA1/C/PzKdax98HLl9qcERcC1wK/HfFeSVJFSy75h8RXwJeusCpW9obmZkRkUt8zhXAx4HdmfnsImP2AHsA1q9fv1xpkqQLtGz4Z+bWxc5FxA8i4orMfKoV7icXGfdi4AvALZn51SXm2gfsAxgdHV30QiJJqqbqss9BYHfreDfw2fkDIuJS4DPAP2XmfRXnkyTVoGr43wG8LiIeB7a22kTEaETc1RrzZuDVwNsi4pHWn2sqzitJqiAye3N1ZXR0NCcnJ5suQ2qcWxRrJSLi4cwcXW6cv+Er9TC3KD6fF8L6uLeP1MPcovh55y6E0zOzJM9fCA8cnm66tFXJ8Jd6mFsUP88LYb0Mf6mHuUXx87wQ1svwl3qYWxQ/zwthvQx/qYft2jLC7ddfzcjQIAGMDA1y+/VXF3mT0wthvXzaR+pxblE8x73662X4S1o1vBDWx2UfSSqQ4S9JBTL8JalAhr8kFcjwl6QCGf6SVCDDX5IKZPhLUoEMf0kqkOEvSQXq2dc4RsQp4HsNl3EZ8MOGa7gQ1t1d1t1d1r2038jM4eUG9Wz494KImOzkXZi9xrq7y7q7y7rr4bKPJBXI8JekAhn+S9vXdAEXyLq7y7q7y7pr4Jq/JBXIb/6SVCDDfwkRMR4R346Ib0bEZyJiqOmaOhERb4qIoxHxbET0zNMFi4mI7RExFRHHIuLmpuvpVER8LCJORsS3mq6lUxFxZUQ8GBGPtn5G3t10TZ2IiBdGxH9FxDdadb+v6ZpWIiIGIuJwRHy+6VrOMfyXdj/w8sx8BfAdYG/D9XTqW8D1wFeaLmQ5ETEA3AlcB2wGboyIzc1W1bF/BLY3XcQKnQH+IjM3A68E3rVK/r1/DrwmM38buAbYHhGvbLimlXg38FjTRbQz/JeQmf+emWdaza8C65qsp1OZ+VhmTjVdR4euBY5l5hOZ+QvgHmBnwzV1JDO/Avy46TpWIjOfysyvt45/ylwg9fxLcXPOz1rNNa0/q+KGZUSsA14P3NV0Le0M/869Hfhi00X0oRHgybb2cVZBGPWDiNgAbAG+1mwlnWktnTwCnATuz8xVUTfwQeAvgWebLqTdJU0X0LSI+BLw0gVO3ZKZn22NuYW5/y5/opu1LaWTuqXFRMSLgH8F/jwzf9J0PZ3IzLPANa17b5+JiJdnZk/fb4mINwAnM/PhiPjDputpV3z4Z+bWpc5HxNuANwCvzR56Lna5uleRaeDKtva6Vp8ukohYw1zwfyIz9zddz0pl5kxEPMjc/ZaeDn/gVcCOiPgj4IXAiyPinzPzLQ3X5bLPUiJiO3P/XduRmc80XU+fegjYGBFXRcSlwA3AwYZr6lsREcBHgccy8wNN19OpiBg+97RdRAwCrwO+3WxVy8vMvZm5LjM3MPez/eVeCH4w/JfzYeBXgPsj4pGI+IemC+pERPxxRBwHfg/4QkRMNF3TYlo31G8CJpi7+fipzDzabFWdiYh/Af4T2BQRxyPiHU3X1IFXAW8FXtP6mX6k9a20110BPBgR32TuC8P9mdkzj02uRv6GryQVyG/+klQgw1+SCmT4S1KBDH9JKpDhL0kFMvwlqUCGvyQVyPCXpAL9P3nR2GWlgA6lAAAAAElFTkSuQmCC\n", 40 | "text/plain": [ 41 | "
" 42 | ] 43 | }, 44 | "metadata": {}, 45 | "output_type": "display_data" 46 | } 47 | ], 48 | "source": [ 49 | "sine = tasks.sample_task()\n", 50 | "plt.scatter(*sine.sample_data(10))\n", 51 | "plt.show()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 6, 57 | "metadata": { 58 | "scrolled": true 59 | }, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "image/png": "\n", 64 | "text/plain": [ 65 | "
" 66 | ] 67 | }, 68 | "metadata": {}, 69 | "output_type": "display_data" 70 | } 71 | ], 72 | "source": [ 73 | "# set up model and task\n", 74 | "model = nn.Sequential(\n", 75 | " nn.Linear(1,40),\n", 76 | " nn.ReLU(),\n", 77 | " nn.Linear(40,40),\n", 78 | " nn.ReLU(),\n", 79 | " nn.Linear(40,1)\n", 80 | ")\n", 81 | "task = tasks.sample_task()\n", 82 | "optimiser = torch.optim.Adam(model.parameters(), lr=0.01)\n", 83 | "criterion = nn.MSELoss()\n", 84 | "\n", 85 | "# fit the model\n", 86 | "losses = []\n", 87 | "\n", 88 | "for i in range(1000):\n", 89 | " model.zero_grad()\n", 90 | " x, y = task.sample_data(10)\n", 91 | " y_hat = model(x)\n", 92 | " loss = criterion(y_hat, y)\n", 93 | " loss.backward()\n", 94 | " optimiser.step()\n", 95 | " losses.append(loss.item())\n", 96 | " \n", 97 | "# plot the result\n", 98 | "x = np.linspace(-5, 5, 100)\n", 99 | "y = model(torch.tensor(x, dtype=torch.float).view(-1, 1))\n", 100 | "\n", 101 | "plt.figure(figsize=(15,5))\n", 102 | "\n", 103 | "plt.subplot(1, 2, 1)\n", 104 | "plt.plot(x, task.true_function(x), '--', label='true function')\n", 105 | "plt.plot(x, y.detach().numpy(), label='model')\n", 106 | "plt.legend(loc='lower right')\n", 107 | "plt.title(\"Model fit\")\n", 108 | "\n", 109 | "plt.subplot(1, 2, 2)\n", 110 | "plt.plot(losses)\n", 111 | "plt.title(\"Loss over time\")\n", 112 | "plt.show()" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "Python 3", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.6.4" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 2 144 | } 145 | -------------------------------------------------------------------------------- /src/__pycache__/tasks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vmikulik/maml-pytorch/f4a6c070d10f9fa3005d446cad38bba5f12f1308/src/__pycache__/tasks.cpython-36.pyc -------------------------------------------------------------------------------- /src/tasks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | class Sine_Task(): 7 | """ 8 | A sine wave data distribution object with interfaces designed for MAML. 9 | """ 10 | 11 | def __init__(self, amplitude, phase, xmin, xmax): 12 | self.amplitude = amplitude 13 | self.phase = phase 14 | self.xmin = xmin 15 | self.xmax = xmax 16 | 17 | def true_function(self, x): 18 | """ 19 | Compute the true function on the given x. 20 | """ 21 | 22 | return self.amplitude * np.sin(self.phase + x) 23 | 24 | def sample_data(self, size=1): 25 | """ 26 | Sample data from this task. 27 | 28 | returns: 29 | x: the feature vector of length size 30 | y: the target vector of length size 31 | """ 32 | 33 | x = np.random.uniform(self.xmin, self.xmax, size) 34 | y = self.true_function(x) 35 | 36 | x = torch.tensor(x, dtype=torch.float).unsqueeze(1) 37 | y = torch.tensor(y, dtype=torch.float).unsqueeze(1) 38 | 39 | return x, y 40 | 41 | class Sine_Task_Distribution(): 42 | """ 43 | The task distribution for sine regression tasks for MAML 44 | """ 45 | 46 | def __init__(self, amplitude_min, amplitude_max, phase_min, phase_max, x_min, x_max): 47 | self.amplitude_min = amplitude_min 48 | self.amplitude_max = amplitude_max 49 | self.phase_min = phase_min 50 | self.phase_max = phase_max 51 | self.x_min = x_min 52 | self.x_max = x_max 53 | 54 | def sample_task(self): 55 | """ 56 | Sample from the task distribution. 57 | 58 | returns: 59 | Sine_Task object 60 | """ 61 | amplitude = np.random.uniform(self.amplitude_min, self.amplitude_max) 62 | phase = np.random.uniform(self.phase_min, self.phase_max) 63 | return Sine_Task(amplitude, phase, self.x_min, self.x_max) --------------------------------------------------------------------------------