├── .gitignore ├── LICENSE ├── README.md ├── fig_permuted_mnist ├── Figure permuted MNIST.ipynb ├── Permuted MNIST.ipynb └── data_path_int[omega_decay=sum,xi=0.1]_optadam_lr1.00e-03_bs256_ep20_tsks10.pkl.gz ├── fig_split_mnist └── Figure split MNIST.ipynb ├── fig_transfer_cifar ├── Figure barplot.ipynb ├── split_cifar10_data_path_int[omega_decay=sum,xi=0.001]_lr1.00e-03_ep60.pkl.gz └── train.py └── pathint ├── __init__.py ├── keras_utils.py ├── optimizers.py ├── protocols.py ├── regularizers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ben Poole & Friedemann Zenke 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continual Learning Through Synaptic Intelligence 2 | 3 | This repository contains code to reproduce the key findings of our path integral approach to prevent catastrophic forgetting in continual learning. 4 | 5 | Zenke, F.1, Poole, B.1, and Ganguli, S. (2017). Continual Learning Through 6 | Synaptic Intelligence. In Proceedings of the 34th International Conference on 7 | Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention 8 | Centre, Sydney, Australia: PMLR), pp. 3987–3995. 9 | 10 | http://proceedings.mlr.press/v70/zenke17a.html 11 | 12 | 1) Equal contribution 13 | 14 | ## BibTeX 15 | ``` 16 | @InProceedings{pmlr-v70-zenke17a, 17 | title = {Continual Learning Through Synaptic Intelligence}, 18 | author = {Friedemann Zenke and Ben Poole and Surya Ganguli}, 19 | booktitle = {Proceedings of the 34th International Conference on Machine Learning}, 20 | pages = {3987--3995}, 21 | year = {2017}, 22 | editor = {Doina Precup and Yee Whye Teh}, 23 | volume = {70}, 24 | series = {Proceedings of Machine Learning Research}, 25 | address = {International Convention Centre, Sydney, Australia}, 26 | month = {06--11 Aug}, 27 | publisher = {PMLR}, 28 | pdf = {http://proceedings.mlr.press/v70/zenke17a/zenke17a.pdf}, 29 | url = {http://proceedings.mlr.press/v70/zenke17a.html}, 30 | } 31 | ``` 32 | 33 | 34 | ## Requirements 35 | 36 | We have tested this maintenance release (v1.1) with the following configuration: 37 | 38 | * Python 3.5.2 39 | * Jupyter 4.4.0 40 | * Tensorflow 1.10 41 | * Keras 2.2.2 42 | 43 | Kudos to Mitra (https://github.com/MitraDarja) for making our code conform with Keras 2.2.2! 44 | 45 | 46 | ### Earlier releases 47 | 48 | For the original release (v1.0) we used the following configuration of the libraries which were available at the time: 49 | 50 | * Python 3.5.2 51 | * Jupyter 4.3.0 52 | * Tensorflow 1.2.1 53 | * Keras 2.0.5 54 | 55 | To revert to such a environment we suggest using virtualenv (https://virtualenv.pypa.io): 56 | 57 | ``` 58 | virtualenv -p python3 env 59 | source env/bin/activate 60 | pip3 install -vI keras==2.0.5 61 | pip3 install jupyter matplotlib numpy tensorflow-gpu tqdm seaborn 62 | ``` 63 | -------------------------------------------------------------------------------- /fig_permuted_mnist/Figure permuted MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Copyright (c) 2017 Ben Poole & Friedemann Zenke\n", 12 | "# MIT License -- see LICENSE for details\n", 13 | "# \n", 14 | "# This file is part of the code to reproduce the core results of:\n", 15 | "# Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through\n", 16 | "# Synaptic Intelligence. In Proceedings of the 34th International Conference on\n", 17 | "# Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention\n", 18 | "# Centre, Sydney, Australia: PMLR), pp. 3987–3995.\n", 19 | "# http://proceedings.mlr.press/v70/zenke17a.html" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": { 26 | "collapsed": false 27 | }, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "The autoreload extension is already loaded. To reload it, use:\n", 34 | " %reload_ext autoreload\n", 35 | "Populating the interactive namespace from numpy and matplotlib\n" 36 | ] 37 | }, 38 | { 39 | "name": "stderr", 40 | "output_type": "stream", 41 | "text": [ 42 | "Using TensorFlow backend.\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "%load_ext autoreload\n", 48 | "%autoreload 2\n", 49 | "%pylab inline\n", 50 | "\n", 51 | "import sys, os\n", 52 | "sys.path.extend([os.path.expanduser('..')])\n", 53 | "from pathint import utils\n", 54 | "import seaborn as sns\n", 55 | "sns.set_style(\"white\")\n", 56 | "\n", 57 | "# import operator\n", 58 | "import matplotlib.colors as colors\n", 59 | "import matplotlib.cm as cmx\n", 60 | "\n", 61 | "rcParams['pdf.fonttype'] = 42\n", 62 | "rcParams['ps.fonttype'] = 42" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 6, 68 | "metadata": { 69 | "collapsed": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "# Load data\n", 74 | "n_tasks = 10\n", 75 | "all_evals = utils.load_zipped_pickle(\"data_path_int[omega_decay=sum,xi=0.1]_optadam_lr1.00e-03_bs256_ep20_tsks10.pkl.gz\")" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 7, 81 | "metadata": { 82 | "collapsed": false 83 | }, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "[ 0. 0.01 0.02 0.1 ]\n" 90 | ] 91 | } 92 | ], 93 | "source": [ 94 | "keys = list(all_evals.keys())\n", 95 | "sorted_keys = np.sort(keys)\n", 96 | "print(sorted_keys)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 8, 102 | "metadata": { 103 | "collapsed": true 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "sns.set_context(\"paper\")\n", 108 | "sns.set_style('ticks')" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 9, 114 | "metadata": { 115 | "collapsed": true 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "plt.rc('text', usetex=False)\n", 120 | "plt.rc('xtick', labelsize=8)\n", 121 | "plt.rc('ytick', labelsize=8)\n", 122 | "plt.rc('axes', labelsize=8)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 10, 128 | "metadata": { 129 | "collapsed": true 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "def simple_axis(ax):\n", 134 | " ax.spines['top'].set_visible(False)\n", 135 | " ax.spines['right'].set_visible(False)\n", 136 | " ax.get_xaxis().tick_bottom()\n", 137 | " ax.get_yaxis().tick_left()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 11, 143 | "metadata": { 144 | "collapsed": true 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "def plot_data():\n", 149 | " marker = iter(['o', 's', 's', 'd', 'o'])\n", 150 | " plot_kwargs = dict(alpha=0.9)\n", 151 | " \n", 152 | " for cval in sorted_keys:\n", 153 | " stuff = []\n", 154 | " for i in range(len(all_evals[cval])):#n_tasks):\n", 155 | " stuff.append(all_evals[cval][i][:i+1].mean())\n", 156 | " plot(range(1,n_tasks+1), stuff, '%s-'%next(marker), label=\"Test (c=%g)\"%cval, zorder=2, **plot_kwargs)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 12, 162 | "metadata": { 163 | "collapsed": false 164 | }, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAC7CAYAAACqwUiwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXl8FEXa+L/dcyWTyUkOQkI4wmE4FFA88IYVr8VbuURR\ndJVVQBSEXXFxfyoriq7Rd19dUTlFI554i6CvaESQ04CBJJCQALkTckwyV/fvj55MMpCQBDJJyNT3\n8+lPV1d31zwzyfNU1VNVT0mqqqoIBAK/Q+5oAQQCQccglF8g8FOE8gsEfopQfoHATxHKLxD4KUL5\nBQI/RSi/QOCnCOUXCPwUofwCgZ8ilF8g8FOE8gsEfkqXUH6n00leXh5Op7OjRREIzhi6hPLn5+cz\nZswY8vPzO1oUgeCMoUsov0AgaD0+Uf6CggJuvvlmhg4dekJTfP/+/UycOJEJEyaQnp7eZJ5AIPAt\nPlH+sLAwli9fzrBhw064l5yczEsvvURycjLJyclN5gkEAt+i90WhJpMJk8nU6L2KigpiY2MBqKys\nbDJPIBD4Fp8o/8lQFMWTrgsi1FheU6SkpJCSkuKVZ7fb21BCgcA/aHfllyTJk5Zlucm8phg/fjzj\nx4/3ysvLy2PMmDFtKKVA0PVpd+UPDQ0lPz8fSZIICgpqMk8gEPgWnyi/w+Hg/vvvJz09nWnTpvHQ\nQw+xbds2pk+fzowZM3jkkUcAWLhwIUCjeQKBwLdIXSF6b12zf8OGDcTHx3e0OALBGYGY5CMQ+ClC\n+QUCP0Uov0DgpwjlFwj8FKH8AoGfIpRfIPBThPILBH6KUH6BwE9p9+m97c6SgQBYC7RLc4w7f86+\n0yrWun2HVt6I4adVjkDQUXR55R8dYQDgvh9dALyZpANg42mWW7psGdA2yi8MiaAj6PLKD9DnsEqf\nw/Xpg3HSyV9oBuv2HdTs3OlJn67StqUhEQhail8o/5itChKgd8KtGxT+b4RE+YcfASqqywWKgqoo\noKiguFBdCqgN0+777nTFV1/jqq5GAo4++SRhN98Eej2SwaAdevfZoEcyGpH0eu1w38dz34AtYz/W\n7dtBltvEkIhWhKCl+IXyA+hcEGiHuCK44UeVfalPodcZMeiNGAwBGHRGJFkGnYwkyaDTgSx50pIs\ngySBTsZVXg4OBwrgLC7i2JdfgsOJ6nB4Dlq4Xsp17BiqO85h7gMPEHTxxRh7xmOI74khPg5jQgKG\n7t01o9EC2roVIYxJ18UvlP+78yT6HFapMMN/b4CMnhI3Rl9EhlHhwLEDqKpKkMFAUrckBncbzJDI\nIZwVcRZBhsZjC1i37+DwrFkAxCUnN6oYqsuF6jzOINSl6/KdTgqffwFbRgaqoqCPikI2m6nZuYuK\nL77U3gHQ6TB0746hZ0+PYTD2jMfQsyf66GjNMNH23REQXZKuTJdXfntgJFvjHVzWoxIV2BYfTJBq\nZdae7zFdtwjrldeRXppOWnEae0r28MH+D1ixZwWSJNEntA9Dug1hcORgBncbTGxQLGPWahGD7ovS\nYg0+kPkYZMLGO7xdiJJOh6TTQROxDOuImj2bw7NmIQGxTz/tUTLV5cJZWIgjLw/7oVwch/Ow5+ZR\nnfoLjvyPwaU5MCWDAUN8PIb4eGp27UKx2ZBkmaLkZLo/tRDZHIQcZEYODNTkaQVtbUxEK6Jz0eWV\n/5Nx33Dj//zMt30zqbI5qcnqTYjFgOvsDfDtAsySxIghtzIiZgQAiqqQU5HDnpI97Cnew/bC7azL\nWgdAeEA4lfZK9LKeby40oZdbp0yNYR4xnEB3lOOGSiHpdBhiYzHExmIeOdLrHdXhwJGfj/3QIRx5\nh3Hk5WLPzcNZWIhSXQ1AzY4dHLrrbq/3JJMJOSgIOTAQ2WzWjILZjBQYqOWbzciBZs+9snfeQXXY\nQZIp/s9/6PH8YnTBwUhG4yl9V9GK6Fx0eeWPtJiYOaYfr2yAWocL1ergnIQIzNctAoMOvnlCe3DI\nrQDIkkyf0D70Ce3Dn/v+GYDy2nL2luwlrSSN/+76L1anlZ0xgAQ6Ww06SUdKegr9w/vTL7wfIcaQ\nVskYcc89rXpeMhgw9uyJsWdPr3zr9h3kzZoJikrMvMcx9uqFYrWi1NSgVFtRqqvd11btbLWiWmtw\nlZTiOJRb/6zVilpb6+WPsG7dysGbbtY+32RCDg5GFxysnS0W5JAQdMEWZEswupBgZEswcrAFXUgI\nssWCPTsb686dSIhWRGehyys/wPiRCYxJimF/QSU/ZRTz6c4j5JbX0vOqp7UHvnkCkGDILY2+HxYQ\nxqi4UYyKG8V76e+houJSXDhVJ07FiUtxsWzPMuwuLYpwjDmGfuH96BfWj/7h/RkQPoBuAd28ApUC\njH5/tPcHZWqn47sQLcU8YjjmYZoyhN5wwymVUYfqclG9eTNHHp8HqkrkrJkYe/RAqazEVVGJUlWJ\nq7IKpbICV2UVjtxcaisrtftVleDw3qzFY0gkibyHHiJo1Cj0UVHaER1Vn46KQg4OPuG3Oh4xz+L0\n8QvlB60FEGkxMbxnOD9mFPHit/v49/hhSFc9rXnmv/m79mATBqAhEhJ6WY8ePbhb/p/f/DmHKg+R\nWZZJZrl2fLD/A6odWjM81BRK/zCtZdA/rD/9w/v75Hu2thXRFJJOh+XiizGP0LpDEcdFTD4Zqqpq\nLQe3MVAqKyl4/gVsmZmgKOhCQ7VhzsxMqn/5BVdpqfdnm0xuoxCNPjKygZGIRh8ViSM/X8yzaAP8\nRvnrCDTqmP2nAfzto9/5MaOYywdEwdhnALcBkGQYfFOry9XLevqG9qVvaF/GMhbQlOBo9VEyyzPJ\nKMsgszyT9TnreS/9PQBKa0vRSTr0sh6drEMvaefToa3/gU/FmEiSpPkRAgMhOhqA6DlzPCMksYsW\necmpOhw4i4txFhVp58JCnIVFOIuKcBw+TM3OnThLSjxOTk8rQpbJ/etfCTx7qObYNJtPPIIapoM0\nudxp2WymNj29zUdIzhT8MoCnqqrMeHcHuaVWUh64iACDDhSX1vzf+wlcuxgG3egzeUtqSsgsz+SR\n7x/BqThxqk6vjUsGRAwgITiBXiG96B3aWzuH9CbUFNpkmSd0IdycahfCF+TNmAlA/KuvtPpd1eXC\nVVaGs7CQ/Kef8bQi9NHRWC67FMVa08Cn4Z3mJFu3u44d0yZ6SRL6bt0Iu+02TIl9MSYmYuzdG7mZ\n0ZrGOFO6EX5X84NWM829eiCTlv7KitRsHrg8EWQdXP0soMJX8wEJBp1ev7kpugV2o1tgNwL1gZ68\nOj+CS3Vxde+ryT6Wzdb8rXya9alnF6NQUyi9Q3qTEJJA7xDNKCSEJNAtoJtP5GxrTqdLIul0Whcg\nMpLouXPrWxHPPntSJVNVFdXh0IxBtRW1xuplGIpfex17drbW9ZNlqjZtovz997WXZRlDfDymxETN\nIPTVzvru3T1zKxrjTOlG+Ez5Fy1aRFpaGoMGDWLBggWe/J9//pnk5GRMJhNPPfUUiYmJPPvss57d\nedPT09m6dauvxPLQq1sQky5IYOUvOVw3NJaeEWa3AVik/SN8NU+b0Zc0zueygLcf4b6h93ny7S47\neZV5ZFdkc6jiENkV2ewu2s2XB7/EpWjN4CBDEMdsx9DJOnSSdsiSjE7Woapqs86z9qKtlKGp4dHG\nkCQJyWhENhohPPyE+7rwiPoJWy++iHnEcFxV1dgPHsR+IAtb1gFsB7Kwbt2KUlWllWkOxNSnL8bE\nvpj6JmLql4gxMRGdxXJGzY3wifLv2bMHq9XKmjVrWLhwIbt37+bss88G4D//+Q/Lly+nqqqKRYsW\n8fLLL/PEE9pw2969e3n77bd9IVKj3HtxH75Oy+el9ft56Y5zNCWRdXDNv7QHvnxcO/vIALSkSW7U\nGekb1pe+YX298h2KgyNVR8ipyCH7WDav7ngVp+LEptqgQUfuzx//mWhzNFHmKKIDtXOMOaY+zxyN\nSXdi07azdyPayrHZmCHRWYIIHDqEwKFDPM+pqoqzsAh7Via2AwexZWVSm7aHii+/8nQr9NHROEtL\nUKxWkGUKl7xAzN+fQB8ehq5bN+SAgFbL58tWhE+Uf+fOnYwaNQqAUaNGsXPnTo/yA5jNZsxmM4cO\nHfJ6b/369YwdO9YXIjVKoFHHI3/qz98++p1NGcVcNiBKu+ExAKrbAEiQ9Od2k6slGGQDvUJ60Suk\nF5fFX8byPcsBrfugqIrnmJw0maKaIgqthewr28dPh3+iwl7hVVaIMcRjCGLMMUSZo7C5bFrrwd2K\n6Gy0pTK0xJBIkoQhJhpDTDRB7v9tANVux56biy0zC/uBLMpS3kex2UBRqE3bQ9706fVlmAPRh0eg\ni4hAFxHuSesjwrW88Pq0HBDgk+naDfGJ8ldWVtLTPQElODiYjIwMr/vFxcUcO3aMAwcOeOVv2rSJ\nv/zlLyctu6136R19VjTn94ngxfX7Ob9PhOb8A7cBeE7rAnw5V+sCnHX9KX9OeyEheZr+AJOSJp3w\nTK2z1mMQiqxFFFgLKLIWUVhTyI7CHRTVFFFlr/I8r5N1BOgCMOpObWZfZ+d0lEoyGt0+gUStrAsu\n9HQjuj/9/zDGx+MsKcVVVoqztBRXaRmu0lKcZaXU5ObhKi3FdezYCQvBZLMZZ3k5ak0NkslE6bJl\nZ4byBwcHU+XuH1VVVRESUj/jbe7cucyePZu4uDhGuMeQAbKzs4mJiSEwMPCE8hrS1rv0NnT+rfwl\nm79cllh/U9Zpnn9U+GIOIMFZ153S5/ia1jTHA/QB9AzuSc/gnk0+c+X7V3ockDaXjWpnNdXOaham\nLuTq3lczsvtIDHLLVhr6Ew27EcFXXAGAqd/J31GdTlzl5dpoRmkZrtISnGVllK1+B8fRo+Ajl41P\nlH/YsGGkpKRw3XXXkZqayi231E+cGT58OKtWrSI7O5vVq1d78tevX89VV13lC3Gapc75tyI1h2uH\nuJ1/dcg6uGaxuwUwR8vrpAagLWnogDTpTCiqgs1l43DlYRb8tICwgDDGJIzhmt7XkBiW2HyBfkSr\np2vr9Z6RjIbel4CzkjytiLbycTTEJ8o/ePBgjEYjkyZNIikpidjYWF577TWmT5/Oa6+9RmpqKuHh\n4fzzn//0vPPDDz/wv//7v74Qp0U06vyrQ6eHa58HVEiZDAGhcLyT7DRjAnZ2ZEkmUB/I0rFLySzP\n5Ovsr/ku5zs+3P8hfcP6cnXvq/lTwp8IDzjRo+5vdMSoxqngl5N8mmJjegHzP/ydJbefU+/8a4jL\nCc/Fg9N2ogHo4srfGA7Fwa9Hf+Wb7G/YfHQzAOd3P5+re1/NRbEXYdCJbsHpcsYN9Z2pjD4rhi1P\nxDT9gE4PplBo/aSvLolBNnBJ3CVcEnfJSZ/r7MOGnRlfThTqEjW/0+kkPz+f7t27o9cLeyYQtIQu\nofwCgaD1dL7ZGwKBoF0Qyi8Q+ClC+QUCP0Uov0DgpwjlFwj8FKH8AoGfIpRfIPBThPILBH6KUH6B\nwE/xifIXFBRw8803M3ToUJzHRU7dv38/EydOZMKECZ64fY3ltQan00leXt4JnyUQCE6C6gNqa2vV\n8vJy9c4771QdDofXvb/+9a/qkSNH1Pz8fPXBBx9sMq815ObmqgMGDFBzc3PbRH6BwB/wySoYk8mE\nqYl45xUVFcTGxgJauK+m8tqatlwa2dbLLM+UOO+CrkW7L4FruDmF6l5T1FheWzH6bS0C632faGGu\n37xJi2238d60Uy6zrSOqdtZ95zprWW1dnr8a33ZX/oYRcmT3xgeN5TXFqQTw7HNYpc9hLd07T+Fg\n/Km5Oka/P5o+h+zct7kMgIVLLuFggvGU16W3dXnQtoaks5bV1uX5g/FtjHZX/tDQUPLz85EkiaCg\noCbzmuJUAniO2aIgq2B0wF1fquzs5+LrrddjNJrR6c3uzS7cYaqR6ze/QEaWZXRo92/PKWdwhp3A\nWq11cveH5eztbyJ/3zNadF9JAqmhMXOfPfca3JckbkivYNjeWgJsWnmTPy3n12FmSqqXI+n1SAY9\nkl4Pej2S3tB4nvv6ge8fIrbAwa2plajASwsv4lAPA++Oew9kWduDUELbaUaSQJbd+xTI2oGEJGvX\n13x0Lb3y7Ez7pRw4faPki40s2qq8tpatsxqlxvDJen6Hw8H999/Pnj17GDRoEA899BDbtm1j+vTp\npKen89RTTwGwcOFCkpKSGs1rDScL43Xl20O4/2MX/fLAZAdrABzsDmZV29tCkUBBRpVA9SiopgxI\n9ftfqBLYnTb6HQFLrZZXFQBZPcBisCAjISMjS7I7LSE1SMs0yHefj1YdIb7QRaBNMxO1RigJ09Ej\nIAbJqSC5FCSXC8npAiQ0myK5TYr77DY0pbWlmGsU9FrvBqcOrIEykaewlVdxbckJZVWbZaIsMdqn\n6nQgS0iyDiQJSSdrgU4bGhSdjCTJpJXuIabY6TFw1gCJQz2MXBg/CkmvA50eSZaRDHqQdVqerEPS\n6UCvQ6rL0+m1z9HrqfjyK5xHjwJgiIsj9Jab3UZNBln7Gza8ltzGDVn2GDgkLV2yfDn27BwATImJ\nRM14GMlo1AyswYBsNCIZDEgGA7jPksGIbHRfN2i1WrfvqN/9Jzn5tI1SW5XVFF0imMfJlH/020Po\nnadw/6fa11x6o8SBOIle5QuxWPOIo4DBgWUMMBURpeSjd5TglCRsEthMQdhD47AHx2KzRDP70Mf0\nzYPpn2o+iv/cKJMVL/Hg+XNxKA4cLgcOxYHNZdOuFQd2l92TX3dtV+w4FAd7i/fSL8/F7I8cqMDL\nN+vZHy/TLfA4hVVVZAV0Cuhcqucsu896F1RYS5j6nUJCofs3iYT3L5PpYYlHh4zefUhI6NFaNTok\n9JIO2Z0nq6BDJq3od67b7KB7qfY9C8NlNow0cVXCnzBJBkyyESN6jLIBk2TAgE47S3p0SKguBVQF\nVVFY8fsyzv29hrAKraxjwTJpAwK4rd8t4HKiOl2oLic4XaiKouW5lAZ5Lu3sqn/OduCAtisOIAcE\nYOjRw/15KigKqCqqqoCigsuFiqql3fca4tnxFy2Kri606c1QG8Wg9xgHZ0GhtjGoJKGzWDANHKgZ\nHp3bmMmyZsB0es0IuQ2fdl92Gz3tXPXjjziLipB0OoIuuuiUNjdtDr+IeXUwXuZgD60aOxAnIUkS\nyx+9ncPlNWw/VM72nDLeOlRGflUtJtXGBeFVjIqoYEhgKb3lQgIqc+HQdoxmlbwecFAbmOBID5VA\nVeXu4gIIjoXg3hDcHUJ6QGC4p1ZuitFvD6GkO2T30K5LYlUiXC4+v/lznKrTYzScitNjPOqunYoT\nu8vuSS/YMIN1l8j81W2YPr5M5lC8xJ8vuBOX4sKpOnEpLhRVwaE4UFQFl+rSylZdWN0x+p2Kky3Z\nf1AWLDHjQxsAK8bq2R+nsj/4D6wOq6ZMTWCQDZgNZoIMQZj1ZraFKGxLMHjKevOGAA4kGDnn0kuJ\nCIggLCCMcFM4gfrAFu0pqPlJDNyXon3PN8ebOJhQ2eIuiaqqmgFQFK5aexV9D1m4d63WvVl9YxCH\nYyVSrlmD6rCjOhyodgc4Hdpmn3YtD6cT1Z2uO17d8jKjKmroZtP+z8p1Fewy/M7EAeOhzhjWGTOX\n4jFqqAqqzYZyXL5itWqf08AZ3tZ0eeWv8+ofHfgrh0qtfHDlKCIt2jBkfLiZ+HAzN5yjad+R8hq2\n5ZSx41A5Kw6VcbS8BoA+kUGM6BeGevQ2AL4bCZJbBSQV2P8NVBWAy1H/wXqj2yDEasYguHv9dXCs\ndu1mw8h6B6QEmA3ufQNcDrBVatGCbdVa2l4FtgqwVTW4rsSkKhyOcxsmCY7EQaCqMkUfA2EJENYT\njCf3p9SReiSVgkTITtAckUWJ4YQDn938GaqqUuOsweq0YnVYqXJUYXVYPdfVjmotz32940gqGbGQ\n0UMzGGkxtai2WhamLvT6TKPOqBkDUxjhAeGEm8IJCwgjIiDCkw43haOoCgcTjBzsqe0edDChdbsI\n1flbkGUUnURmHxMH3GXsHajtpWeMj2tVmQA/6JaSE2fgvhTtN1t9UxgHE4w8csfcVpcF3s1+X8Ts\nhxYo/z333MMyt+MB4NFHH+Wll17yiTC+JPbiC4ht5pkeYYH0CAtknNsY5B+rZfuhMrbllPHrwVLe\nL6xGcTcbJSSkXE1Zi+d8SaTZANZiqDgClflQebT+KM6Ag5ugusjr8zZWF2l95bq98A67m6WvXawp\nuPMkoxg6A5iCwWjRzm42nO9O1DVv182of8fczW0I3MbAk+7l3VKp0voOG4arXtegKY/ZYNYM1Mk3\nVwLgi93a/86PbgMXoWh+lLXj1lJuK6estowyW5l2ri2j3FZOaW0pGeUZnmunUj9zs6y2DCT4+Dxt\n9+Fym1ZrP7j+QfSyHoNsQCfrMMiG+mtJp6V1BvSS3nPP6tS6Dl9eoKlBjbMGCYl1Wes8+xRKkoRe\n0mu+HPfh2QXZfZYlGYfiYH+8xIF4razMnjokFJyKE73cujq2LtrxfVHanJcHMh+DzLaPdtykVJs3\nb2bz5s3k5OSQnJwMgMvlorCwsKlXuhzdQwO4bmgs1w3VzEbtc3oqap2gahtiqirIkkRGQSWRiZFg\nidaOpnDaoSpfMw4VR2Ddw+5+qLtpJ+s1Q3DOhHqlrjuOv9Z7T6LauGSglqhTyDL3+cGf4NghKMuB\nY7lQfgjKsuHg/0FNeX0BxiC3QejFxvwyzSgZ3Yap2P1MwR5QXKBqzVVNdndaVRq/5+4iHPRUpioS\n0E020i20L4SdvKmvqipVjiqPkfjr55NRgdxYrWSDU2tm9wvrh1N1erpBdf4Vq8Pq1VWq6065VBc2\nexUqkBbjltJta1/e9vJJZWqMilrtB//kXO37lteUADD2g7EE6AOwGCwEGYI8h8VgIcgYRJA+CIvR\n4skzG8w4FAcSEt9d1LKu0KnSpPL37NkTWZbJzc3loosu0h7W65vdSLMrU/XwHm7+n59xuBRcikpF\nrQNFhZcqbC0rQG+sr20BvprX+HOjZjSefypYorQj7twT79VW1BuE8lwoz9HSztp6g9SQVbecmNcM\nG6ubqCxePQ8MgWCJcRvNmAbpaE9aCoom2BhMsDGYBBIwNeFumDNyTusEU1yMXn5O4zJfsxoFcIHm\nG0FFARRUXKi4VBUVFRfgUhUU4M5PbkQFCuI0ZQ12jyY9et6jVNmrqHZUe44qRxUltSXkVOZgdVip\ntFdidVo9E9zqDMm2KK1lGVF1gohtQpPKHxcXR1xcHCUlJZx/vtaWVFWVr7/+mmuvvdY30nRyIi0m\nZo7pxysbMgGFSIuJ+PBAnv3yD44cq+WBy/oiy76z1CflVHYMCgiBgMEQM9g7f8lAtOaNy9s7fucH\n7vkC7hZBXZfFkz7+Wob/jMRrp0nV3SK4fonWAqou0vwlFYfhyHati9HQdwJal6TOIKB4l1fHjy+A\nowYcteA87uywagbNUeM+14LLDuG6xn+XN65EpnWr3nRNlPXnXZ9DUCQERYE5EsLPcl+78wxaU63O\nl1LtqObWlCtR0YyHL/+bmu2MvPvuux5llySJ9957z2+VH2D8yATGJMWwv6CSATHBRJiNLE/N5r8/\nZpFRUMk/bxxMcEALt6nq1Ft8SSDpvf/7ug89hXKOUyHJbSSSxjX+uKpCTZlmFCrzNWNQVaAd1UVs\nPFxU383wyCmB6yswmEEfAIYA0Adq3SNLjDsvsP5sCAR9IBu/nOP2cxynYnesqP8M1a2Gquqdpyru\nfC1v42ezvL9D3bNhLshP0/xB1SUntqiMQRAUiRQUhTkoErM5Er3XcGQHNPvrcDgcHDt2jNDQUMrL\ny7HZWtjE7cJEWkyeEQOAey/pw4CYYP7xaRpTl21lye3n0CeyZZ71Lk9rDZwkgTlCO6IGnng/6/vG\n37v/FJxh3y5oPL/PZa0v66v5jeffurQ+rbg0P4u1WDNu1UVQXew+irT84ozj5iKo9Q7hNqZZ5Z8z\nZw4PPfQQqqoiyzKPP/64TwQ507mkfyTL7hnJ3LW7uXf5Vp66YTCXN7bZ55lAp26RnMHIOgjqph2N\nGTY3Gz3dLncXx0fK36IZfg6Hg9LSUmJiTrKJZQfSVrv0tgVVNidPrdvDj/uLuP/Svky7pE/H+QEE\nZyZLmjAMbWyUm635P/74Y9atW0dJSQkff/wxs2fP5pVX2n6qYVfBYtLz/K1n89ZPB1m66QDp+Zof\nwGLq8vOpBG1FO7W8mm1PrF27lmXLlhEaGopOp6O8vLy5V/weWZa4/7K+LLn9HLYfKuPeZVvJKanu\naLEEAi+aVX6dTkd1dTWSJFFbW+vTSQddjcsGRLFs6kgUVeWeZVvZlFHU/EsCQTvRrPLPnTuXmTNn\ncuDAAWbOnMljjz3WHnJ1GXpHBrHsnpEMSwjjsfd38eamAyjKGb+QUtAFOGlHVFVVMjMzeeutt9pL\nni5JcICBJbedw9JNB3jjxwPsy6/kqRsGEyT8AIIO5KQ1vyRJ/N///V97ydKlkWWJBy5PZPFtZ7M1\nu5R7l28lt9Ta0WIJ/Jhmq56ysjLGjRvHwIEDkSRtLfzzzz/fHrJ1Sa4cGE2vqWbmrN3F3cu28PSN\nQxjYPdgzY7Dh5CGBwJc0q/xz584lIiKiPWTxG/pGWVhx7/k8+Uka01dvQ5IkTHoZg05m5ph+jB+Z\n0NEiCvyAZh1+L7/8smeRT90hOH2CAwz87VotVmFlrYOSajul1XYWf72P7GIfLeMSCBrQbM0fHR3N\nG2+8wZAhQzzDfHVLfAWnR1ZxFQEGHXqdjM3hwu5SKLe6uPW1VC7sG8kl/SO5pF8kPSPMHS2qoAvS\nrPLHxcVht9vZvn27J08of9swICYYg04GFPQmPUFo61r+cmlftueW8z8bM/n3+v307hakGYL+kZwd\nF4peJ/ZXFZw+LZrbX1RURF5eHnFxcURHnyRSTQfRmeb2t5aUrYd4ZUMmDpeCQScza0x/7hjZE9DW\nCfx6oISfMov5ObOYcquD4AA9oxI1Q3BRYjdCjls+XFxlE85DQYtoVvnffPNNfv31V8466yz27t3L\nhRdeyP3aPVgJAAAYx0lEQVT3399e8rWIM1n5oWUK61JU9h6pYFNmET9lFJNZWIVOljgnPoyL+0dy\nWf9Ifskq4dWN9YZEOA8FJ6NZ5Z80aRJr1qzxXE+cOJF3333X54K1hjNd+U+Fo8dq2JShtQh+yy7D\n5nBRUevAoJMx6mX0OhmjTubThy8WLQBBozTbeTQYDGzfvp3a2lp+++039HoxK60zEBsayB3n9SR5\nwnDWP3oZ0y7tozkOnQrHahyUVNkorrLx7/X72X6oDLvTd/HfBWcmzdb8R48eZenSpRw6dIhevXox\nbdo0evTo0V7ytQh/rPmPp7jKxo0Ngova3ecoi4kahwuTQeac+DBG9o7g3F7hnNU9WDgO/Zxmq3Gb\nzcaTTz6JJEmoqkpOTk57yCVoJccHFw0xGJg1pj+3jIhjX0El23PK2Jpdxls/HeQ/32diMekZlhDG\neb0iGNk7nMQoywlBR4TzsGvTbM1/9913s2LFiiavOwOi5q+nOYV1uBT2HKngt+xSfssu4/fDx3C4\nFEIDDZzXO5xz3cZAOA+7Ps3W/LW1tZ60qqpe14LOx/HBRY/HoJMZ1jOMYT3DuO9SqHW42J13jK3Z\npWzPKePFb/fhcCocq3WglzXnoQq8siGTMUkxogXQhWhW+W+66SamTp3KoEGD+OOPP7jppptaVPCi\nRYtIS0tj0KBBLFhQHyV19uzZFBcXY7fbqa2t5dNPP+XVV19l/fr1hIaGMnr0aO7x0d5kghMJMOg4\nv08E5/fR1m9U2Zys2ZxD8oYM7C4FW622I45Jr+PDbXnce0kf98QkwZlOs8o/ceJErr76avLy8rjv\nvvtatMhnz549WK1W1qxZw8KFC9m9ezdnn302AP/+978BWL9+PWlpaZ535s+fz6hRo071ewjaCItJ\nzy3nxrPilxzNeaiq2BwKdqfC0h8PsHZbHmOSorlmcHfOiQ8TwUnPYFo0bhcREdGqlX07d+70KPKo\nUaPYuXOnR/nrWL9+PXfffbfnesmSJYSEhDBv3jySkpJa/FmCtsfLeej2B8wc048RCeF8nZbPN3vz\n+Xj7YbqHBjB2UAzXDOlOv+jg5gsWdCp8MmhfWVlJz57aFNXg4GAyMjK87jscDvbv38/gwdo2UVOm\nTGHGjBlkZ2fz97//3WtSkaBjOH5norq+fv+YYB66sh87csv4Zk8BH+04zMpfcugXbeHqwd25enB3\nuocGdLD0gpbQIuWvqqqisrLSs5Fgc+P8wcHBVFVVed4NCQnxur9lyxbP/n8AYWFhAPTu3btZWVJS\nUkhJSfHKs9tPspW14JRpynkoyxLn9org3F4RzBk7kNSsYr7ZU8DSTQf4z/eZDOsZxjVDujPmrBhC\nzdraAzFs2PloVvmffPJJjhw54rWg51//+tdJ3xk2bBgpKSlcd911pKamcsst3ru7rl+/nuuvv95z\nXVVVhcViobS0FJfLddKyx48fz/jx473y6ob6BO2PUS9zxcBorhgYTZXNyffphXy9J5/nv97Hi9/u\n54K+EVhMer7bW4BTUcWwYSeiWeXPy8tj2bJlrSp08ODBGI1GJk2aRFJSErGxsbz22mtMnz4dVVXZ\nuXMn//jHPzzPP//88+zfvx9VVUV04DMYi0nPuHN6MO6cHhRV2vjujwLW7TzCrwdLkJAw6mXMRp0Y\nNuwkNDvJZ968eQwaNIgBAwZ48jrben4xyafzkppVzIw1O7A5FWodLhRVxaiXefL6QUy+sFdHi+fX\nNFvz9+zZk8rKSrZt2+bJ62zK3xpE37N9GRATTIBBh06WMBt11Dpd1NoV/r1+Pz9nFXPPxX04r1e4\n2AymA/CrYB7HB85oSd/T4XDwzjvvsGnTJs4//3wGDx7MJZdccsJzTqeTF198kXnz5nnysrKyyMjI\n4Jprrmn2OyQnJ2OxWIiOjmbUqFGsX7+eCRMmNPvemcDxv/uM0f3oZjGxPPUgGQVVDO4Rwj0X9+GS\nfpFi3kA70mzNfyYE82iK3FIrVTYnAGXVdl78dj9O9245TsXFi9/uJyHCTJjZCGh91uPj5RkMBqZO\nnUplZSUTJ07kvffeIysri8rKSm688UZWrFhB3759GTFiBBkZGV4jGevWrWPWrFlkZmaybt06YmJi\nuO222/jss888oyEWi4XRo0djMpmYNm0aixYtYty4ceTl5bXXz+Rzmho2/FNSND9nlrDs54PMWbuL\nxCgLUy/uzZizosWKw3agWeXfuHHjCcE8zgTlL6u2c/vrv6C4GzZ2p0JFreOE56av3o5Rr/2jyZLE\nV7MuJTzI2GS5mzdv5s9//jOFhYWEhYURExNDeXk58fHx9O/f32sIs7a2FlmW2bBhA3fffTfdunVr\nsty6Zm/dWVEUXC4XOp2u9V++E9LYsKEkSVzSP5KL+3Vj+6Eylv2czZOfpPHf8EDuuqg31w2N9fxt\nBG1Ps8pfF8xj0KBBpKWlnTHBPMKDjKx98CJPzV9utTPrvZ2emh9AL0skTxjmVfOfTPFBm7FYU1ND\n7969KS8vx2w2c+DAAaqrqykvLyc1NdUzuzEgIABFURgzZgwrV66ke/fu3HLLLdx2220nlGuz2Xj7\n7bcZMmQIALIsdxnFbw5Jqp83kHb4GCtSs1n05R8s3XSAOy/sxU3D4gg0+sdv0Z74VTCPkwXL9AWt\n6fM3pLS0lG+//bbL9PlPhczCKlb+ks23ewoICdQzYWQCt58XT3BA208a8lcncJPKr6oqkiShKIrX\nNWi1UmeiNUN9/vqHPlPJLbWyenMOn+8+ikkvc+u58QQadLz108E2iTVwKk7grkKTyv+vf/2Lv/3t\nb0yZMsWj9HUGYOXKle0qZHOIcf6uT2FlLe9sPsQH2/LIr6jFpJfRuUcGdJLEA1ckYjHqUFRQVBVF\n1f5f1QbXiqqiKCoqWrqq1knK1lxc7rkHRp22ZZq/BD1tsgP/t7/9DYDp06d7LbX97bfffC+VL1gy\nsPH8OfvaVw7BKREdHMDsqwYwpEcIs9/fRY3D5VlrIiGx/OeDBBn1IGnGQJYkJAmvsyxp/oW6c1Wt\ng1qnCxUtqIlRL2Mx6dlfUOnfyl/H66+/7qX8y5cv57zzzvOpUJ2JjhjnHzduHAC5ubmekZapU6dy\n9OhRXn/9dR599FEGDBjAW2+9xeTJkwkI8J9VdMN7hRMaaMDcwAF4qrV1w6CnNqdClc1JudXBweJq\nLurbrctPPGpS+T/88EM+/PBD9u/fz+TJk1FVFVmWGTp0aHvKd3qU5YCtUksrJw7zAZBfH1AEUzCE\ne0857ahxfoBvvvmGqVOnYrfbWb9+PXfeeafXAqZLL72Ub7/9lhtuuKGNfrDOT8NYAw0dt6dSU3sH\nPQWz0UTvyCCSv8sgNbOE+dee1aX3SWxS+W+99VZuvfVWNm7cyOjRo9tTprbBWgpvXwOqO159TVnj\nz62+tT4tyTD9ZzA3Hbikvcb5j89vjB49evDNN980eb+r0tSkobYqKzWzmH99lc6kNzfzl8sSmTiy\nZ5ecdNRss/+7777zKL+qqixYsIBnn33W54KdNuYIuPfr+pp/ZRO1450f1qdNwSdVfGifcf7S0lI2\nb97M1VdfzfLlywEtanJWVhapqank5OTwwAMPcOTIEXr18s/FMc0FKj2dskb1i+S9v1zIaz9k8T8b\nM1i/t4AF1ycxIKZrRStqdpx/ypQprFq1qsnrzkCLvP0d4PA71XH+luKPff725ve8Yzz9xV5yS63c\neWEv7ru0DyZ915hw1GzNHx4eztq1axk+fDg7duwgPDy8PeRqezrAq5+YmEhiYqLPyp82bZrPyhZo\nDI0PZfW0C1iRms3y1Gy+Ty/k79cnMSLhDNWDBjTbkVm8eDHV1dWsXr2ampoaFi9e3B5yCQSdBqNe\n5v7L+rLi3vMJCTTw4KptPPdVOpWNrBU5k2jRkl673U5JSUmLY/i1N2KSj6C9cCkqH2zL5X9/yMJi\n0jPvmrO4bEBUR4t1SjTb7H/jjTf46aefOHDgAAkJCRiNRo8T6kxi9PuNj1hsvGNjO0siOJPRyRLj\nRyZwaf8onvsqnTlrd/GnQTE8dtUAup1hE4OabfZv2LCBlStX0qdPH9asWeOJtOsvOBwOli9fzrRp\n0/jvf//LTz/91OhzTqfzhC5RVlYWX3/9dYs+Jzk5mbfeeovPPvvMk5ebm8vixYtZvHgxBQUFbNiw\ngTfeeIN//OMfuFwu3nrrLbF9WgfRIyyQ5AnD+OcNg9l6sJQ7/vsLn+06QlFlLalZxRRX2TpaxGZp\ntuY3GrUlrgEBAWzdupWsrCyfC9VWHK46TJVdm0zjVJyNPrOvtN4RaDFaiLPEed3vjJN8/vnPf+Jw\nOPxykk9nQpIkrh0aywV9u/HS+n38/aPfcSoqgUYdJr3MzDH9mXj+qS8S8vUitGaV/4knnsButzN/\n/nzeffddHn/88TYXwheU15Zz11d3efwUx2zHGn1u+nfTPWlJkvhw3IeEBTTduunoST7Lli1j7Nix\nBAQE+O0kn85GRJCRWWMG8NXv+dTUOLBZtfDzCz7+nbd/OkhUsIlws5HwIAMRZiPhQUYigoyEmY1E\nmI2EmQ1EBBkxG3Wev3d7rDY8qfKrqsrbb7/N888/T2JioteGm52dsIAwVl670lPz3/9t49GHXvvT\na560xWg5qeJDx07yeeONN9i3bx+SJDFkyBC/nuTT2cgorEQnS4QHGXG6FM+qwvP7RBASaKCs2k5h\nhY19+ZWUWR1U1Jw4UmDUy4SbjQSZ9OzKLQdAr9OMgS/CnTfr7X/hhRcYO3YsgwcP9qzjPxPX83eE\nw09M8vEfGi4SquNkC44cLoVyq4Nyq53SajtlVjtlVgel1Xb2HD7Gxn2FKCpIQGigFsDk5QnDGJUY\n2WYyN9vs3717N7t370aSpE67nr8ldIRXX0zy8R9au+DIoJOJCjYRFXzi/eIqG7saMSRtPb24SeWv\n20Krs03lPV2s23cAYB4xvIMlEXQ12mrBUVuuXDwZTSr/X//6V08N//e//51Fixa16Qd3FKXurceE\n8gt8QVstOGrLlYtN0aLOe1eJIW/dvoOanTup2bnT0wJojo4c5y8qKuKZZ55h9erVAHz//ffs3bu3\nReUJznwiLSZGJUb6LKpQkzV/Xl4eycnJqKrqSdcxa9YsnwjT1tjzDqNUV3mui159BdXp9KRj5s/3\nel4OsmCM7zzj/FFRUUyZMoVNmzYBcMUVV/DSSy8xaNAg3/xgAr+iSeV/7rnnPOmGYbzOFJxlZeRM\nngxKvdPEdeyYR/lrtm0n977jhv9kmT6ffIz+JCsX23uc//hn7Hb7Sb+3QNBSmlT+hv/EZyL68HB6\nvfOOV81f+0c6he6mefS8eQQkneX1jhxkOaniQ/uO819++eV89NFHHDx4kAsuuIB+/fphMp1Z88cF\nnZcWreo7FRYtWkRaWhqDBg3ymhw0f/58srKyCAgI4I477mDcuHEUFBQwd+5c7HY7M2fObHVLozWr\n+vJmzAQg/tVXWv+lWklbj/N///33xMTEiGa/oE3wyd5be/bswWq1smbNGhYuXMju3bs5++yzPfeX\nLFniNTNt6dKlzJo1i7POOosHH3zQp92MiHvu8VnZx9PW4/xXXnllm5UlEPhkqt7OnTs9Cjxq1Ch2\n7tzpuSdJEvPmzePBBx/k8OHDAOzbt48RI0YQFBREUFCQxxnmC8wjhothPoEAH9X8lZWV9Oyp7YEX\nHBxMRkaG5968efMICwvjt99+Y/HixbzyyisoiuJxdFksFioqKrBYLI2WnZKSQkpKileecIIJBK3H\nJ8ofHBzsqb2rqqoICQnx3KuLB3Deeefx4osvAt5rBY5//njGjx/P+PHjvfLq+vwCgaDl+ET5hw0b\nRkpKCtdddx2pqanccsstnnt104YPHDjgUfKBAweyY8cOBg4cSHV1dZO1flO4XNoSyvz8/Lb7EgLB\nGUb37t3R61uu0j5R/sGDB2M0Gpk0aRJJSUnExsby2muvMX36dObMmcOxY8eQJImnnnoKgPvuu4/H\nH38cm83GjBkzWv15RUVFAEyePLktv4ZAcEbR2hiWPhvqa09qa2tJS0sjKioKna59Yqo/+OCDvP76\n6+3yWa2ls8rWWeWCzitba+TqFDV/exMQENDum4cajcZOGym4s8rWWeWCziubL+XqXFE5BAJBuyGU\nXyDwU4TyCwR+iu6pOpe7oNUMGTKko0Voks4qW2eVCzqvbL6Sq0t4+wUCQesRzX6BwE8Ryi8Q+ClC\n+QUCP0Uov0DgpwjlbyW7du1iwoQJTJw4sdOGM1++fDkTJ07saDG8+OSTT7j77ruZMmUKBQUFHS2O\nh5qaGv7yl78wZcoUpk+f3uHLwwsKCrj55psZOnQoTne8yUWLFjFp0iSeeeaZNv0sofytpEePHqxY\nsYJ3332XkpIS9u3b1/xL7YjdbuePP/7oaDG8KCgoYMuWLaxYsYJVq1YRExPT0SJ52LRpE2effTar\nVq3i7LPP5scff+xQecLCwli+fDnDhg0DvKNiORwOdu/e3WafJZS/lURFRXmCaBoMhnZbSNRS1q5d\ny0033dTRYnixadMmFEXh7rvv5umnn/Yswe4MJCQkUFNTA0BFRYUn3kRHYTKZCA0N9VyfLCrW6SKU\n/xRJT0+ntLSUfv36dbQoHhwOB1u2bOGiiy7qaFG8KCkpweFwsGLFCgICAtiwYUNHi+ShV69e7Ny5\nk+uvv560tDRGjBjR0SJ5UVlZ6YlvERwcTEVFRZuVLZT/FCgvL+fpp5/m2Wef7WhRvPj00089G350\nJiwWCyNHjgTgwgsvJCsrq4Mlqufjjz/myiuv5IsvvuCKK65g3bp1HS2SFyeLinW6COVvJU6nk7lz\n5zJv3jyioqI6WhwvDh48yLvvvsu0adPIzMzsNJusjhgxwuMb+eOPPzrV0llVVT3N7PDwcCorKztY\nIm+GDRvG5s2bAUhNTfX4AtoCMb23lXz++ec888wz9O/fH4BHH32U4cM7XzTgiRMn8u6773a0GB4W\nL15MWloa4eHhLFmyBKPR2NEiAVo/f/bs2djtdvR6Pf/+9787tN/vcDi4//772bNnD4MGDeLRRx/l\ns88+Y+/evSQlJfHkk0+22WcJ5RcI/BTR7BcI/BSh/AKBnyKUXyDwU4TyCwR+ilB+gcBPEcrfBfj1\n118ZPny4Z/bX/PnzycnJOaWyPvroI9auXduW4mG1WpkwYQIzZ870yv/ggw9aVc6UKVM8i10Ep49Q\n/i5CbGxsmyttS1EU5aT309PTOe+883jllVe88j/88ENfiiVohi6xaYcAxowZw/fff8/UqVM9ea++\n+irnnnsuo0aNYv78+Tz88MNs2bKFH374gdraWlwuF6NHj+bLL7+kd+/enunKGzdu5Ouvv8ZoNJKc\nnIzBYOCpp57i4MGDBAQE8MILL5Cens6yZcsAbULR5ZdfDmhz0efMmUNVVRVJSUksWLCAF154gfz8\nfHQ6HbNnzwa03Zb379/PlClTWLBgAWvXriU9PR1FUViyZAmRkZE8/PDD1NTUEBERQXJysud7ffbZ\nZ+zevZuHHnrIs73bwIEDWbBgQXv81F0GUfN3EWRZ5sorr+Tbb79t9tno6GjeeOMNevTogcPh4J13\n3uHo0aOUl5cD0K1bN9566y2GDx/O+vXr+f777+nRowcrV65k8uTJvPfee4A2G+3111/3KD5oSn3t\ntdfyzjvvUFNTw65du3jkkUe44YYbPIoP2m7LAwYMYNWqVQwcOJDHHnuM1atX8/DDD5OSkkJ+fj4R\nERGsWrWKl19+2fPe559/zq5du3jiiSf4448/OP/881m1ahVPPPFEW/2UfoOo+bsQt99+O4888gjR\n0dEASJLkuddwIueAAQMAzQjUTVOOjo72+AySkpI8599//x2DwcAXX3zBTz/9hNPp9MwvHzx48Aky\nHDp0yGMMhgwZQk5OTovW77/55pv88ssvOJ1OEhMTSUhIYMCAATz22GMMGTKEe+65B4ClS5eyZs0a\nQNvmfcuWLTz22GNceumlnW4pc2dHKH8XIiQkhD59+vDLL78A2mq6wsJCVFUlIyPD81xDo9CYgahb\nhJOenk5CQgIBAQHcdNNN3HvvvYBW42/fvt3r3ToSEhLYs2cP/fv3Jy0tjdtvvx2bzdaovHXvl5WV\nsWXLFtasWcPPP//MZ599ht1uZ+rUqciyzL333utZrfjcc88xd+5cXnnlFSRJYtasWQDceOONQvlb\niWj2dzGmTJnCgQMHABg7diwrV65k1qxZXgEimqO8vJx7772Xbdu2MXbsWMaMGcPhw4e56667uOuu\nu04a7eaOO+7giy++YNKkSRiNxpOuQouNjWXGjBmUlJRgNpu56667+OGHHwA4fPgwkydPZvz48YSH\nh9OtWzdAa41MmzaNxx9/nN27dzNx4kRuv/12T8ALQcsRC3sEAj9F1PwCgZ8ilF8g8FOE8gsEfopQ\nfoHATxHKLxD4KUL5BQI/RSi/QOCnCOUXCPyU/w+X4IpEBK2BIgAAAABJRU5ErkJggg==\n", 169 | "text/plain": [ 170 | "" 171 | ] 172 | }, 173 | "metadata": {}, 174 | "output_type": "display_data" 175 | } 176 | ], 177 | "source": [ 178 | "fig = plt.figure(figsize=(3.3,2.4))\n", 179 | "# fig, ax = plt.subplots()\n", 180 | "# plt.rc('font', family='serif', serif='Times')\n", 181 | "gs = GridSpec(2, 1, height_ratios=[0.5, 1])\n", 182 | "ax = plt.subplot(gs[0])\n", 183 | "plot_data()\n", 184 | "# Training error for control network -- trained conventionally\n", 185 | "# plt.arrow(10.5, 0.995198, -0.2, 0, head_width=0.005, head_length=0.2, fc='k', ec='k')\n", 186 | "plt.tick_params(\n", 187 | " axis='x', # changes apply to the x-axis\n", 188 | " which='both', # both major and minor ticks are affected\n", 189 | " bottom='off', # ticks along the bottom edge are off\n", 190 | " top='off', # ticks along the top edge are off\n", 191 | " labelbottom='off')\n", 192 | "plt.tick_params(\n", 193 | " axis='y', # changes apply to the x-axis\n", 194 | " which='both', # both major and minor ticks are affected\n", 195 | " right='off', # ticks along the bottom edge are off\n", 196 | " left='on')\n", 197 | "\n", 198 | "\n", 199 | "ax.spines['top'].set_visible(False)\n", 200 | "ax.spines['right'].set_visible(False)\n", 201 | "ax.spines['bottom'].set_visible(False)\n", 202 | "# ax.get_yaxis().tick_left()\n", 203 | "ylim(0.965, 1.005)\n", 204 | "yticks([0.97, 1.0])\n", 205 | "xlim(0.5, 10.46)\n", 206 | "\n", 207 | "ax2 = plt.subplot(gs[1])\n", 208 | "plot_data()\n", 209 | "\n", 210 | "xlabel('Number of tasks')\n", 211 | "ylabel('Fraction correct')\n", 212 | "ylim(0.48, 1.02)\n", 213 | "xlim(0.5, 10.5)\n", 214 | "simple_axis(ax2)\n", 215 | "yticks([0.5, 0.75, 1.0])\n", 216 | "\n", 217 | "from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes\n", 218 | "\n", 219 | "legend(loc='lower left', fontsize=6)\n", 220 | "\n", 221 | "plt.subplots_adjust(left=.18, bottom=.20, right=.99, top=.97)\n", 222 | "plt.savefig(\"accuracy_vs_nbtasks.pdf\", pad_inches=0)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": { 229 | "collapsed": true 230 | }, 231 | "outputs": [], 232 | "source": [] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": { 238 | "collapsed": true 239 | }, 240 | "outputs": [], 241 | "source": [] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": { 247 | "collapsed": true 248 | }, 249 | "outputs": [], 250 | "source": [] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": { 256 | "collapsed": true 257 | }, 258 | "outputs": [], 259 | "source": [] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": { 265 | "collapsed": true 266 | }, 267 | "outputs": [], 268 | "source": [] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": { 274 | "collapsed": true 275 | }, 276 | "outputs": [], 277 | "source": [] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "collapsed": true 284 | }, 285 | "outputs": [], 286 | "source": [] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": { 292 | "collapsed": true 293 | }, 294 | "outputs": [], 295 | "source": [] 296 | } 297 | ], 298 | "metadata": { 299 | "kernelspec": { 300 | "display_name": "Python 3", 301 | "language": "python", 302 | "name": "python3" 303 | }, 304 | "language_info": { 305 | "codemirror_mode": { 306 | "name": "ipython", 307 | "version": 3 308 | }, 309 | "file_extension": ".py", 310 | "mimetype": "text/x-python", 311 | "name": "python", 312 | "nbconvert_exporter": "python", 313 | "pygments_lexer": "ipython3", 314 | "version": "3.5.2" 315 | } 316 | }, 317 | "nbformat": 4, 318 | "nbformat_minor": 1 319 | } 320 | -------------------------------------------------------------------------------- /fig_permuted_mnist/data_path_int[omega_decay=sum,xi=0.1]_optadam_lr1.00e-03_bs256_ep20_tsks10.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ganguli-lab/pathint/86c5e07c5603d6a44c2f53585c19ee70773ef15c/fig_permuted_mnist/data_path_int[omega_decay=sum,xi=0.1]_optadam_lr1.00e-03_bs256_ep20_tsks10.pkl.gz -------------------------------------------------------------------------------- /fig_split_mnist/Figure split MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Copyright (c) 2017 Ben Poole & Friedemann Zenke\n", 12 | "# MIT License -- see LICENSE for details\n", 13 | "# \n", 14 | "# This file is part of the code to reproduce the core results of:\n", 15 | "# Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through\n", 16 | "# Synaptic Intelligence. In Proceedings of the 34th International Conference on\n", 17 | "# Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention\n", 18 | "# Centre, Sydney, Australia: PMLR), pp. 3987–3995.\n", 19 | "# http://proceedings.mlr.press/v70/zenke17a.html" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 10, 25 | "metadata": { 26 | "collapsed": false, 27 | "scrolled": false 28 | }, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "The autoreload extension is already loaded. To reload it, use:\n", 35 | " %reload_ext autoreload\n", 36 | "Populating the interactive namespace from numpy and matplotlib\n" 37 | ] 38 | }, 39 | { 40 | "name": "stderr", 41 | "output_type": "stream", 42 | "text": [ 43 | "/usr/local/lib/python3.5/dist-packages/IPython/core/magics/pylab.py:160: UserWarning: pylab import has clobbered these variables: ['colors']\n", 44 | "`%matplotlib` prevents importing * from pylab and numpy\n", 45 | " \"\\n`%matplotlib` prevents importing * from pylab and numpy\"\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "%load_ext autoreload\n", 51 | "%autoreload 2\n", 52 | "%pylab inline\n", 53 | "\n", 54 | "import tensorflow as tf\n", 55 | "slim = tf.contrib.slim\n", 56 | "graph_replace = tf.contrib.graph_editor.graph_replace\n", 57 | "\n", 58 | "import sys, os\n", 59 | "sys.path.extend([os.path.expanduser('..')])\n", 60 | "from pathint import utils\n", 61 | "import seaborn as sns\n", 62 | "sns.set_style(\"ticks\")\n", 63 | "\n", 64 | "from tqdm import trange, tqdm\n", 65 | "\n", 66 | "# import operator\n", 67 | "import matplotlib.colors as colors\n", 68 | "import matplotlib.cm as cmx\n", 69 | "\n", 70 | "rcParams['pdf.fonttype'] = 42\n", 71 | "rcParams['ps.fonttype'] = 42" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "## Parameters" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 11, 84 | "metadata": { 85 | "collapsed": true 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "select = tf.select if hasattr(tf, 'select') else tf.where" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 12, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "# Data params\n", 101 | "input_dim = 784\n", 102 | "output_dim = 10\n", 103 | "\n", 104 | "# Network params\n", 105 | "n_hidden_units = 256\n", 106 | "activation_fn = tf.nn.relu\n", 107 | "\n", 108 | "# Optimization params\n", 109 | "batch_size = 64\n", 110 | "epochs_per_task = 10\n", 111 | "\n", 112 | "n_stats = 10\n", 113 | "\n", 114 | "# Reset optimizer after each age\n", 115 | "reset_optimizer = True" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "## Construct datasets" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 13, 128 | "metadata": { 129 | "collapsed": true 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "task_labels = [[0,1], [2,3]]#, [4,5], [6,7], [8,9]]\n", 134 | "task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9]]\n", 135 | "# task_labels = [[0,1,2,3,4], [5,6,7,8,9]]\n", 136 | "n_tasks = len(task_labels)\n", 137 | "training_datasets = utils.construct_split_mnist(task_labels, split='train')\n", 138 | "validation_datasets = utils.construct_split_mnist(task_labels, split='test')\n", 139 | "# training_datasets = utils.mk_training_validation_splits(full_datasets, split_fractions=(0.9, 0.1))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## Construct network, loss, and updates" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 14, 152 | "metadata": { 153 | "collapsed": true 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "tf.reset_default_graph()" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 15, 163 | "metadata": { 164 | "collapsed": true 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "config = tf.ConfigProto()\n", 169 | "config.gpu_options.allow_growth=True\n", 170 | "sess = tf.InteractiveSession(config=config)\n", 171 | "sess.run(tf.global_variables_initializer())" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 9, 177 | "metadata": { 178 | "collapsed": true 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "# tf.equal(output_mask[None, :], 1.0)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 16, 188 | "metadata": { 189 | "collapsed": true 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "import keras.backend as K\n", 194 | "import keras.activations as activations\n", 195 | "\n", 196 | "output_mask = tf.Variable(tf.zeros(output_dim), name=\"mask\", trainable=False)\n", 197 | "\n", 198 | "def masked_softmax(logits):\n", 199 | " # logits are [batch_size, output_dim]\n", 200 | " x = select(tf.tile(tf.equal(output_mask[None, :], 1.0), [tf.shape(logits)[0], 1]), logits, -1e32 * tf.ones_like(logits))\n", 201 | " return activations.softmax(x)\n", 202 | "\n", 203 | "def set_active_outputs(labels):\n", 204 | " new_mask = np.zeros(output_dim)\n", 205 | " for l in labels:\n", 206 | " new_mask[l] = 1.0\n", 207 | " sess.run(output_mask.assign(new_mask))\n", 208 | " print(sess.run(output_mask))\n", 209 | " \n", 210 | "def masked_predict(model, data, targets):\n", 211 | " pred = model.predict(data)\n", 212 | " print(pred)\n", 213 | " acc = np.argmax(pred,1)==np.argmax(targets,1)\n", 214 | " return acc.mean()" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 17, 220 | "metadata": { 221 | "collapsed": true 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "from keras.models import Sequential\n", 226 | "from keras.layers import Dense\n", 227 | "model = Sequential()\n", 228 | "model.add(Dense(n_hidden_units, activation=activation_fn, input_shape=(input_dim,)))\n", 229 | "model.add(Dense(n_hidden_units, activation=activation_fn))\n", 230 | "model.add(Dense(output_dim, kernel_initializer='zero', activation=masked_softmax))" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 9, 236 | "metadata": { 237 | "collapsed": true 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "from pathint import protocols\n", 242 | "from pathint.optimizers import KOOptimizer\n", 243 | "from keras.optimizers import Adam, RMSprop,SGD\n", 244 | "from keras.callbacks import Callback\n", 245 | "from pathint.keras_utils import LossHistory\n", 246 | "\n", 247 | "#protocol_name, protocol = protocols.PATH_INT_PROTOCOL(omega_decay='sum',xi=1e-3)\n", 248 | "protocol_name, protocol = protocols.PATH_INT_PROTOCOL(omega_decay='sum',xi=1e-3)\n", 249 | "# protocol_name, protocol = protocols.SUM_FISHER_PROTOCOL('sum')\n", 250 | "opt = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999)\n", 251 | "# opt = SGD(1e-3)\n", 252 | "# opt = RMSprop(lr=1e-3)\n", 253 | "oopt = KOOptimizer(opt, model=model, **protocol)\n", 254 | "model.compile(loss='categorical_crossentropy', optimizer=oopt, metrics=['accuracy'])\n", 255 | "model._make_train_function()\n", 256 | "saved_weights = model.get_weights()\n", 257 | "\n", 258 | "history = LossHistory()\n", 259 | "callbacks = [history]\n", 260 | "datafile_name = \"split_mnist_data_%s.pkl.gz\"%protocol_name" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "## Train!" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 10, 273 | "metadata": { 274 | "collapsed": true, 275 | "scrolled": true 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "# diag_vals = dict()\n", 280 | "# all_evals = dict()\n", 281 | "# data = utils.load_zipped_pickle(\"comparison_data_%s.pkl.gz\"%protocol_name)\n", 282 | "# returns empty dict if file not found\n", 283 | "\n", 284 | "def run_fits(cvals, training_data, valid_data, eval_on_train_set=False, nstats=1):\n", 285 | " acc_mean = dict()\n", 286 | " acc_std = dict()\n", 287 | " for cidx, cval_ in enumerate(tqdm(cvals)):\n", 288 | " runs = []\n", 289 | " for runid in range(nstats):\n", 290 | " sess.run(tf.global_variables_initializer())\n", 291 | " # model.set_weights(saved_weights)\n", 292 | " cstuffs = []\n", 293 | " evals = []\n", 294 | " print(\"setting cval\")\n", 295 | " cval = cval_\n", 296 | " oopt.set_strength(cval)\n", 297 | " oopt.init_task_vars()\n", 298 | " print(\"cval is\", sess.run(oopt.lam))\n", 299 | " for age, tidx in enumerate(range(n_tasks)):\n", 300 | " print(\"Age %i, cval is=%f\"%(age,cval))\n", 301 | " print(\"settint output mask\")\n", 302 | " set_active_outputs(task_labels[age])\n", 303 | " stuffs = model.fit(training_data[tidx][0], training_data[tidx][1], batch_size, epochs_per_task, callbacks=callbacks)\n", 304 | " oopt.update_task_metrics(training_data[tidx][0], training_data[tidx][1], batch_size)\n", 305 | " oopt.update_task_vars()\n", 306 | " ftask = []\n", 307 | " for j in range(n_tasks):\n", 308 | " set_active_outputs(task_labels[j])\n", 309 | " if eval_on_train_set:\n", 310 | " f_ = masked_predict(model, training_data[j][0], training_data[j][1])\n", 311 | " else:\n", 312 | " f_ = masked_predict(model, valid_data[j][0], valid_data[j][1])\n", 313 | " ftask.append(np.mean(f_))\n", 314 | " evals.append(ftask)\n", 315 | " cstuffs.append(stuffs)\n", 316 | "\n", 317 | " # Re-initialize optimizater variables\n", 318 | " if reset_optimizer:\n", 319 | " oopt.reset_optimizer()\n", 320 | "\n", 321 | " evals = np.array(evals)\n", 322 | " runs.append(evals)\n", 323 | " \n", 324 | " runs = np.array(runs)\n", 325 | " acc_mean[cval_] = runs.mean(0)\n", 326 | " acc_std[cval_] = runs.std(0)\n", 327 | " return dict(mean=acc_mean, std=acc_std)\n" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 11, 333 | "metadata": { 334 | "collapsed": false 335 | }, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "[0, 1.0]\n" 342 | ] 343 | } 344 | ], 345 | "source": [ 346 | "# cvals = np.concatenate(([0], np.logspace(-2, 2, 10)))\n", 347 | "# cvals = np.concatenate(([0], np.logspace(-1, 2, 2)))\n", 348 | "# cvals = np.concatenate(([0], np.logspace(-2, 0, 3)))\n", 349 | "cvals = np.logspace(-3, 3, 7)#[0, 1.0, 2, 5, 10]\n", 350 | "cvals = [0, 1.0]\n", 351 | "print(cvals)\n" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 12, 357 | "metadata": { 358 | "collapsed": true 359 | }, 360 | "outputs": [], 361 | "source": [ 362 | "%%capture\n", 363 | "\n", 364 | "recompute_data = True\n", 365 | "\n", 366 | "if recompute_data:\n", 367 | " data = run_fits(cvals, training_datasets, validation_datasets, eval_on_train_set=True, nstats=n_stats)\n", 368 | " utils.save_zipped_pickle(data, datafile_name)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 13, 374 | "metadata": { 375 | "collapsed": false 376 | }, 377 | "outputs": [ 378 | { 379 | "name": "stdout", 380 | "output_type": "stream", 381 | "text": [ 382 | "[0, 1.0]\n" 383 | ] 384 | } 385 | ], 386 | "source": [ 387 | "data = utils.load_zipped_pickle(datafile_name)\n", 388 | "print(cvals)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 14, 394 | "metadata": { 395 | "collapsed": false 396 | }, 397 | "outputs": [ 398 | { 399 | "name": "stdout", 400 | "output_type": "stream", 401 | "text": [ 402 | "(-5, 0.0)\n" 403 | ] 404 | } 405 | ], 406 | "source": [ 407 | "cmap = plt.get_cmap('cool') \n", 408 | "cNorm = colors.Normalize(vmin=-5, vmax=np.log(np.max(list(data['mean'].keys()))))\n", 409 | "scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)\n", 410 | "print(scalarMap.get_clim())" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 15, 416 | "metadata": { 417 | "collapsed": false 418 | }, 419 | "outputs": [ 420 | { 421 | "name": "stderr", 422 | "output_type": "stream", 423 | "text": [ 424 | "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:13: RuntimeWarning: divide by zero encountered in log\n", 425 | " del sys.path[0]\n", 426 | "/usr/local/lib/python3.5/dist-packages/matplotlib/axes/_axes.py:545: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots.\n", 427 | " warnings.warn(\"No labelled objects found. \"\n" 428 | ] 429 | }, 430 | { 431 | "data": { 432 | "image/png": "\n", 433 | "text/plain": [ 434 | "" 435 | ] 436 | }, 437 | "metadata": {}, 438 | "output_type": "display_data" 439 | } 440 | ], 441 | "source": [ 442 | "figure(figsize=(14, 2.5))\n", 443 | "axs = [subplot(1,n_tasks+1,1)]#, None, None]\n", 444 | "for i in range(1, n_tasks + 1):\n", 445 | " axs.append(subplot(1, n_tasks+1, i+1, sharex=axs[0], sharey=axs[0]))\n", 446 | " \n", 447 | "keys = list(data['mean'].keys())\n", 448 | "sorted_keys = np.sort(keys)\n", 449 | "\n", 450 | "for cval in sorted_keys:\n", 451 | " mean_vals = data['mean'][cval]\n", 452 | " std_vals = data['std'][cval]\n", 453 | " for j in range(n_tasks):\n", 454 | " colorVal = scalarMap.to_rgba(np.log(cval))\n", 455 | " # axs[j].plot(evals[:, j], c=colorVal)\n", 456 | " axs[j].errorbar(range(n_tasks), mean_vals[:, j], yerr=std_vals[:, j]/np.sqrt(n_stats), c=colorVal)\n", 457 | " label = \"c=%g\"%cval\n", 458 | " average = mean_vals.mean(1)\n", 459 | " axs[-1].plot(average, c=colorVal, label=label)\n", 460 | " \n", 461 | "for i, ax in enumerate(axs):\n", 462 | " ax.legend(loc='best')\n", 463 | " ax.set_title((['task %d'%j for j in range(n_tasks)] + ['average'])[i])\n", 464 | "gcf().tight_layout()\n", 465 | "sns.despine()" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 16, 471 | "metadata": { 472 | "collapsed": true 473 | }, 474 | "outputs": [], 475 | "source": [ 476 | "plt.rc('text', usetex=False)\n", 477 | "plt.rc('xtick', labelsize=8)\n", 478 | "plt.rc('ytick', labelsize=8)\n", 479 | "plt.rc('axes', labelsize=8)\n", 480 | "\n", 481 | "def simple_axis(ax):\n", 482 | " ax.spines['top'].set_visible(False)\n", 483 | " ax.spines['right'].set_visible(False)\n", 484 | " ax.get_xaxis().tick_bottom()\n", 485 | " ax.get_yaxis().tick_left()" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 17, 491 | "metadata": { 492 | "collapsed": false 493 | }, 494 | "outputs": [ 495 | { 496 | "data": { 497 | "image/png": "\n", 498 | "text/plain": [ 499 | "" 500 | ] 501 | }, 502 | "metadata": {}, 503 | "output_type": "display_data" 504 | } 505 | ], 506 | "source": [ 507 | "fig = plt.figure(figsize=(3.3,2.5))\n", 508 | "ax = plt.subplot(111)\n", 509 | "\n", 510 | "for cval in sorted_keys:\n", 511 | " mean_stuff = []\n", 512 | " std_stuff = []\n", 513 | " for i in range(len(data['mean'][cval])):\n", 514 | " mean_stuff.append(data['mean'][cval][i][:i+1].mean())\n", 515 | " std_stuff.append(np.sqrt((data['std'][cval][i][:i+1]**2).sum())/(n_stats*np.sqrt(n_stats)))\n", 516 | " # plot(range(1,n_tasks+1), mean_stuff, 'o-', label=\"c=%g\"%cval)\n", 517 | " errorbar(range(1,n_tasks+1), mean_stuff, yerr=std_stuff, fmt='o-', label=\"c=%g\"%cval)\n", 518 | " \n", 519 | "axhline(data['mean'][cval][0][0], linestyle='--', color='k')\n", 520 | "xlabel('Number of tasks')\n", 521 | "ylabel('Fraction correct')\n", 522 | "legend(loc='best')\n", 523 | "xlim(0.5, 5.5)\n", 524 | "ylim(0.7, 1.02)\n", 525 | "# grid('on')\n", 526 | "# sns.despine()\n", 527 | "simple_axis(ax)" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 18, 533 | "metadata": { 534 | "collapsed": false 535 | }, 536 | "outputs": [ 537 | { 538 | "data": { 539 | "image/png": "\n", 540 | "text/plain": [ 541 | "" 542 | ] 543 | }, 544 | "metadata": {}, 545 | "output_type": "display_data" 546 | } 547 | ], 548 | "source": [ 549 | "fig = plt.figure(figsize=(3.3,2.0))\n", 550 | "ax = plt.subplot(111)\n", 551 | "\n", 552 | "plot_keys =sorted(data['mean'].keys())# [0,1]\n", 553 | "\n", 554 | "for cval in plot_keys:\n", 555 | " mean_stuff = []\n", 556 | " std_stuff = []\n", 557 | " for i in range(len(data['mean'][cval])):\n", 558 | " mean_stuff.append(data['mean'][cval][i][:i+1].mean())\n", 559 | " std_stuff.append(np.sqrt((data['std'][cval][i][:i+1]**2).sum())/(n_stats*np.sqrt(n_stats)))\n", 560 | " # plot(range(1,n_tasks+1), mean_stuff, 'o-', label=\"c=%g\"%cval)\n", 561 | " errorbar(range(1,n_tasks+1), mean_stuff, yerr=std_stuff, fmt='o-', label=\"c=%g\"%cval)\n", 562 | " \n", 563 | "axhline(data['mean'][cval][0][0], linestyle=':', color='k')\n", 564 | "xlabel('Number of tasks')\n", 565 | "ylabel('Fraction correct')\n", 566 | "legend(loc='best', fontsize=8)\n", 567 | "xlim(0.5, 5.5)\n", 568 | "plt.yticks([0.6,0.8,1.0])\n", 569 | "ylim(0.6, 1.02)\n", 570 | "# grid('on')\n", 571 | "# sns.despine()\n", 572 | "simple_axis(ax)\n", 573 | "plt.subplots_adjust(left=.15, bottom=.18, right=.99, top=.97)\n", 574 | "plt.savefig(\"split_mnist_accuracy.pdf\")" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 20, 580 | "metadata": { 581 | "collapsed": false, 582 | "scrolled": true 583 | }, 584 | "outputs": [ 585 | { 586 | "name": "stdout", 587 | "output_type": "stream", 588 | "text": [ 589 | "[0, 1.0]\n" 590 | ] 591 | }, 592 | { 593 | "data": { 594 | "image/png": "\n", 595 | "text/plain": [ 596 | "" 597 | ] 598 | }, 599 | "metadata": {}, 600 | "output_type": "display_data" 601 | } 602 | ], 603 | "source": [ 604 | "figure(figsize=(7, 1.8))\n", 605 | "axs = [subplot(1,n_tasks+1,1)]\n", 606 | "for i in range(1,n_tasks+1):\n", 607 | " axs.append(subplot(1, n_tasks+1, i+1, sharex=axs[0], sharey=axs[0]))\n", 608 | "fmts = ['o', 's']\n", 609 | "\n", 610 | "plot_keys =sorted(data['mean'].keys())\n", 611 | "# plot_keys = [0]\n", 612 | "print(plot_keys)\n", 613 | "\n", 614 | "for i, cval in enumerate(plot_keys):\n", 615 | " label = \"c=%g\"%cval\n", 616 | " mean_vals = data['mean'][cval]\n", 617 | " std_vals = data['std'][cval]\n", 618 | " for j in range(n_tasks+1):\n", 619 | " sca(axs[j])\n", 620 | " errorbar_kwargs = dict(fmt=\"%s-\"%fmts[i], markersize=5)\n", 621 | " if j < n_tasks:\n", 622 | " # print(j,mean_vals[:, j])\n", 623 | " norm= np.sqrt(n_stats) # np.sqrt(n_stats) for SEM or 1 for STDEV\n", 624 | " axs[j].errorbar(np.arange(n_tasks)+1, mean_vals[:, j], yerr=std_vals[:, j]/norm, label=label, **errorbar_kwargs)\n", 625 | " else:\n", 626 | " mean_stuff = []\n", 627 | " std_stuff = []\n", 628 | " for i in range(len(data['mean'][cval])):\n", 629 | " mean_stuff.append(data['mean'][cval][i][:i+1].mean())\n", 630 | " #std_stuff.append(data['mean'][cval][i][:i+1].std()/np.sqrt(n_stats))\n", 631 | " std_stuff.append(np.sqrt((data['std'][cval][i][:i+1]**2).sum())/(n_stats*np.sqrt(n_stats)))\n", 632 | " # plot(range(1,n_tasks+1), mean_stuff, 'o-', label=\"c=%g\"%cval)\n", 633 | " errorbar(range(1,n_tasks+1), mean_stuff, yerr=std_stuff, label=\"c=%g\"%cval, **errorbar_kwargs)\n", 634 | " plt.xticks(np.arange(5)+1)\n", 635 | " plt.xlim((1.0,5.5))\n", 636 | " if j == 0:\n", 637 | " axs[j].set_yticks([0.5,1])\n", 638 | " else:\n", 639 | " setp(axs[j].get_yticklabels(), visible=False)\n", 640 | " plt.ylim((0.45,1.1))\n", 641 | "\n", 642 | "for i, ax in enumerate(axs):\n", 643 | " if i < n_tasks:\n", 644 | " ax.set_title((['Task %d (%d or %d)'%(j+1,task_labels[j][0], task_labels[j][1]) for j in range(n_tasks)] + ['average'])[i], fontsize=8)\n", 645 | " else:\n", 646 | " ax.set_title(\"Average\", fontsize=8)\n", 647 | " #ax.set_title((['Task %d'%(j+1) for j in xrange(n_tasks)] + ['average'])[i], fontsize=8)\n", 648 | " # ax.axhline(0.5, linestyle=':', color='k')\n", 649 | " ax.axhline(0.5, color='k', linestyle=':', label=\"chance\", zorder=0)\n", 650 | "handles, labels = axs[-1].get_legend_handles_labels()\n", 651 | "# Reorder legend so chance is last\n", 652 | "axs[-1].legend([handles[j] for j in [1,2,0]], [labels[j] for j in [1,2,0]], loc='lower right', fontsize=8, bbox_to_anchor=(-1.3, -.7), ncol=3, frameon=True)\n", 653 | "# axs[-1].legend(loc='lower right', fontsize=8, bbox_to_anchor=(-1.3, -.7), ncol=3, frameon=True)\n", 654 | " \n", 655 | "axs[0].set_xlabel(\"Tasks\")\n", 656 | "axs[0].set_ylabel(\"Accuracy\")\n", 657 | "gcf().tight_layout()\n", 658 | "sns.despine()\n", 659 | "plt.savefig(\"split_mnist_tasks.pdf\")" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": null, 665 | "metadata": { 666 | "collapsed": true 667 | }, 668 | "outputs": [], 669 | "source": [] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": null, 674 | "metadata": { 675 | "collapsed": true 676 | }, 677 | "outputs": [], 678 | "source": [] 679 | } 680 | ], 681 | "metadata": { 682 | "kernelspec": { 683 | "display_name": "Python 3", 684 | "language": "python", 685 | "name": "python3" 686 | }, 687 | "language_info": { 688 | "codemirror_mode": { 689 | "name": "ipython", 690 | "version": 3 691 | }, 692 | "file_extension": ".py", 693 | "mimetype": "text/x-python", 694 | "name": "python", 695 | "nbconvert_exporter": "python", 696 | "pygments_lexer": "ipython3", 697 | "version": "3.5.2" 698 | } 699 | }, 700 | "nbformat": 4, 701 | "nbformat_minor": 2 702 | } 703 | -------------------------------------------------------------------------------- /fig_transfer_cifar/Figure barplot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Copyright (c) 2017 Ben Poole & Friedemann Zenke\n", 12 | "# MIT License -- see LICENSE for details\n", 13 | "# \n", 14 | "# This file is part of the code to reproduce the core results of:\n", 15 | "# Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through\n", 16 | "# Synaptic Intelligence. In Proceedings of the 34th International Conference on\n", 17 | "# Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention\n", 18 | "# Centre, Sydney, Australia: PMLR), pp. 3987–3995.\n", 19 | "# http://proceedings.mlr.press/v70/zenke17a.html" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 19, 25 | "metadata": { 26 | "collapsed": false 27 | }, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "Populating the interactive namespace from numpy and matplotlib\n" 34 | ] 35 | }, 36 | { 37 | "name": "stderr", 38 | "output_type": "stream", 39 | "text": [ 40 | "/usr/local/lib/python3.5/dist-packages/IPython/core/magics/pylab.py:160: UserWarning: pylab import has clobbered these variables: ['colors']\n", 41 | "`%matplotlib` prevents importing * from pylab and numpy\n", 42 | " \"\\n`%matplotlib` prevents importing * from pylab and numpy\"\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "%pylab inline" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 20, 53 | "metadata": { 54 | "collapsed": true 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "import sys, os\n", 59 | "sys.path.extend([os.path.expanduser('..')])\n", 60 | "from pathint import utils\n", 61 | "import seaborn as sns\n", 62 | "\n", 63 | "rcParams['pdf.fonttype'] = 42\n", 64 | "rcParams['ps.fonttype'] = 42" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 21, 70 | "metadata": { 71 | "collapsed": true 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "datafile_name = \"split_cifar10_data_path_int[omega_decay=sum,xi=0.001]_lr1.00e-03_ep60.pkl.gz\"\n", 76 | "# backup all_evals to disk\n", 77 | "# all_evals = dict() # uncomment to delete on disk\n", 78 | "data = utils.load_zipped_pickle( datafile_name )" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 22, 84 | "metadata": { 85 | "collapsed": true 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "colors = (sns.color_palette(\"deep\"))\n", 90 | "colors[2] = 'lightgray'" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 23, 96 | "metadata": { 97 | "collapsed": false 98 | }, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPIAAACrCAYAAABc6cGbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XlAVOX+P/D3sAgom5CgYCZpGIIomV66XSVRspuaSSIg\njBmiAuIKKq5RQqDigksKYqiI5Ia50L1WoNJVw3tdAzVEhUBhMEEZEBiYOb8/+HG+jMxwDtsgw+f1\nF5zlWWb4cLbnfB4BwzAMCCGdmkZHN4AQ0noUyISoAQpkQtQABTIhaoACmRA1QIFMiBqgQCZEDVAg\nE6IGKJAJUQMUyBwyMjIwZswYCIVCCIVCiMVihIeHQyqVtqi8Y8eONXufdevWtagudaXoO3nVNPU9\nZ2RkYMuWLW1an1abltaBJgWdbNX+pzdNVrruk08+weLFi9nfV61a1eJ6jh8/jqlTpzZrnzVr1rS4\nvo6UmZnZqv3t7OyUrnv5O6knk8mgoaG645Oy+lryPbeG2gSyKgmFQsTHx2PXrl3Iz89HcXExLC0t\nER4ejpKSEqxcuRIVFRUYMGAAQkND2f0OHz6M7OxsCIVCrF69GqGhoUhKSkJBQQF27NiByMhITJs2\nDdbW1vj9998RFBSE0aNHw9PTE0lJSQgJCUG3bt1w7949vP/++wgMDMTNmzcRGhoKKysrPHz4ECdO\nnOi4D6YDCYVCDBkyBMXFxfjyyy8RHByM8vJy2NjYYPXq1UhOTsb58+dRVVUFqVQKZ2dn/Pjjj+jf\nvz/Cw8PZciQSCQIDA1FZWQkTExNER0fj6tWr2LhxI7S1teHp6QlTU1PEx8cDADw9PfHrr7/i7t27\nkMlkiIqKwp07d9jv2c/PD7q6uo32z8rKgp+fH549e4a9e/eiR48ereo/nVrzcOrUKQiFQqxYsaLR\nusGDB2Pfvn0oLCxEWVkZYmNjMXfuXCQkJKBHjx64fv06u627uzusra2RkJCAQYMGKazr+fPnWLx4\nMWJjY3H48OFG60eNGoWkpCRcuHABAPDtt99i165dCAsLw+PHj9uox68+Rd+Ji4sLoqKicPjwYfzz\nn/9EYmIiKisrcfPmTQCAmZkZYmNjYWFhgZqaGiQmJqKwsBDPnj1jyygqKoKJiQkSEhKwdetWAMDm\nzZvx7bffIiEhAR999BEAoKamBrt374aTkxOCgoJw8OBBBAYG4vDhwxg7diz7Pb///vsK99fW1mb3\nv3z5cqs/Dzoi86DsNA4A3nrrLQB1fyRisRj379/Hpk2bIBAIUFFRAXt7e87yG76A1rNnT5iamgIA\nysrKlNanq6sLACgvL0fv3r0BAP379+ffqU5O0Xdia2sLAPjzzz/h5OQEoO70PC8vDwBgbW0NoO67\navi9lZWVwdjYGADQr18/WFtbIygoCHZ2dvjiiy/AMAxMTEwAgD2Nrq8LAOLi4nD58mXU1tZiwIAB\njdqqaP/6tpibm7fJNT4FcisJBAL2Z4ZhYGVlhU8++YS9vqutrVW6fXV1NQAgOztb4Xqu+gBAX18f\nIpEIhoaG7B9sV1UfJP369UNWVhbeeustZGZmws3NDQ8ePJD77F7+3upJJBLMnDkTGhoa8PHxwaRJ\nkyAQCFBaWoqePXtCJpPJ7V9aWoorV67g0KFDuHjxIk6fPt2ofEX7N9QWbxJTILcxPz8/rFmzBmKx\nGBoaGggLC0Pfvn3Z9X369MH8+fOxaNEifPDBB/D09MTQoUNbXF9AQAD8/f3Rr18/9OnTpy260OlN\nmzYNQUFBOHLkCAYNGoRhw4bhwYMHvPZ99OgRVq5cCZlMhr59+8LU1BRLliyBv7+/3DVuPSMjI3Tv\n3h0zZsyQu1yyt7dHQEAAfHx8mty/rQgosUDnVltbCy0tLbx48QI+Pj74/vvvO7pJpAPQEbmTu3bt\nGrZt24aKigrMmzevo5tDOggdkQlRA/T4iRA1QIFMiBqgQCZEDVAg83T58mUIhUJ4eXlh3rx5KC0t\nbbOyt2/fjkuXLuHOnTs4evSo3LqCggKEhIQo3bfh4PzWvMxBOje1uWs97bB/q/Y/4r5L6bqSkhLs\n3LkTu3fvhr6+Ph4+fIiamppW1aeIjY0NbGxsmrVPw8H5rXmZg3RuahPI7enChQuYPHky9PX1AQBW\nVla4e/cuFixYAJlMBi8vL0yePFnhSw2bN2/Gf//7X2hpaSEqKgpSqRQhISGQSCRwdnbGnDlz2Hoy\nMjJw6dIlLF68GNHR0cjIyMDAgQPZ9WFhYU0Ozt+9ezfi4+NRXFzcqI7t27c3esGDqA86tebhyZMn\n6NWrl9yyrVu3IioqComJiTh48CB7hH75pYZr164hMTERCQkJMDMzw549e7BgwQJ8//33yMjIgEgk\nalRfcXExbt26hUOHDmHEiBHscq7B+fWU1fHyCx5EfdARmYdevXqhuLhYbllZWRk79LJv374oKSkB\n0PilBl9fXyxfvhzGxsZYvHgx/vzzTwwePBhA3al0QUFBo/oeP37MDveztbXFxYsXAXAPzq+nrI6X\nX/AwNDRswadBXkV0RObByckJp06dQnl5OQAgLy8POjo6KCgoQE1NDfLz89m3W15+qcHR0REbN26E\nqakpzp8/zw7oB4A7d+7IjcOuZ2Fhwb5IcefOHQDyg/MXLlzIDrRX9JKFsjqUvShAOj86IvNgYmKC\ngIAA+Pn5gWEYGBkZISgoCMHBwZBKpfDy8oK2trbCfQMCAlBVVQUAiI6OxtChQxESEoKamhqMGTMG\n5ubmjfYxMzODra0tpk+fjrfffhsAv8H59Xx9fTnrIOqFhmgSogbo1JoQNUCBTIgaaLdAXrFiBd57\n7z1MnDhR4XqGYRAWFgYXFxdMmjSJvTlDCGk+zkD+448/WlSwq6sr4uLilK5PT09Hbm4ufvrpJ6xb\nt04u2yQhpHk4A/mrr77C1KlTkZiY2KwkYSNGjICRkZHS9ampqfj0008hEAgwbNgwlJWVNXpWSwjh\nhzOQDx06hKioKBQVFcHV1RVBQUHsAIXWEIlEbPZHAOjdu7fCUU7K1NbWoqCgoFFyO0K6Il7Pkfv3\n749FixbBzs4OYWFhuH37NhiGwZIlS/Dhhx+2dxtx+PDhRjmeJRIJ7t27h9TUVIWDKgjpSjgD+e7d\nu0hOTsaFCxfw97//Hbt374atrS1EIhE8PDxaHMjm5uYoKipify8qKlI6cMHd3R3u7u5yywoKCjB2\n7NgW1U2IuuE8tQ4LC8PgwYNx8uRJfPnll2xibnNzcyxcuLDFFTs7O+OHH34AwzC4ceMGDAwMYGZm\n1uLyCOnKOI/IMTEx0NXVhaamJoC6Sauqq6uhp6eHTz/9VOl+S5YswZUrV1BaWorRo0dj/vz57PWs\np6cnnJyccOHCBbi4uEBPTw/ffPNNG3WJkK6Hc4jmtGnTEB8fz04yVVFRgVmzZnV4/uT6U2u6RiaE\nx6l1dXW13ExxPXr0QGVlZbs2ihDSPJyBrKenJzfqKjMzk33XlhDyauC8Rl65ciUWLlwIMzMzMAyD\nv/76q81nWyeEtA5nINvb2+Nf//oXHj58CKAuX5Wyd28JIR2D14CQhw8fIicnBxKJBLdv3waAJu9Y\nE0JUizOQd+zYgYyMDNy/fx9OTk5IT0/H8OHDKZAJeYVw3uw6e/Ys9u/fj9deew0RERE4efJkm8yw\nTghpO5yBrKOjAw0NDWhpaaG8vBympqYoLCxURdsIITxxnlrb2dmhrKwMbm5ucHV1Rffu3eHg4KCK\nthFCeGoykBmGwdy5c2FoaAhPT0+MGjUK5eXlbGZHLunp6QgPD4dMJoObm5vcrApAXf7m5cuXQywW\nQyqVIjg4GE5OTi3vDSFdFcNh4sSJXJsoVFtby4wdO5b5888/merqambSpEnMvXv35LZZvXo1k5iY\nyDAMw9y7d48ZM2YM7/Lz8/MZa2trJj8/v0XtI0SdcF4jDx48GLdu3Wr2P4hbt27hjTfewOuvv45u\n3bphwoQJSE1NldtGIBCwSd/FYjG9/URIC3FeI9+8eROnT5+GhYUF9PT02OWnT59ucr+XM4CYm5s3\n+ocQGBiIWbNm4eDBg6isrER8fHxz208IAY9A3rt3b7tVnpKSgilTpsDHxwfXr1/HsmXLcObMGWho\nyJ8oKMsQ0hqZmZlNrrezs2tV+aRzmhR0ssn1pzdNVlFLmoczkBXNLcTHyxlARCJRowwgx44dYzNt\nOjg4oLq6GqWlpTA1NZXbjjKEENI0zkCeO3cu+3N1dTUKCgpgZWWFlJSUJvcbMmQIcnNzkZ+fD3Nz\nc6SkpGDTpk1y2/Tp0weXL1+Gq6sr7t+/j+rqanYyNEIIf5yB/PK1cFZWFg4dOsRdsJYW1q5dC19f\nX0ilUnz22Wd46623EB0dDTs7O4wdOxYhISFYvXo19u3bB4FAgMjIyBafARDSlTV7NkZbW1ved7Gd\nnJwaPRdumOdr4MCBHZ5phLTOtMP+Ta4/4r5LRS1RjVe1v5yB3PBOskwmw+3bt+kxESGvGM5Arqio\nYH/W1NSEk5MTxo8f366NIoQ0D2cgBwYGqqIdbY7rFGit7TwVtaRro8d8qsEZyF988QWio6NhaGgI\nAHj+/DmWLFnSrs+XO7NX9RqKqDfOQC4pKWGDGACMjIzw9OnTdm0Uab6u9g+kq/WXC2cga2pq4vHj\nx7CwsAAAPHr0iB4RtQKdaqpGR33OHVUvZyAvWrQI06dPx4gRI8AwDK5evYqvv/66XRpDSD2uoZJ6\nI1XUkE6CM5BHjx6N5ORk3Lx5E0BdelwafUXIq4UzkH/++Wc4OjpizJgxAICysjL88ssvGDduXLs3\njnQ8OjJ2DryyaLq4uLC/GxoaYseOHRTInQxdm6s3zsQCMpms0TKpVMqr8PT0dIwfPx4uLi6IjY1V\nuM2PP/6Ijz/+GBMmTEBQUBCvcgkh8ngl34uIiICXlxcAIDExkZ0juSlSqRRff/014uPjYW5ujqlT\np8LZ2RkDBw5kt8nNzUVsbCySkpLosRYhrcB5RF6zZg20tbWxaNEiLFq0CN26dcPatWs5C+aT6ufI\nkSPw8vKCkZERADR6D5kQwg/nEbl79+4IDg5udsF8Uv3k5uYCADw8PCCTyRAYGIjRo0c3Kqs9MoQQ\nok54jezas2cPcnJyUF1dzS4/cOBAqyuXSqXIy8tDQkICioqK4O3tjdOnT8uNJAMoQwhAd49J0zhP\nrYODg/Hmm2+ioKAAgYGBsLS0xJAhQzgL5pPqx9zcHM7OztDW1sbrr7+O/v37s0dpQgh/nIH87Nkz\nuLm5QUtLCyNHjkRERAR+++03zoIbpvqRSCRISUmBs7Oz3Dbjxo3DlStXANQd+XNzc/H666+3sCuE\ndF2cp9ZaWnWbmJmZ4fz58zAzM8Pz58+5C+aR6mfUqFG4ePEiPv74Y2hqamLZsmXo2bNn63tFSBfD\nGcj+/v4Qi8VYvnw51q1bh4qKCqxYsYJX4VypfgQCAVasWMG7PEKIYpyBXD8008DAAAkJCe3eIEJI\n8zU7+V5XR3ePyauI82YXIeTVR4FMiBrgPLWWSCQ4e/YsHj16hNraWnZ5Z03KR4g64nXX2sDAALa2\ntujWrZsq2kQIaSbOQBaJRJQxk5BXHGcgOzg44I8//sCgQYNU0Z5m8w3/GdrdG6ceorvHpCvhDOSr\nV6/ixIkTsLS0lDu15pronBCiOpyBvGfPnhYXnp6ejvDwcMhkMri5uWHOnDkKtzt79iwWLFiAY8eO\n8XohgxAij/Pxk6WlJcRiMc6dO4dz585BLBbD0tKSs+D6DCFxcXFISUnBmTNnkJOT02i78vJyHDhw\nAEOHDm1ZDwgh3IG8f/9+BAcH4+nTp3j69CmWLl3Ka6gmnwwhABAdHY3Zs2dDR0enZT0ghHCfWh87\ndgxHjhxB9+7dAQCzZ8+Gu7s7hEJhk/vxyRCSlZWFoqIifPDBB03eGacMIYQ0jddYa01NTYU/t4ZM\nJkNkZCQiIiI4t6UMIYQ0jTOQXV1d4ebmxua2/uWXX/DZZ59xFsyVIaSiogLZ2dmYMWMGAODJkyfw\n9/fHrl276IYXIc3Ea1rVkSNH4urVqwCAiIgIDB48mLPghhlCzM3NkZKSgk2bNrHrDQwMkJGRwf4u\nFAqxbNkyCmJCWkBpIJeXl0NfXx/Pnj2DpaWl3J3qZ8+ewdjYuOmCeWQIIYS0DaWBHBQUhJiYGLi6\nuspNo8owDAQCgcI70C/jyhDSECUtIKTllAZyTEwMACAtLU1ljSGEtAznc+TPP/+c1zJCSMdRekSu\nrq5GZWUlSktL8fz5czAMA6Du2lkkEqmsgYQQbkoD+fvvv8f+/ftRXFwMV1dXNpD19fXh7e2tsgYS\nQrgpDeTPP/8cn3/+ORISEjhHcRFCOhbnc2ShUIjs7Gzk5OTIDYv89NNP27VhhBD+OAN5x44dyMjI\nwP379+Hk5IT09HQMHz6cApmQVwjnXeuzZ89i//79eO211xAREYGTJ09CLBarom2EEJ44A1lHRwca\nGhrQ0tJCeXk5TE1NUVhYqIq2EUJ44jy1trOzQ1lZGdzc3ODq6oru3bvDwcGBV+FcGULi4+Nx9OhR\naGpqwsTEBN988w2vpAWEEHmcgRwaGgoA8PT0xKhRo1BeXo63336bs+D6DCHx8fEwNzfH1KlT4ezs\njIEDB7Lb2NjY4Pjx49DT08OhQ4ewceNGbN26teW9IaSLUhrIWVlZSnfKysqCra1tkwU3zBACgM0Q\n0jCQHR0d2Z+HDRuGU6dO8W44IeT/KA3kyMhIAHWZODIzM9l0uH/88Qfs7OwaZex4GZ8MIQ0dO3YM\no0ePVriOMoQQ0jSlgVz/NlJgYCCSk5PZQM7OzsaOHTvatBEnT55EZmYmDh48qHA9ZQghpGmc18gP\nHz6US05vbW2N+/fvcxbMlSGk3qVLl7B7924cPHiQpqQhpIU4A3nQoEFYtWoVPvnkEwB1ien5zDrB\nlSEEAG7fvo21a9ciLi4OpqamLewCIYQzkCMiIpCUlIQDBw4AAEaMGAFPT0/ugnlkCNmwYQNevHjB\nJhvo06cPdu/e3couEdL1cAayjo4OZs6ciZkzZza7cK4MIfv27Wt2mYSQxpQG8sKFCxEdHY1JkyYp\nXE9zPxHy6lAayKtWrQIAOtUlpBNQGshmZmYAQEMmCekElAayg4ODXPbMevVZNK9du9auDSOE8Kc0\nkK9fv67KdhBCWoHX3E8A8PTpU1RXV7O/W1hYtEuDCCHNxxnIqampWL9+PYqLi2FiYoLHjx9jwIAB\nSElJUUX7CCE8cCYWiI6OxuHDh9G/f3+kpaVh3759NCk5Ia8YzkDW0tJCz549IZPJIJPJ4OjoiMzM\nTFW0jRDCE+eptaGhISoqKjBixAgEBwfDxMSEnfScC1eGEIlEgmXLliErKwvGxsbYsmUL+vbt27Ke\nENKFcR6Rv/32W+jq6mLFihUYNWoU+vXrh127dnEWXJ8hJC4uDikpKThz5gxycnLktjl69CgMDQ3x\n888/Y+bMmYiKimp5TwjpwpQekb/66itMnDgRw4cPZ5dNmTKFd8F8MoSkpaUhMDAQADB+/Hh8/fXX\n7HNqLlKpFABQU/lM4XrN0som9y8uLm5yfUFBgcLlNS9KmtyP6qV6W1KvIr1794aWFr8HS0q36t+/\nPzZs2IAnT57go48+wsSJE3lNcF6PT4YQkUiEPn361DVESwsGBgYoLS2FiYmJ3HaKMoRUVFQAAAou\nKxlCyjGJpD8u8+lG81G9VG8bSU1N5X2pyTllzKNHj5CSkoKVK1eiqqoKEydOxIQJE2BlZdVmDeai\nKENIVVUVMjMz0atXL2hqaja7TD8/vw4ZR071Ur18NTwQcuE8bltaWmLOnDmYM2cObt++jZUrV2Ln\nzp24c+dOk/vxyRBibm6OwsJC9O7dG7W1tRCLxejZsyevhuvq6uLdd9/lta0i3bp165Aba1Qv1dse\nOG921dbWIi0tDUFBQZg9ezasrKywfft2zoIbZgiRSCRISUmBs7Oz3DbOzs44ceIEgLoZLRwdHXld\nHxNC5Ck9Il+8eBFnzpxBeno6hgwZggkTJmDdunW8Hz3xyRAydepULF26FC4uLjAyMsKWLVvarGOE\ndCVKAzkmJgaTJk1CSEgIjIyMWlQ4V4YQHR0dbNu2rUVlE0L+j2Zo/VQSL5kyZQpsbW2hq6ur4iap\njp2dHdVL9apFvQKGYRiV10oIaVOcN7sIIa8+CmRC1ADvxAKvKrFYjICAAAB1Ce8HDx6Mvn37IiIi\nosn96qdzdXV1bbROIpHA29sb9+7dw+nTpxU+F2yPevPy8rBixQoIBAJYWFggMjKy0WCX9qj3r7/+\nQmBgILS0tGBoaIgtW7ZAR0en3eut9+OPPyIqKgppaY2HTbVHvbW1tXB0dISNjQ2AuvcJDAwMVNLf\n9PR0xMXFgWEYrFy5km1DqzFqxMPDg/e2R44cYY4fP65wnVQqZf766y8mKCiIyc/PV1m9paWljFgs\nZhiGYTZs2MCcP39eJfXW1tYyUqmUYRiG2bJlC/PTTz+ppF6GYRiZTMYsWbKEcXd35yyrreqtqalh\nvL29eZfVVvVWVFQw8+fPZ2pra3mXx5danlqfP38eQqEQrq6u7FStBw4cwLRp0yAUCnH37l1228LC\nQsyZMwdPnjxhl2loaLRoCpvW1mtsbAx9fX0Adc/h+Q49bW29mpqa0NCo+1NgGAb9+vVTSb1A3Ysz\no0aNatZAoLaoNzs7G9OnT2/W2IXW1nvt2jVoaGjA19cXy5cvR2Vl0y9gNEub/2voQPX/OV+8eMEw\nDMNIJBJ22cyZM5mqqiqGYeqOAkeOHGF27tzJzJ49mxGJRArLa+4Rua3qLSwsZDw8PDj/c7dlvdeu\nXWOmTJnCTJ8+nT0rUEW9CxYsYGpqangd9dqy3mfPnjEymYxZsWIF7zOf1tZ74sQJxtvbm6mtrWUO\nHDjA7N+/n7PPfHX6a2RFfv/9d+zcuRNSqRQPHjwAUDc97Nq1a6Gjo4NFixYBAA4dOoTg4GA2h/er\nUG9VVRVCQkIQFhbG+4jcFvU6ODggOTkZsbGxOHHiBIRCYbvX+5///AcjRozg/apeW/a3fpDTuHHj\nkJ2d3WjgUnvUa2BggOHDh0NTUxOOjo5KpxFuCbU8tY6NjcX69evx3Xffsaeqtra2WL9+Pd555x38\n8MMPAOq+hH//+99NTsCu6npXr16NGTNmYMCAASqrt+Gk8fr6+rwHAbW23pycHPz888+YNWsWcnJy\neI/ya229L168gEwmA1B3usv3UqK19drb27NTEt+5c6dNX65QyyPyhx9+iLlz58LGxgaGhoYA6gKk\nsLAQEokEkZGRuHbtGrp164aoqCgsWLAAq1evxptvvsmWMX/+fFy/fh1Lly7FnDlzMGbMmHav93//\n+x/S0tIgEokQHx+PmTNn8prMvbX1ZmVlISoqChoaGujZsyc2bNigks+54eSAnp6eWLBggUrqffDg\nAdasWQM9PT288cYbcHFxUUm9vXr1wtChQ+Ht7Q1dXV1s3ryZV7180MguQtSAWp5aE9LVUCATogYo\nkAlRA2oVyP369cM//vEPldQ1depUhYNG7OzsOuz1OaK+bty4wQ7aUUSt7lpXVlbi+fPnKqmrpKRE\nblK7euXl5Sqpn3QtIpEITd2XVqsjMiFdlVodkYG6kVFtNcCjKWKxWOk6iUSikjaQrqN+JJkyavUc\n2cbGRm7genuzsLDAo0eP5Jb97W9/w5UrV1TWBtJ1aGhosDOsvEytArmlBAIBiouL0atXr45uSpfi\n4+ODpKSktn0LqIuiQEbdPFX29vYd3Ywup7KyEvn5+bC2tu7opnR6FMiEqIFOe7NLT08PVVVVHd0M\nhXR1dZt9uvgq94e0j5b8nSjTaY/IAoGgyedqHaklbXuV+0PaR1t+5/QcmRA1oDaBfPnyZQiFQnh5\neWHevHnw8/NDXl4ekpOTMX78eAiFQvj6+rLb79y5U+73htsFBASwL9uHhobC0dERR48elavL3d0d\nQqFQbsbJ9upLaWkpQkJCVNofvn309/fHu+++i0uXLrHLTp06BQ8PD8ydO5fXSDdPT0/+H04HO3bs\nWLO2r//e2l2bJQ1SsYZNf/r0KePl5cXmmnrw4AHj4+PD5ObmMsePH2eOHDnSaH9fX18mMDCQKSsr\nYxiGkdtu586dTFpaGsMwDCMSiRqV4e3tzYjFYubGjRtMaGhok21rbn8U9UUkEjHLly9XaX+4+lhP\nJBIx27ZtYy5evMgwTF0+K09PT6ampoZJSUlh9uzZw9n35mSpVLX67KL1mtvW+u9NkbYMP7U4Il+4\ncAGTJ09m069YWVk1+Uw4Pz8fffv2xbhx43D+/PlG6xuO2no571JlZSV0dXWhr6+PoUOHIicnp206\n8f8p6gtXTrG27k9z+vhyeXl5ebC2toaWlhbee+893LhxQ279kydP4OvrC6FQiE2bNsmti4mJgbe3\nN9zc3HD79m0AwPLly+Ht7Q2hUAiZTIbNmzfD09MTQqEQIpEIJSUl8PPzg1AoRP00ZomJiWxmy6ys\nLLk6pk2bhpCQELi6uuLcuXMAgJs3b0IoFMLDwwPHjx8HAAiFQmzYsAHLli1j901NTUV2djaEQiEu\nXrzIq731fvvtNwQFBaGmpkbpZ9kanfaudUNPnjxp8llkXFwcTp06hWHDhiEoKAi//PILxo8fDzs7\nO3z11VeYNGkSu92BAwdgaGiIJUuWKCyrrKyMDTIASkfatFdf6tvZnv1pTR8b7mtgYICysjK59TEx\nMZg5cyb+8Y9/yP2hA8CMGTMwd+5c5OXlYdu2bYiMjERRUREOHjwIhmEgEAhw7do1JCYmQkNDAwzD\nYP369Zg7dy4cHBywceNGXL9+HampqThw4AB0dXUb3UwqKSnB1q1bYWxsDB8fH4wZMwbbtm3Drl27\n0KNHD3zxxRfs5+fi4gIHBwd237Fjx8La2hoJCQkAgHfeeYezvQBw5coV/Pbbb4iMjIS2tjbvz7I5\n1CKQe/XqheLiYqXrfX194ebmxv5+4cIF/PrrrxAIBMjLy2PfYvL19YWrqyvmzZuH58+f47XXXmtU\nloGBgdx2b/n4AAADDklEQVR1H99Ml3xx9aW+ne3Zn9b0seG+5eXlbG6rerm5uWxwvPxa3smTJ3H6\n9Gl2uba2NqZMmYLg4GBYWlpi4cKFbE5oY2NjLF68GPfv38emTZsgEAhQUVEBe3t7zJ8/H6GhodDW\n1sbChQvl+m1sbAwLCwu5ft29exf+/v4AgNLSUpSWlgKoS6zXFD7tBeruX+zbt6/dghhQk5tdTk5O\nOHXqFPsHlJeX1ygheb0nT56gd+/e+O6777B3717MmjULFy9eZNdramrCy8sL+/btU7h/9+7dUVVV\nhYqKCty6datZ2S5b2pemArs9+qOsjyKRiLP9/fv3x7179yCVSnHp0iUMHTpUbr2VlRVu3rwJAI2O\nyIcOHUJCQgLWrVsHoO5MYMKECYiKikJJSQl+//13ODo6YuPGjTA1NcX58+dhZWWFkJAQJCQkIDk5\nGWPHjoWNjQ0iIyMxcuRIJCcny9Xx/PlzFBUVobKykj3TsLGxQUxMDBISEnDixAmYm5sDaPyPBoBc\nIn0+7QWAiIgIhIaGoqSkhPPzaym1OCKbmJggICAAfn5+YBgGRkZGSv/7paamYvjw4ezvI0eORFxc\nHEaMGMEue//997F161ZIJBLs3bsXZ86cAcMwEIlECAwMhL+/P3x8fNCtWzesX7++3fsSHh6udPv2\n6o+iZcuWLUN8fLzcH3hYWBjOnTuHtLQ0eHh4wN3dHW5ubvDy8oKhoWGj6+A5c+YgJCQEu3btgoOD\ng9wpv729Pby8vNi2V1RUwN/fH1KpFPr6+rC2tkZAQAA7cCY6OhqOjo5Ys2YNxGIxNDQ0EBYWhu3b\nt6OgoAASiaTRXE09e/bE9u3bcefOHcybNw8AsGDBAvbzNjY2xvbt25V+3vb29ggICICPjw+v9gJ1\nL9esWrUKS5cuxbZt29CjRw+l5bcUDQhpB+o4IEQmkyE8PBxr1qzp6Ka0iqenJ5KSkjq6GQBoQAjp\nABoaGp0+iNUZHZHbgToekUnboyMyIUROp73Zpaur26ypOFWJ79xJL+/zqvaHtI+W/J0o8/8ADAEZ\nnqZHJdcAAAAASUVORK5CYII=\n", 103 | "text/plain": [ 104 | "" 105 | ] 106 | }, 107 | "metadata": {}, 108 | "output_type": "display_data" 109 | } 110 | ], 111 | "source": [ 112 | "sns.set_style('ticks')\n", 113 | "figure(figsize=(3.3, 2.2))\n", 114 | "ax = axes()\n", 115 | "cval = 0.01 # orig value\n", 116 | "cvals = [0, cval, 'scratch']\n", 117 | "# cvals = [cval]\n", 118 | "# cvals = [0, cval]\n", 119 | "n_tasks = 6\n", 120 | "group_width = 0.8\n", 121 | "bar_width = group_width/len(cvals)\n", 122 | "bar_width = group_width/3\n", 123 | "index = np.arange(n_tasks)\n", 124 | "xtick_labels = ['Task %i'%(i+1) for i in range(n_tasks)]\n", 125 | "# xtick_labels[0] = 'CIFAR10'\n", 126 | "\n", 127 | "def do_plot(eval_type=0, age=-1):\n", 128 | " for k,cv in enumerate(cvals):\n", 129 | " means = []\n", 130 | " stdevs = []\n", 131 | " # print(cv)\n", 132 | " for tid in range(n_tasks):\n", 133 | " if cv=='scratch':\n", 134 | " a = tid\n", 135 | " else:\n", 136 | " a = age\n", 137 | " means.append( data['mean'][cv][a, tid, eval_type] )\n", 138 | " stdevs.append( data['std'][cv][a, tid, eval_type] )\n", 139 | " # print(means)\n", 140 | " \n", 141 | " bar(index+k*bar_width, means, width=bar_width, yerr=stdevs, color=colors[k], ecolor='gray')\n", 142 | " xticks(index)\n", 143 | " # gca().set_xticklabels()\n", 144 | " if eval_type==0:\n", 145 | " ylabel('Validation accuracy')\n", 146 | " else:\n", 147 | " ylabel('Training accuracy')\n", 148 | " xticks(index+group_width/2, xtick_labels, fontsize=8)\n", 149 | " xlim(-0.1, 6.0)\n", 150 | " # ylim(0.5, 1.0)\n", 151 | " yticks(np.arange(0.0, 1.1, 0.2))\n", 152 | " legend(('Fine tuning', 'Consolidation', 'From scratch'), bbox_to_anchor=(0., 1.02, 1., .102), loc=3,\n", 153 | " ncol=2, mode='expand', borderaxespad=0., fontsize=8)\n", 154 | " \n", 155 | " ax.annotate('CIFAR10', xy=(0.27, 0.12), xytext=(0.27, 0.02), xycoords='figure fraction', \n", 156 | " fontsize=8, ha='center', va='bottom',\n", 157 | " bbox=dict(boxstyle='square', fc='white'),\n", 158 | " arrowprops=dict(arrowstyle='-[, widthB=1.7, lengthB=0.5', lw=1.0))\n", 159 | "\n", 160 | " ax.annotate('CIFAR100, 10 classes per task', xy=(0.66, 0.12), xytext=(0.66, 0.02), xycoords='figure fraction', \n", 161 | " fontsize=8, ha='center', va='bottom',\n", 162 | " bbox=dict(boxstyle='square', fc='white'),\n", 163 | " arrowprops=dict(arrowstyle='-[, widthB=9.4, lengthB=0.5', lw=1.0))\n", 164 | "\n", 165 | "do_plot(eval_type=0)\n", 166 | "sns.despine()\n", 167 | "subplots_adjust(left=.21, bottom=.25, right=.98, top=.82)\n", 168 | "savefig(\"cifar10_cifar100_transfer_valid.pdf\")" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 24, 174 | "metadata": { 175 | "collapsed": false 176 | }, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPIAAACrCAYAAABc6cGbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XlYVGX7B/DvsJvshIihYhqKGGhmL72aJka+ZmqiCAhj\nhqSAiAsuYFqkkCi4YYYhhgmIG2gora+KGiqW+4KXCkKiMNALySLrzPn9wY9zMTIz5zAwgMP9ua6u\nZM45zzIz95wzZ57nfgQMwzAghLzQNDq7AYSQtqNAJkQNUCATogYokAlRAxTIhKgBCmRC1AAFMiFq\ngAKZEDVAgUyIGqBA5pCVlYXx48dDKBRCKBSioqIC4eHhEIvFSpV35MiRVh+zfv16pepSV7Jek65G\n0euclZWFrVu3tmt9Wu1aWieaEvRDm44/vnma3G1Tp07F0qVL2b8/++wzpetJSUnBzJkzW3XM2rVr\nla6vM926datNxw8bNkzutudfkyYSiQQaGh13fpJXnzKvc1uoTSB3JKFQiPj4eMTExODRo0coLi7G\nK6+8gvDwcJSWlmL16tWoqqrCwIEDERoayh538OBB3Lt3D0KhEGvWrEFoaCiSk5NRUFCAr7/+GhER\nEZg1axZsbGxw8+ZNBAUFYezYsfDw8EBycjKCg4Oho6OD+/fvY/To0QgICMD169cRGhqKAQMG4OHD\nhzh69GjnPTGdSCgU4vXXX0dxcTG++OILLF++HJWVlbC1tcWaNWuQmpqKjIwM1NTUQCwWw8nJCT/+\n+COsra0RHh7OllNXV4eAgABUV1fD1NQU27dvx+XLlxEZGQltbW14eHjAzMwM8fHxAAAPDw+cO3cO\nd+/ehUQiQVRUFLKzs9nX2dfXF3p6ei2Ov337Nnx9ffHPP/9gz5496NmzZ5v6T5fWPKSlpUEoFCIk\nJKTFtqFDh2Lv3r0oLCxEeXk5YmNjsWDBAiQkJKBnz564evUqu6+bmxtsbGyQkJCAwYMHy6zr6dOn\nWLp0KWJjY3Hw4MEW29955x0kJyfjzJkzAIBvvvkGMTExCAsLw5MnT9qpx12frNfE2dkZUVFROHjw\nICZNmoSkpCRUV1fj+vXrAIBevXohNjYWffr0QX19PZKSklBYWIh//vmHLaOoqAimpqZISEjAtm3b\nAABbtmzBN998g4SEBPznP/8BANTX12PXrl0YN24cgoKCkJiYiICAABw8eBATJkxgX+fRo0fLPF5b\nW5s9/sKFC21+PuiMzIO8yzgAeO211wA0vkkqKiqQk5ODzZs3QyAQoKqqCvb29pzlN5+AZmJiAjMz\nMwBAeXm53Pr09PQAAJWVlejduzcAwNramn+nXnCyXhM7OzsAwF9//YVx48YBaLw8z8/PBwDY2NgA\naHytmr9u5eXlMDY2BgD069cPNjY2CAoKwrBhw/DJJ5+AYRiYmpoCAHsZ3VQXAMTFxeHChQtoaGjA\nwIEDW7RV1vFNbbGwsGiX7/gUyG0kEAjYfzMMgwEDBmDq1Kns97uGhga5+9fW1gIA7t27J3M7V30A\noK+vD5FIBENDQ/YN2101BUm/fv1w+/ZtvPbaa7h16xZcXV2Rm5sr9dw9/7o1qaurw9y5c6GhoQFv\nb29MmTIFAoEAZWVlMDExgUQikTq+rKwMly5dwv79+5GZmYnjx4+3KF/W8c21x0xiCuR25uvri7Vr\n16KiogIaGhoICwuDlZUVu93S0hKLFi3CkiVL8O6778LDwwMODg5K1+fv7w8/Pz/069cPlpaW7dGF\nF96sWbMQFBSEQ4cOYfDgwRg+fDhyc3N5Hfv48WOsXr0aEokEVlZWMDMzw7Jly+Dn5yf1HbeJkZER\nXnrpJcyZM0fq65K9vT38/f3h7e2t8Pj2IqDEAi+2hoYGaGlp4dmzZ/D29saBAwc6u0mkE9AZ+QV3\n5coVREdHo6qqCgsXLuzs5pBOQmdkQtQA/fxEiBqgQCZEDVAgE6IGKJB5unDhAoRCITw9PbFw4UKU\nlZW1W9k7duzA+fPnkZ2djcOHD0ttKygoQHBwsNxjmw/Ob8tkDvJiU5u71rMO+rXp+ENuMXK3lZaW\nYufOndi1axf09fXx8OFD1NfXt6k+WWxtbWFra9uqY5oPzm/LZA7yYlObQFalM2fOYNq0adDX1wcA\nDBgwAHfv3kVgYCAkEgk8PT0xbdo0mZMatmzZgj/++ANaWlqIioqCWCxGcHAw6urq4OTkhPnz57P1\nZGVl4fz581i6dCm2b9+OrKwsDBo0iN0eFhamcHD+rl27EB8fj+Li4hZ17Nixo8UED6I+6NKah5KS\nEpibm0s9tm3bNkRFRSEpKQmJiYnsGfr5SQ1XrlxBUlISEhIS0KtXL+zevRuBgYE4cOAAsrKyIBKJ\nWtRXXFyMGzduYP/+/Rg1ahT7ONfg/Cby6nh+ggdRH3RG5sHc3BzFxcVSj5WXl7NDL62srFBaWgqg\n5aQGHx8frFq1CsbGxli6dCn++usvDB06FEDjpXRBQUGL+p48ecIO97Ozs0NmZiYA7sH5TeTV8fwE\nD0NDQyWeDdIV0RmZh3HjxiEtLQ2VlZUAgPz8fOjq6qKgoAD19fV49OgRO7vl+UkNjo6OiIyMhJmZ\nGTIyMtgB/QCQnZ0tNQ67SZ8+fdiJFNnZ2QCkB+cvXryYHWgva5KFvDrkTRQgLz46I/NgamoKf39/\n+Pr6gmEYGBkZISgoCMuXL4dYLIanpye0tbVlHuvv74+amhoAwPbt2+Hg4IDg4GDU19dj/PjxsLCw\naHFMr169YGdnh9mzZ2PIkCEA+A3Ob+Lj48NZB1EvNESTEDVAl9aEqAEKZELUAAUyIWqAApkQNUCB\nTIgaeGEDuaGhAQUFBS2S2xHSHb2wgVxUVIQJEyagqKios5tCSKfjDGRZ6Tv5CAkJwdtvv40PP/xQ\n5naGYRAWFgZnZ2dMmTKFHYlECGk9zkB+//33sXnzZjx8+LBVBbu4uCAuLk7u9rNnzyIvLw+//vor\n1q9fL7W0CiGkdTgDOTU1FZaWllixYgU8PDxw5MgRVFVVcRY8atQoGBkZyd1+8uRJfPTRRxAIBBg+\nfDjKy8tbTEwghPDDOdba0NAQs2fPxuzZs5GVlYXly5cjPDwckyZNgp+fH/r27atUxSKRiF3qBAB6\n9+4NkUiEXr16tdj34MGDLdZBqqurU6rezsa1QqGiFQipXtXjWtVT0aqdnYkzkCUSCc6dO4eUlBTk\n5eVhzpw5mDp1Kv7880/4+Pjgl19+UXkj3dzc4ObmJvVYQUEBJkyYoHSZL9objLROd/vg4gzk999/\nHyNHjoRQKJSa5D558mT88ccfSldsYWEhdce5qKiIZukQoiTOQD527Bib4uZ5bblB5eTkhMTEREye\nPBnXr1+HgYGBzMtqQgg3zptdX331lVRamKdPn2LNmjWcBS9btgzu7u54+PAhxo4di8OHDyM5ORnJ\nyckAGifr9+3bF87Ozli7di2++OKLNnSDkO6N84x8584dqZQwRkZGvH7z3bJli8LtAoGAgpeQdsLr\nZldFRQUMDAwANJ6RaVhk18OVDvhzO1rgrT101eeZM5A//vhjuLm54YMPPgDDMPjpp5/g4+PTEW0j\nL4Cu+sbubjgDecaMGbCzs8PFixcBNF4yN+WRIi3RG5t0Bl7J94YMGYLevXujtrYWQONgDvqpiJCu\ngzOQMzIysGHDBhQVFcHExAQikQj9+/fHzz//3BHtI93UizrCqrNwBvLWrVuRnJwMb29vHDt2DJmZ\nmfjpp586om2EyEVfYaRxBrKmpiZMTU0hkUjAMAxGjx6NjRs3dkTb2oReaNKdcAaygYEBqqqqMHLk\nSKxcuRJmZmbsciiEkK6Bc2TXzp07oaenh9WrV+Ott96ChYUFdu3axavws2fPYuLEiXB2dkZsbGyL\n7U+ePIFQKMRHH32EKVOmsAufEUJaR+EZWSwWY9GiRYiPj4empiZcXV15FywWi7Fu3TrEx8fDwsIC\nM2fOhJOTk9QyoTExMZg0aRJmz56NBw8eYP78+Th16pTyvSGkm1J4RtbU1IRYLGYXL2uNGzduoH//\n/ujbty90dHQwefJknDx5UmofgUDAll1RUUGTJghREq/vyFOnTsWYMWPQo0cP9vGQkBCFxz2fOMDC\nwgI3btyQ2icgIADz5s1DYmIiqqurER8f39r2ExXj+hmox1sd1BCiEGcgv/vuu3j33XdVUnl6ejqm\nT58Ob29vXL16FStXrsSJEyegoSF9oaBOGUIIUQXOQG7N9+Lmnk8cIGs02JEjR9gEfSNGjEBtbS3K\nyspgZmYmtZ8qMoQQok54ZQiRtZg2V4qf119/HXl5eXj06BEsLCyQnp6OzZs3S+1jaWmJCxcuwMXF\nBTk5OaitrWUXDCeE8McZyPv372f/XVtbi59//hkVFRXcBWtp4fPPP4ePjw/EYjFmzJiB1157Ddu3\nb8ewYcMwYcIEBAcHY82aNdi7dy8EAgEiIiJkfmgQQhTjDOSXX35Z6u958+bBxcUFS5Ys4Sx83Lhx\nGDdunNRjixcvZv89aNAgHDhwgG9bCSFycAby3bt32X8zDINbt26hvr5epY0ihLQOZyCvW7eO/bem\npiasrKywdetWlTaKENI6rfqOTAjpmjjHWm/btq1FFs3o6GiVNooQ0jq8Egs0v7FlZGSE06dPIzAw\nUKUNI9JohBVRhPOMLBaLpUZR1dbW0s0uQroYzjPy5MmT4e3tjRkzZgAAUlJS5K55TAjpHJyB7Ovr\ni8GDB+P8+fMAAB8fH5WNvSaEKIczkJ88eYJ///vfGD9+PACgpqYGhYWFsLS05Cz87NmzCA8Ph0Qi\ngaurK+bPn99inx9//BFff/01BAIBhgwZ0mIYJyGEG+d35ICAAKlhkxoaGli0aBFnwU2JBeLi4pCe\nno4TJ07gwYMHUvvk5eUhNjYWycnJSE9Px+rVq5XoAiGE180uHR0d9m8dHR1eUwj5JBY4dOgQPD09\nYWRkBAAtZj0RQvjhDGRjY2OpXFqnT59mA08RWYkFRCKR1D55eXl4+PAh3N3dMWvWLJw9e7Y1bSeE\n/D/O78hffvklli1bhi+//BIAYGpqisjIyHapXCwWIz8/HwkJCSgqKoKXlxeOHz8utfoj0LUSC9Dv\nuaQr4gxka2trpKamsqO7ng8yefgkFrCwsICDgwO0tbXRt29fWFtbIy8vD/b29lL7UWIBQhTjtfbT\nuXPn8ODBA3btJ6DxZylF+CQWeO+995Ceno4ZM2agtLQUeXl56Nu3rxLdIKR74wzk0NBQVFRU4I8/\n/oCLiwt+/fVXODg4cBfMI7HAO++8g8zMTHzwwQfQ1NTEypUrYWJi0i4dI6Q74Qzky5cv4/jx45g6\ndSqWLFkCHx8fLFiwgFfhXIkFBAIBQkJCODNyEkIU47xr3bQ8jK6uLkpKSqCrq4vi4mKVN4wQwh/n\nGXns2LEoLy+Ht7c3PvroI2hoaGD69Okd0TZCCE+cgdw0imvSpEkYP348ampqYGxsrPKGEUL443XX\nuomenh6txEhIF8T5HZkQ0vVRIBOiBlqVDreJgYEBLC0tW6zR1Bl8wn+D9kstV6egoZKkO+EM5M8+\n+wx3797FoEGDwDAMcnNzMXDgQFRVVWH9+vV4++23O6KdhBAFOE+pr7zyClJSUvDDDz8gLS0NKSkp\nsLa2xp49exAREaHw2LNnz2LixIlwdnZGbGys3P1++eUXDB48GDdv3mx9Dwgh3IGcm5uLIUOGsH8P\nHjwYOTk56N+/v8Lj+CQWAIDKykrs27eP17BPQohsnIH86quvYv369bh8+TIuX76MsLAwvPrqq6ir\nq4Ompqbc4/gkFgCA7du349NPP4Wurm7bekJIN8YZyBs3boSFhQV2796N3bt3o1evXoiIiICmpia+\n//57ucfxSSxw+/ZtFBUVUTI/QtqI82ZXjx49ZCbNAxrvXitLIpEgIiICGzZs4Ny3KyUWIKQr4gzk\na9eu4euvv8aTJ08gFovZx7kWOudKLFBVVYV79+5hzpw5AICSkhL4+fkhJiYGr7/+ulRZlFiAEMU4\nAzkkJAQrVqyAnZ2dwu/Ez+NKLGBgYICsrCz2b6FQiJUrV7YIYkIIN85A1tfXh5OTU+sL5pFYgBDS\nPjgD2dHREZs3b4azs7NUWtzmP0nJw5VYoLmEhATO8gghsvHKENL8/0BjZo+kpCTVtYoQ0iq00Dkh\nakBuIJ84cQIffvgh9u3bJ3N7091mQkjnkxvIT58+BQCUlpZ2WGMIIcqRG8ienp4AgCVLlnRYYwgh\nyuH8jlxaWoqUlBQ8fvxYakDI+vXrVdowQgh/nIHs7++P4cOHY+TIka0aEEII6TicgVxdXY3g4OCO\naAshREmcs5/Gjh2L33//XanCuRILxMfH44MPPsCUKVPw8ccf4/Hjx0rVQ0h3x3lGPnDgAHbv3o0e\nPXpAW1sbDMNAIBDg0qVLCo9rSiwQHx8PCwsLzJw5E05OThg0aBC7j62tLVJSUtCjRw/s378fkZGR\n2LZtW9t7RUg3wxnIFy9eVKrg5okFALCJBZoHsqOjI/vv4cOHIy0tTam6COnu5AZyXl4erK2tcf/+\nfZnbucZay0oscOPGDbn7HzlyBGPHjuVqLyFEBrmBHBsbi6+++grr1q1rsa29x1r/8MMPuHXrFhIT\nE2Vup8QChCgmN5C/+uorAMqPteZKLNDk/Pnz2LVrFxITE6VmVzVHiQUIUYzX2k85OTnIyclBbW0t\n+9iUKVMUHsOVWAAA7ty5g88//xxxcXEwMzNTovmEEIBHIH/zzTfIzMxEbm4uxowZg99//x0jR47k\nDGQ+iQU2bdqEZ8+esXOULS0tsWvXrvbpGSHdCGcg//TTTzh27BimT5+OyMhIFBcXIyQkhFfhXIkF\n9u7d27rWEkJk4hwQoqurC01NTWhpaaGyshLm5uY0cIOQLobzjDx06FCUl5djxowZmDFjBvT19WFv\nb98RbSOE8KQwkBmGQUBAAAwNDeHp6YkxY8agsrISdnZ2HdU+QggPCi+tBQIBvL292b/79+9PQUxI\nF8T5HXnIkCG4c+dOR7SFEKIkuZfWDQ0N0NLSQnZ2NmbOnIm+ffvipZdeYidNHD16tCPbSQhRQG4g\nu7q64ujRo4iJienI9hBClCA3kBmGAQD069evwxpDCFGO3EAuLS1FfHy83AM/+eQTlTSIENJ6cm92\nSSQSVFVVyf2PD64MIXV1dViyZAmcnZ3h6uqKgoIC5XtCSDcm94xsbm6OgIAApQvmkyHk8OHDMDQ0\nxG+//Yb09HRERUVRhhBClMD5HVlZfDKEnDp1iv2wmDhxItatW8feFefSlJq3vvofmds1y6oVHl9c\nXKxwu7yrg/pnihP2U71UrzL1ytK7d29oafGaoCg/kNs6oYFPhhCRSARLS8vGhmhpwcDAAGVlZTA1\nNZXaT1ZigabL+4ILcmZLnVLcPj9c4NON1qN6qd52cvLkSVhZWfHaV24gGxsbt1uD2kpWYoGamhrc\nunUL5ubmSuXb9vX17ZQpk1Qv1ctX8xMhF37nbSXwyRBiYWGBwsJC9O7dGw0NDaioqICJiQmv8vX0\n9PDmm28q3T4dHR3en3btieqlelWBc4imsppnCKmrq0N6ejqcnJyk9nFycmJHiP3yyy9wdHTk9f2Y\nECJNZWdkPhlCZs6ciRUrVsDZ2RlGRkbYunWrqppDiFpTWSAD3BlCdHV1ER0drcomENItaIaGhoZ2\ndiM6y7Bhw6heqlct6hUwbf3BmBDS6VR2s4sQ0nEokAlRAyq92dURKioq4O/vD6Ax4f3QoUNhZWWF\nDRs2KDzu8OHD0NTUhIuLS4ttdXV18PLywv3793H8+HGZvwuqot78/HyEhIRAIBCgT58+iIiIaDHY\nRRX1/v333wgICICWlhYMDQ2xdetW6OrqqrzeJj/++COioqJw6lTLYVOqqLehoQGOjo6wtbUF0Ji7\n3cDAoEP6e/bsWcTFxYFhGKxevZptQ5sxasTd3Z33vocOHWJSUlJkbhOLxczff//NBAUFMY8ePeqw\nesvKypiKigqGYRhm06ZNTEZGRofU29DQwIjFYoZhGGbr1q3Mr7/+2iH1MgzDSCQSZtmyZYybmxtn\nWe1Vb319PePl5cW7rPaqt6qqilm0aBHT0NDAuzy+1PLSOiMjA0KhEC4uLuxSrfv27cOsWbMgFApx\n9+5ddt/CwkLMnz8fJSUl7GMaGhpKLWHT1nqNjY2hr68PoPF3eL5DT9tar6amJjQ0Gt8KDMPwTibR\n1nqBxokz77zzTqsGArVHvffu3cPs2bNbNXahrfVeuXIFGhoa8PHxwapVq1BdrXgCRqu0+0dDJ2r6\n5Hz27BnDMAxTV1fHPjZ37lympqaGYZjGs8ChQ4eYnTt3Mp9++ikjEolkltfaM3J71VtYWMi4u7tz\nfnK3Z71Xrlxhpk+fzsyePZu9KuiIegMDA5n6+npeZ732rPeff/5hJBIJExISwvvKp631Hj16lPHy\n8mIaGhqYffv2Md9//z1nn/l64b8jy3Lz5k3s3LkTYrEYubm5AICAgAB8/vnn0NXVxZIlSwA0rjS5\nfPly9OrVq8vUW1NTg+DgYISFhfE+I7dHvSNGjEBqaipiY2Nx9OhRCIVCldf7+++/Y9SoUbyn6rVn\nf42MjAAA7733Hu7du9di4JIq6jUwMMDIkSOhqakJR0dHucsIK0MtL61jY2OxceNGfPfdd+ylqp2d\nHTZu3Ig33ngDx44dA9D4Ivz8888KF2Dv6HrXrFmDOXPmYODAgR1Wb/O1pvX19aGnp9ch9T548AC/\n/fYb5s2bhwcPHvAe5dfWep89ewaJRAKg8XKX71eJttZrb2+PnJwcAEB2dna7Tq5QyzPy+++/jwUL\nFsDW1haGhoYAGgOksLAQdXV1iIiIwJUrV6Cjo4OoqCgEBgZizZo1ePXVV9kyFi1ahKtXr2LFihWY\nP38+xo8fr/J6//zzT5w6dQoikQjx8fGYO3curzWg21rv7du3ERUVBQ0NDZiYmGDTpk0d8jzPnTsX\nc+fOBQB4eHggMDCwQ+rNzc3F2rVr0aNHD/Tv3x/Ozs4dUq+5uTkcHBzg5eUFPT09bNmyhVe9fNDI\nLkLUgFpeWhPS3VAgE6IGKJAJUQNqFcj9+vXDmDFjOqSumTNnyhw0MmzYsE6bPkfU17Vr19hBO7Ko\n1V3r6upqPH36tEPqKi0tRW1tbYvHKysrO6R+0r2IRCKFKarV6oxMSHelVmdkoHFkVHsN8FCkoqJC\n7ra6uroOaQPpPppGksmjVr8j29raSg1cV7U+ffrg8ePHUo/961//wqVLlzqsDaT70NDQYFdYeZ5a\nBbKyBAIBiouLYW5u3tlN6Va8vb2RnJzcvrOAuikKZDSuU2Vvb9/Zzeh2qqur8ejRI9jY2HR2U154\nFMiEqIEX9mZXjx49UFNT09nNkElPT6/Vl4tduT9ENZR5n8jzwp6RBQJBm5d+VRVl2taV+0NUoz1f\nc/odmRA1oDaBfOHCBQiFQnh6emLhwoXw9fVFfn4+UlNTMXHiRAiFQvj4+LD779y5U+rv5vv5+/uz\nk+1DQ0Ph6OiIw4cPS9Xl5uYGoVAoteKkqvpSVlaG4ODgDu0P3z76+fnhzTffxPnz59nH0tLS4O7u\njgULFvAa6ebh4cH/yelkR44cadX+Ta+byrVb0qAO1rzp//vf/xhPT08211Rubi7j7e3N5OXlMSkp\nKcyhQ4daHO/j48MEBAQw5eXlDMMwUvvt3LmTOXXqFMMwDCMSiVqU4eXlxVRUVDDXrl1jQkNDFbat\ntf2R1ReRSMSsWrWqQ/vD1ccmIpGIiY6OZjIzMxmGacxn5eHhwdTX1zPp6enM7t27OfvemiyVHa0p\nu2iT1ra16XWTpT3DTy3OyGfOnMG0adPY9CsDBgxQ+Jvwo0ePYGVlhffeew8ZGRkttjcftfV83qXq\n6mro6elBX18fDg4OePDgQft04v/J6gtXTrH27k9r+vh8efn5+bCxsYGWlhbefvttXLt2TWp7SUkJ\nfHx8IBQKsXnzZqlt3377Lby8vODq6oo7d+4AAFatWgUvLy8IhUJIJBJs2bIFHh4eEAqFEIlEKC0t\nha+vL4RCIZqWMUtKSmIzW96+fVuqjlmzZiE4OBguLi44ffo0AOD69esQCoVwd3dHSkoKAEAoFGLT\npk1YuXIle+zJkydx7949CIVCZGZm8mpvk4sXLyIoKAj19fVyn8u2eGHvWjdXUlKi8LfIuLg4pKWl\nYfjw4QgKCsJ///tfTJw4EcOGDcOXX36JKVOmsPvt27cPhoaGWLZsmcyyysvL2SADIHekjar60tRO\nVfanLX1sfqyBgQHKy8ultn/77beYO3cuxowZI/VGB4A5c+ZgwYIFyM/PR3R0NCIiIlBUVITExEQw\nDAOBQIArV64gKSkJGhoaYBgGGzduxIIFCzBixAhERkbi6tWrOHnyJPbt2wc9Pb0WN5NKS0uxbds2\nGBsbw9vbG+PHj0d0dDRiYmLQs2dPfPLJJ+zz5+zsjBEjRrDHTpgwATY2NkhISAAAvPHGG5ztBYBL\nly7h4sWLiIiIgLa2Nu/nsjXUIpDNzc1RXFwsd7uPjw9cXV3Zv8+cOYNz585BIBAgPz+fncXk4+MD\nFxcXLFy4EE+fPsXLL7/coiwDAwOp7318M13yxdWXpnaqsj9t6WPzYysrK9ncVk3y8vLY4Hh+Wt4P\nP/yA48ePs49ra2tj+vTpWL58OV555RUsXryYzQltbGyMpUuXIicnB5s3b4ZAIEBVVRXs7e2xaNEi\nhIaGQltbG4sXL5bqt7GxMfr06SPVr7t378LPzw8AUFZWhrKyMgCNifUU4dNeoPH+xd69e1UWxICa\n3OwaN24c0tLS2DdQfn5+i4TkTUpKStC7d29899132LNnD+bNm4fMzEx2u6amJjw9PbF3716Zx7/0\n0kuoqalBVVUVbty40apsl8r2RVFgq6I/8vooEok4229tbY379+9DLBbj/PnzcHBwkNo+YMAAXL9+\nHQBanJH379+PhIQErF+/HkDjlcDkyZMRFRWF0tJS3Lx5E46OjoiMjISZmRkyMjIwYMAABAcHIyEh\nAampqZgY+SSpAAAByElEQVQwYQJsbW0RERGBt956C6mpqVJ1PH36FEVFRaiurmavNGxtbfHtt98i\nISEBR48ehYWFBYCWHzQApBLp82kvAGzYsAGhoaEoLS3lfP6UpRZnZFNTU/j7+8PX1xcMw8DIyEju\np9/JkycxcuRI9u+33noLcXFxGDVqFPvY6NGjsW3bNtTV1WHPnj04ceIEGIaBSCRCQEAA/Pz84O3t\nDR0dHWzcuFHlfQkPD5e7v6r6I+uxlStXIj4+XuoNHhYWhtOnT+PUqVNwd3eHm5sbXF1d4enpCUND\nwxbfg+fPn4/g4GDExMRgxIgRUpf89vb28PT0ZNteVVUFPz8/iMVi6Ovrw8bGBv7+/uzAme3bt8PR\n0RFr165FRUUFNDQ0EBYWhh07dqCgoAB1dXUt1moyMTHBjh07kJ2djYULFwIAAgMD2efb2NgYO3bs\nkPt829vbw9/fH97e3rzaCzROrvnss8+wYsUKREdHo2fPnnLLVxYNCFEBdRwQIpFIEB4ejrVr13Z2\nU9rEw8MDycnJnd0MADQghHQCDQ2NFz6I1RmdkVVAHc/IpP3RGZkQIuWFvdmlp6fXqqU4OxLftZOe\nP6ar9oeohjLvE3n+D7fxBgwelo6gAAAAAElFTkSuQmCC\n", 181 | "text/plain": [ 182 | "" 183 | ] 184 | }, 185 | "metadata": {}, 186 | "output_type": "display_data" 187 | } 188 | ], 189 | "source": [ 190 | "sns.set_style('ticks')\n", 191 | "figure(figsize=(3.3, 2.2))\n", 192 | "ax = axes()\n", 193 | "do_plot(eval_type=1)\n", 194 | "sns.despine()\n", 195 | "subplots_adjust(left=.21, bottom=.25, right=.98, top=.82)\n", 196 | "savefig(\"cifar10_cifar100_transfer_train.pdf\")" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": { 203 | "collapsed": true 204 | }, 205 | "outputs": [], 206 | "source": [] 207 | } 208 | ], 209 | "metadata": { 210 | "kernelspec": { 211 | "display_name": "Python 3", 212 | "language": "python", 213 | "name": "python3" 214 | }, 215 | "language_info": { 216 | "codemirror_mode": { 217 | "name": "ipython", 218 | "version": 3 219 | }, 220 | "file_extension": ".py", 221 | "mimetype": "text/x-python", 222 | "name": "python", 223 | "nbconvert_exporter": "python", 224 | "pygments_lexer": "ipython3", 225 | "version": "3.5.2" 226 | } 227 | }, 228 | "nbformat": 4, 229 | "nbformat_minor": 2 230 | } 231 | -------------------------------------------------------------------------------- /fig_transfer_cifar/split_cifar10_data_path_int[omega_decay=sum,xi=0.001]_lr1.00e-03_ep60.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ganguli-lab/pathint/86c5e07c5603d6a44c2f53585c19ee70773ef15c/fig_transfer_cifar/split_cifar10_data_path_int[omega_decay=sum,xi=0.001]_lr1.00e-03_ep60.pkl.gz -------------------------------------------------------------------------------- /fig_transfer_cifar/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke 3 | # MIT License -- see LICENSE for details 4 | # 5 | # This file is part of the code to reproduce the core results of: 6 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through 7 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on 8 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention 9 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995. 10 | # http://proceedings.mlr.press/v70/zenke17a.html 11 | # 12 | 13 | import sys, os 14 | sys.path.extend([os.path.expanduser('..')]) 15 | 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import matplotlib.colors as colors 19 | import matplotlib.cm as cmx 20 | 21 | import seaborn as sns 22 | sns.set_style("white") 23 | 24 | from tqdm import trange, tqdm 25 | 26 | import tensorflow as tf 27 | 28 | from pathint import protocols 29 | from pathint.optimizers import KOOptimizer 30 | from keras.optimizers import Adam, RMSprop, SGD 31 | from keras.callbacks import Callback 32 | import keras.backend as K 33 | import keras.activations as activations 34 | from keras.models import Sequential 35 | from keras.layers import Dense, Dropout, Activation, Flatten 36 | from keras.layers import Conv2D, MaxPooling2D 37 | 38 | from pathint import utils 39 | from pathint.keras_utils import LossHistory 40 | 41 | # ## Parameters 42 | 43 | # Data params 44 | 45 | # Network params 46 | input_shape = (3,32,32) 47 | 48 | # size of pooling area for max pooling 49 | pool_size = (2, 2) 50 | 51 | # convolution kernel size 52 | kernel_size = (3, 3) 53 | 54 | # Optimization parameters 55 | batch_size = 256 56 | epochs_per_task = 60 57 | learning_rate = 1e-3 58 | nstats = 1 # repeats of experiment to compute stdev 59 | 60 | 61 | add_evals=False # Saves evals accross runs 62 | cvals = ['scratch', 0, 0.01, 0.05, 0.1, 0.2] 63 | cvals = ['scratch', 0, 0.01, 0.1] 64 | print("cvals %s"%cvals) 65 | 66 | 67 | debug=False 68 | if debug: 69 | cvals = [0.1] 70 | epochs_per_task = 1 71 | nstats = 1 72 | 73 | # Reset optimizer after each age 74 | reset_optimizer = True 75 | 76 | 77 | # ## Construct datasets 78 | n_tasks = 6 79 | output_dim = 10*n_tasks 80 | nb_classes = output_dim 81 | # task_labels = [ range(i*10,(i+1)*10) for i in range(n_tasks) ] 82 | 83 | task_labels, training_datasets = utils.construct_transfer_cifar10_cifar100(n_tasks, split='train') 84 | _, validation_datasets = utils.construct_transfer_cifar10_cifar100(n_tasks, split='test') 85 | print(task_labels) 86 | 87 | # training_datasets = utils.construct_split_cifar10(task_labels, split='train') 88 | # validation_datasets = utils.construct_split_cifar10(task_labels, split='test') 89 | 90 | # ## Construct network, loss, and updates 91 | tf.reset_default_graph() 92 | 93 | config = tf.ConfigProto() 94 | config.gpu_options.allow_growth=True 95 | sess = tf.InteractiveSession(config=config) 96 | sess.run(tf.global_variables_initializer()) 97 | 98 | 99 | # Instantiate masking functions 100 | output_mask = tf.Variable(tf.zeros(output_dim), name="mask", trainable=False) 101 | 102 | select = tf.select if hasattr(tf, 'select') else tf.where 103 | 104 | def masked_softmax(logits): 105 | # logits are [batch_size, output_dim] 106 | x = select(tf.tile(tf.equal(output_mask[None, :], 1.0), [tf.shape(logits)[0], 1]), logits, -1e32 * tf.ones_like(logits)) 107 | return activations.softmax(x) 108 | 109 | def set_active_outputs(labels): 110 | new_mask = np.zeros(output_dim) 111 | for l in labels: 112 | new_mask[l] = 1.0 113 | sess.run(output_mask.assign(new_mask)) 114 | # print("setting output mask") 115 | # print(sess.run(output_mask)) 116 | 117 | def masked_predict(model, data, targets): 118 | pred = model.predict(data) 119 | # print(pred) 120 | acc = np.argmax(pred,1)==np.argmax(targets,1) 121 | return acc.mean() 122 | 123 | # Assemble the network model 124 | model = Sequential() 125 | 126 | model.add(Conv2D(32, (3, 3), padding='same', 127 | input_shape=training_datasets[0][0].shape[1:])) 128 | model.add(Activation('relu')) 129 | model.add(Conv2D(32, (3, 3))) 130 | model.add(Activation('relu')) 131 | model.add(MaxPooling2D(pool_size=(2, 2))) 132 | model.add(Dropout(0.25)) 133 | 134 | model.add(Conv2D(64, (3, 3), padding='same')) 135 | model.add(Activation('relu')) 136 | model.add(Conv2D(64, (3, 3))) 137 | model.add(Activation('relu')) 138 | model.add(MaxPooling2D(pool_size=(2, 2))) 139 | model.add(Dropout(0.25)) 140 | 141 | model.add(Flatten()) 142 | model.add(Dense(512)) 143 | model.add(Activation('relu')) 144 | model.add(Dropout(0.5)) 145 | # model.add(Dense(nb_classes)) 146 | model.add(Dense(nb_classes, kernel_initializer='zero', activation=masked_softmax)) 147 | 148 | 149 | # Define our training protocol 150 | protocol_name, protocol = protocols.PATH_INT_PROTOCOL(omega_decay='sum', xi=1e-3 ) 151 | opt = Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999) 152 | # opt = RMSprop(lr=1e-3) 153 | # opt = SGD(1e-3) 154 | oopt = KOOptimizer(opt, model=model, **protocol) 155 | model.compile(loss='categorical_crossentropy', optimizer=oopt, metrics=['accuracy']) 156 | model._make_train_function() 157 | 158 | history = LossHistory() 159 | callbacks = [history] 160 | datafile_name = "split_cifar10_data_%s_lr%.2e_ep%i.pkl.gz"%(protocol_name, learning_rate, epochs_per_task) 161 | 162 | 163 | 164 | def run_fits(cvals, training_data, valid_data, nstats=1): 165 | acc_mean = dict() 166 | acc_std = dict() 167 | for cidx, cval_ in enumerate(cvals): 168 | runs = [] 169 | for runid in range(nstats): 170 | evals = [] 171 | sess.run(tf.global_variables_initializer()) 172 | # model.set_weights(saved_weights) 173 | cstuffs = [] 174 | if cval_=='scratch': 175 | print("Scratch mode -- inits net before each age") 176 | cval = 0 177 | else: 178 | print("setting cval") 179 | cval = cval_ 180 | oopt.set_strength(cval) 181 | oopt.init_task_vars() 182 | print("cval is %f"%sess.run(oopt.lam)) 183 | for age, tidx in enumerate(range(n_tasks)): 184 | if cval_=='scratch': 185 | sess.run(tf.global_variables_initializer()) 186 | oopt.reset_optimizer() 187 | print("Age %i, cval is=%f"%(age,cval)) 188 | set_active_outputs(task_labels[age]) 189 | stuffs = model.fit(training_data[tidx][0], training_data[tidx][1], batch_size, epochs_per_task, callbacks=callbacks, verbose=0) 190 | oopt.update_task_metrics(training_data[tidx][0], training_data[tidx][1], batch_size) 191 | oopt.update_task_vars() 192 | ftask = [] 193 | for j in range(n_tasks): 194 | set_active_outputs(task_labels[j]) 195 | train_err = masked_predict(model, training_data[j][0], training_data[j][1]) 196 | valid_err = masked_predict(model, valid_data[j][0], valid_data[j][1]) 197 | ftask.append( (np.mean(valid_err), np.mean(train_err)) ) 198 | evals.append(ftask) 199 | cstuffs.append(stuffs) 200 | 201 | # Re-initialize optimizater variables 202 | if reset_optimizer: 203 | oopt.reset_optimizer() 204 | 205 | evals = np.array(evals) 206 | runs.append(evals) 207 | 208 | runs = np.array(runs) 209 | acc_mean[cval_] = runs.mean(0) 210 | acc_std[cval_] = runs.std(0) 211 | return dict(mean=acc_mean, std=acc_std) 212 | 213 | 214 | # Run the sim 215 | data = run_fits(cvals, training_datasets, validation_datasets, nstats=nstats) 216 | 217 | 218 | # data = dict(mean={0.1:0.0}, std={0.1:0.0}) 219 | # print(data) 220 | if add_evals: 221 | old_data = utils.load_zipped_pickle(datafile_name) 222 | # returns empty dict if file not found 223 | for k in old_data.keys(): 224 | for l in old_data[k].keys(): 225 | data[k][l] = old_data[k][l] 226 | 227 | # Save the data 228 | utils.save_zipped_pickle(data, datafile_name) 229 | 230 | # To overwrite the data in the file uncomment this 231 | # all_evals = dict() # uncomment to delete on disk 232 | # utils.save_zipped_pickle(data, datafile_name) 233 | -------------------------------------------------------------------------------- /pathint/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ganguli-lab/pathint/86c5e07c5603d6a44c2f53585c19ee70773ef15c/pathint/__init__.py -------------------------------------------------------------------------------- /pathint/keras_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke 2 | # MIT License -- see LICENSE for details 3 | # 4 | # This file is part of the code to reproduce the core results of: 5 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through 6 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on 7 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention 8 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995. 9 | # http://proceedings.mlr.press/v70/zenke17a.html 10 | # 11 | 12 | # Keras-specific functions and utils 13 | from keras.callbacks import Callback 14 | from keras.models import Model 15 | import keras.backend as K 16 | from keras.layers import Dense 17 | import tensorflow as tf 18 | 19 | class LossHistory(Callback): 20 | def __init__(self, *args, **kwargs): 21 | super(LossHistory, self).__init__(*args, **kwargs) 22 | self.losses = [] 23 | self.regs = [] 24 | 25 | def on_batch_end(self, batch, logs={}): 26 | self.losses.append(logs.get('loss')) 27 | self.regs.append(K.get_session().run(self.model.optimizer.regularizer)) 28 | 29 | 30 | # Create a callback that tracks FisherInformation 31 | from keras.models import Model 32 | import keras.backend as K 33 | def compute_fishers(model): 34 | # Check that model only contains Dense layers 35 | for l in model.layers: 36 | if not isinstance(l, Dense): 37 | raise ValueError("All layers of the model must be Dense, got %s"%l) 38 | # Create new model used to extract activations at each layer 39 | new_model = Model(inputs=model.input, outputs=[l.output for l in model.layers]) 40 | acts = new_model(model.input) 41 | 42 | out = acts[-1] 43 | out_dim = out.get_shape().as_list()[-1] 44 | #assert len(model.weights) %2 == 0 45 | n_weights = len(model.weights) // 2 46 | 47 | 48 | def _get_fishers(): 49 | fisher_weights = [0.0] * n_weights 50 | fisher_biases = [0.0] * n_weights 51 | for idx in range(out_dim): 52 | # Clips output 53 | # https://github.com/fchollet/keras/blob/master/keras/backend/tensorflow_backend.py#L2743 54 | output = out[:, idx] 55 | epsilon = tf.convert_to_tensor(x) 56 | if epsilon.dtype != output.dtype.base_dtype: 57 | epsilon = tf.cast(epsilon, output.dtype.base_dtype) 58 | 59 | output = tf.clip_by_value(output, epsilon, 1. - epsilon) 60 | y = K.log(output) 61 | # From the post-nonlinearity outputs of each layer, we walk back up through the graph 62 | # to find the linear activation corresponding to h=XW + b. 63 | # Then we identify the weights W, biases b, and previous activation X. 64 | # 1. TensorFlow can compute dy/dh, giving us a [batch_size, n_neurons] matrix 65 | # 2. We manually comute dy/dW=dy/dh X and dy/db=dy/dh 66 | # 3. We sum the squared Jacobians 67 | 68 | # Identify pre-nonlinearity activation corresponding to h=XW+b 69 | def _walk_up_until_add(x): 70 | if x.op.type == 'BiasAdd': 71 | return x 72 | elif x.op.type == 'Select': 73 | return x.op.inputs[1] 74 | else: 75 | return _walk_up_until_add(x.op.inputs[0]) 76 | 77 | linear = [_walk_up_until_add(a) for a in acts] 78 | # Identify previous activation, X 79 | prev_acts = [l.op.inputs[0].op.inputs[0] for l in linear] 80 | # Compute dy/dh 81 | dy_dlinear = [tf.gradients(y,l)[0] for l in linear] 82 | 83 | # Figure out which Jacobians correspond to which weights 84 | if idx == 0: 85 | val_to_var = {v.value():v for v in model.weights} 86 | weight_vars = [val_to_var[l.op.inputs[0].op.inputs[1]] for l in linear] 87 | bias_vars = [val_to_var[l.op.inputs[1]] for l in linear] 88 | 89 | # Compute the sum of the Jacobian squared 90 | # Because each of the Jacobians are rank-1, we can compute this by first squaring and then summing: 91 | # \sum_i (u_i v_i^T)^2 = \sum_i (u_i^2 (v_i^2)^T) 92 | weights_sum_jacobian_squared = [tf.matmul(tf.transpose(a)**2, dh**2) for dh,a in zip(dy_dlinear, prev_acts)] 93 | bias_sum_jacobian_squared = [tf.reduce_sum(dh**2, 0) for dh in dy_dlinear] 94 | # Keep track of aggregate across outputs 95 | for jj in range(n_weights): 96 | fisher_weights[jj] += weights_sum_jacobian_squared[jj] 97 | fisher_biases[jj] += bias_sum_jacobian_squared[jj] 98 | var_to_fisher = dict(zip(weight_vars+bias_vars, fisher_weights+fisher_biases)) 99 | return {w: var_to_fisher[w] for w in model.weights} 100 | 101 | fishers = _get_fishers() 102 | return fishers 103 | 104 | # Allocate space for accumulated Fisher 105 | avg_fishers = [K.zeros(w.get_shape().as_list()) for w in model.weights] 106 | # Create updates to reset avg fisher, update, etc. 107 | update_fishers = tf.group(*[tf.assign_add(avg_f, f) for avg_f, f in zip(avg_fishers, fishers)]) 108 | zero_fishers = tf.group(*[tf.assign(avg_f, 0.0 * avg_f) for avg_f in avg_fishers]) 109 | return fishers, avg_fishers, update_fishers, zero_fishers 110 | 111 | -------------------------------------------------------------------------------- /pathint/optimizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke 2 | # MIT License -- see LICENSE for details 3 | # 4 | # This file is part of the code to reproduce the core results of: 5 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through 6 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on 7 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention 8 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995. 9 | # http://proceedings.mlr.press/v70/zenke17a.html 10 | # 11 | """Optimization algorithms.""" 12 | 13 | import tensorflow as tf 14 | 15 | import numpy as np 16 | import keras 17 | from keras import backend as K 18 | from keras.optimizers import Optimizer 19 | from keras.callbacks import Callback 20 | from pathint.utils import extract_weight_changes, compute_updates 21 | from pathint.regularizers import quadratic_regularizer 22 | from collections import OrderedDict 23 | 24 | 25 | class KOOptimizer(Optimizer): 26 | """An optimizer whose loss depends on its own updates.""" 27 | 28 | def _allocate_var(self, name=None): 29 | return {w: K.zeros(w.get_shape(), name=name) for w in self.weights} 30 | 31 | def _allocate_vars(self, names): 32 | #TODO: add names, better shape/init checking 33 | self.vars = {name: self._allocate_var(name=name) for name in names} 34 | 35 | def __init__(self, opt, step_updates=[], task_updates=[], init_updates=[], task_metrics = {}, regularizer_fn=quadratic_regularizer, 36 | lam=1.0, model=None, compute_average_loss=False, compute_average_weights=False, **kwargs): 37 | """Instantiate an optimzier that depends on its own updates. 38 | 39 | Args: 40 | opt: Keras optimizer 41 | step_updates: OrderedDict or List of tuples 42 | Contains variable names and updates to be run at each step: 43 | (name, lambda vars, weight, prev_val: new_val). See below for details. 44 | task_updates: same as step_updates but run after each task 45 | init_updates: updates to be run before using the optimizer 46 | task_metrics: list of names of metrics to compute on full data/unionset after a task 47 | regularizer_fn (optional): function, takes in weights and variables returns scalar 48 | defaults to EWC regularizer 49 | lam: scalar penalty that multiplies the regularization term 50 | model: Keras model to be optimized. Needed to compute Fisher information 51 | compute_average_loss: compute EMA of the loss, default: False 52 | compute_average_weights: compute EMA of the weights, default: False 53 | 54 | Variables are created for each name in the task and step updates. Note that you cannot 55 | use the name 'grads', 'unreg_grads' or 'deltas' as those are reserved to contain the gradients 56 | of the full loss, loss without regularization, and the weight updates at each step. 57 | You can access them in the vars dict, e.g.: oopt.vars['grads'] 58 | 59 | The step and task update functions have the signature: 60 | def update_fn(vars, weight, prev_val): 61 | '''Compute the new value for a variable. 62 | Args: 63 | vars: optimization variables (OuroborosOptimzier.vars) 64 | weight: weight Variable in model that this variable is associated with. 65 | prev_val: previous value of this varaible 66 | Returns: 67 | Tensor representing the new value''' 68 | 69 | You can run both task and step updates on the same variable, allowing you to reset 70 | step variables after each task. 71 | """ 72 | super(KOOptimizer, self).__init__(**kwargs) 73 | if not isinstance(opt, keras.optimizers.Optimizer): 74 | raise ValueError("opt must be an instance of keras.optimizers.Optimizer but got %s"%type(opt)) 75 | if not isinstance(step_updates, OrderedDict): 76 | step_updates = OrderedDict(step_updates) 77 | if not isinstance(task_updates, OrderedDict): task_updates = OrderedDict(task_updates) 78 | if not isinstance(init_updates, OrderedDict): init_updates = OrderedDict(init_updates) 79 | # task_metrics 80 | self.names = set().union(step_updates.keys(), task_updates.keys(), task_metrics.keys()) 81 | if 'grads' in self.names or 'deltas' in self.names: 82 | raise ValueError("Optimization variables cannot be named 'grads' or 'deltas'") 83 | self.step_updates = step_updates 84 | self.task_updates = task_updates 85 | self.init_updates = init_updates 86 | self.compute_average_loss = compute_average_loss 87 | self.regularizer_fn = regularizer_fn 88 | # Compute loss and gradients 89 | self.lam = K.variable(value=lam, dtype=tf.float32, name="lam") 90 | self.nb_data = K.variable(value=1.0, dtype=tf.float32, name="nb_data") 91 | self.opt = opt 92 | #self.compute_fisher = compute_fisher 93 | #if compute_fisher and model is None: 94 | # raise ValueError("To compute Fisher information, you need to pass in a Keras model object ") 95 | self.model = model 96 | self.task_metrics = task_metrics 97 | self.compute_average_weights = compute_average_weights 98 | 99 | def set_strength(self, val): 100 | K.set_value(self.lam, val) 101 | 102 | def set_nb_data(self, nb): 103 | K.set_value(self.nb_data, nb) 104 | 105 | def get_updates(self, params,loss,model=None): 106 | self.weights = params 107 | # Allocate variables 108 | with tf.variable_scope("KOOptimizer"): 109 | self._allocate_vars(self.names) 110 | 111 | #grads = self.get_gradients(loss, params) 112 | 113 | # Compute loss and gradients 114 | self.regularizer = 0.0 if self.regularizer_fn is None else self.regularizer_fn(params, self.vars) 115 | self.initial_loss = loss 116 | self.loss = loss + self.lam * self.regularizer 117 | with tf.variable_scope("wrapped_optimizer"): 118 | self._weight_update_op, self._grads, self._deltas = compute_updates(self.opt, self.loss, params) 119 | 120 | wrapped_opt_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "wrapped_optimizer") 121 | self.init_opt_vars = tf.variables_initializer(wrapped_opt_vars) 122 | 123 | self.vars['unreg_grads'] = dict(zip(params, tf.gradients(self.initial_loss, params))) 124 | # Compute updates 125 | self.vars['grads'] = dict(zip(params, self._grads)) 126 | self.vars['deltas'] = dict(zip(params, self._deltas)) 127 | # Keep a pointer to self in vars so we can use it in the updates 128 | self.vars['oopt'] = self 129 | # Keep number of data samples handy for normalization purposes 130 | self.vars['nb_data'] = self.nb_data 131 | 132 | if self.compute_average_weights: 133 | with tf.variable_scope("weight_emga") as scope: 134 | weight_ema = tf.train.ExponentialMovingAverage(decay=0.99, zero_debias=True) 135 | self.maintain_weight_averages_op = weight_ema.apply(self.weights) 136 | self.vars['average_weights'] = {w: weight_ema.average(w) for w in self.weights} 137 | self.weight_ema_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name) 138 | self.init_weight_ema_vars = tf.variables_initializer(self.weight_ema_vars) 139 | print(">>>>>") 140 | K.get_session().run(self.init_weight_ema_vars) 141 | if self.compute_average_loss: 142 | with tf.variable_scope("ema") as scope: 143 | ema = tf.train.ExponentialMovingAverage(decay=0.99, zero_debias=True) 144 | self.maintain_averages_op = ema.apply([self.initial_loss]) 145 | self.ema_loss = ema.average(self.initial_loss) 146 | self.prev_loss = tf.Variable(0.0, trainable=False, name="prev_loss") 147 | self.delta_loss = tf.Variable(0.0, trainable=False, name="delta_loss") 148 | self.ema_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name) 149 | self.init_ema_vars = tf.variables_initializer(self.ema_vars) 150 | # if self.compute_fisher: 151 | # self._fishers, _, _, _ = compute_fishers(self.model) 152 | # #fishers = compute_fisher_information(model) 153 | # self.vars['fishers'] = dict(zip(weights, self._fishers)) 154 | # #fishers, avg_fishers, update_fishers, zero_fishers = compute_fisher_information(model) 155 | 156 | def _var_update(vars, update_fn): 157 | updates = [] 158 | for w in params: 159 | updates.append(tf.assign(vars[w], update_fn(self.vars, w, vars[w]))) 160 | return tf.group(*updates) 161 | 162 | def _compute_vars_update_op(updates): 163 | # Force task updates to happen sequentially 164 | update_op = tf.no_op() 165 | for name, update_fn in updates.items(): 166 | with tf.control_dependencies([update_op]): 167 | update_op = _var_update(self.vars[name], update_fn) 168 | return update_op 169 | 170 | self._vars_step_update_op = _compute_vars_update_op(self.step_updates) 171 | self._vars_task_update_op = _compute_vars_update_op(self.task_updates) 172 | self._vars_init_update_op = _compute_vars_update_op(self.init_updates) 173 | 174 | # Create task-relevant update ops 175 | reset_ops = [] 176 | update_ops = [] 177 | for name, metric_fn in self.task_metrics.items(): 178 | metric = metric_fn(self) 179 | for w in params: 180 | reset_ops.append(tf.assign(self.vars[name][w], 0*self.vars[name][w])) 181 | update_ops.append(tf.assign_add(self.vars[name][w], metric[w])) 182 | self._reset_task_metrics_op = tf.group(*reset_ops) 183 | self._update_task_metrics_op = tf.group(*update_ops) 184 | 185 | # Each step we update the weights using the optimizer as well as the step-specific variables 186 | self.step_op = tf.group(self._weight_update_op, self._vars_step_update_op) 187 | self.updates.append(self.step_op) 188 | # After each task, run task-specific variable updates 189 | self.task_op = self._vars_task_update_op 190 | self.init_op = self._vars_init_update_op 191 | 192 | if self.compute_average_weights: 193 | self.updates.append(self.maintain_weight_averages_op) 194 | 195 | if self.compute_average_loss: 196 | self.update_loss_op = tf.assign(self.prev_loss, self.ema_loss) 197 | bupdates = self.updates 198 | with tf.control_dependencies(bupdates + [self.update_loss_op]): 199 | self.updates = [tf.group(*[self.maintain_averages_op])] 200 | self.delta_loss = self.prev_loss - self.ema_loss 201 | 202 | return self.updates#[self._base_updates 203 | 204 | def init_task_vars(self): 205 | K.get_session().run([self.init_op]) 206 | 207 | def init_acc_vars(self): 208 | K.get_session().run(self.init_ema_vars) 209 | 210 | def init_loss(self, X, y, batch_size): 211 | pass 212 | #sess = K.get_session() 213 | #xi, yi, sample_weights = self.model.model._standardize_user_data(X[:batch_size], y[:batch_size], batch_size=batch_size) 214 | #sess.run(tf.assign(self.prev_loss, self.initial_loss), {self.model.input:xi[0], self.model.model.targets[0]:yi[0], self.model.model.sample_weights[0]:sample_weights[0], K.learning_phase():1}) 215 | 216 | def update_task_vars(self): 217 | K.get_session().run(self.task_op) 218 | 219 | def update_task_metrics(self, X, y, batch_size): 220 | # Reset metric accumulators 221 | n_batch = len(X) // batch_size 222 | 223 | sess = K.get_session() 224 | sess.run(self._reset_task_metrics_op) 225 | for i in range(n_batch): 226 | xi, yi, sample_weights = self.model._standardize_user_data(X[i * batch_size:(i+1) * batch_size], y[i*batch_size:(i+1)*batch_size], batch_size=batch_size) 227 | sess.run(self._update_task_metrics_op, {self.model.input:xi[0], self.model.targets[0]:yi[0], self.model.sample_weights[0]:sample_weights[0]}) 228 | 229 | 230 | def reset_optimizer(self): 231 | """Reset the optimizer variables""" 232 | K.get_session().run(self.init_opt_vars) 233 | 234 | def get_config(self): 235 | raise ValueError("Write the get_config bro") 236 | 237 | def get_numvals_list(self, key='omega'): 238 | """ Returns list of numerical values such as for instance omegas in reproducible order """ 239 | variables = self.vars[key] 240 | numvals = [] 241 | for p in self.weights: 242 | numval = K.get_value(tf.reshape(variables[p],(-1,))) 243 | numvals.append(numval) 244 | return numvals 245 | 246 | def get_numvals(self, key='omega'): 247 | """ Returns concatenated list of numerical values such as for instance omegas in reproducible order """ 248 | conc = np.concatenate(self.get_numvals_list(key)) 249 | return conc 250 | 251 | def get_state(self): 252 | state = [] 253 | vs = self.vars 254 | for key in vs.keys(): 255 | if key=='oopt': continue 256 | v = vs[key] 257 | for p in v.values(): 258 | state.append(K.get_value(p)) # FIXME WhyTF does this not work? 259 | return state 260 | 261 | def set_state(self, state): 262 | c = 0 263 | vs = self.vars 264 | for key in vs.keys(): 265 | if key=='oopt': continue 266 | v = vs[key] 267 | for p in v.values(): 268 | K.set_value(p,state[c]) 269 | c += 1 270 | -------------------------------------------------------------------------------- /pathint/protocols.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke 2 | # MIT License -- see LICENSE for details 3 | # 4 | # This file is part of the code to reproduce the core results of: 5 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through 6 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on 7 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention 8 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995. 9 | # http://proceedings.mlr.press/v70/zenke17a.html 10 | # 11 | from pathint.utils import ema 12 | from pathint.regularizers import quadratic_regularizer, get_power_regularizer 13 | from pathint.keras_utils import compute_fishers 14 | import tensorflow as tf 15 | import numpy as np 16 | 17 | """ 18 | A protocol is a function that takes as input some parameters and returns a tuple: 19 | (protocol_name, optimizer_kwargs) 20 | The protocol name is just a string that describes the protocol. 21 | The optimizer_kwargs is a dictionary that will get passed to KOOptimizer. It typically contains: 22 | step_updates, task_updates, task_metrics, regularizer_fn 23 | """ 24 | 25 | 26 | 27 | PATH_INT_PROTOCOL = lambda omega_decay, xi: ( 28 | 'path_int[omega_decay=%s,xi=%s]'%(omega_decay,xi), 29 | { 30 | 'init_updates': [ 31 | ('cweights', lambda vars, w, prev_val: w.value() ), 32 | ], 33 | 'step_updates': [ 34 | ('grads2', lambda vars, w, prev_val: prev_val -vars['unreg_grads'][w] * vars['deltas'][w] ), 35 | ], 36 | 'task_updates': [ 37 | ('omega', lambda vars, w, prev_val: tf.nn.relu( ema(omega_decay, prev_val, vars['grads2'][w]/((vars['cweights'][w]-w.value())**2+xi)) ) ), 38 | #('cached_grads2', lambda vars, w, prev_val: vars['grads2'][w]), 39 | #('cached_cweights', lambda vars, w, prev_val: vars['cweights'][w]), 40 | ('cweights', lambda opt, w, prev_val: w.value()), 41 | ('grads2', lambda vars, w, prev_val: prev_val*0.0 ), 42 | ], 43 | 'regularizer_fn': quadratic_regularizer, 44 | }) 45 | 46 | 47 | FISHER_PROTOCOL = lambda omega_decay:( 48 | 'fisher[omega_decay=%s]'%omega_decay, 49 | { 50 | 'task_updates': [ 51 | ('omega', lambda vars, w, prev_val: ema(omega_decay, prev_val, vars['task_fisher'][w]/vars['nb_data'])), 52 | ('cweights', lambda opt, w, prev_val: w.value()), 53 | ], 54 | 'task_metrics': { 55 | 'task_fisher': lambda opt: compute_fishers(opt.model), 56 | }, 57 | 'regularizer_fn': quadratic_regularizer, 58 | }) 59 | 60 | def sum_regularizer_fn(weights, vars): 61 | reg = 0.0 62 | for w in weights: 63 | reg += tf.reduce_sum(vars['sum_omega'][w] * w**2 64 | - 2 * vars['sum_omega_cweights'][w] * w 65 | + vars['sum_omega_cweights_squared'][w]) 66 | 67 | return reg 68 | 69 | -------------------------------------------------------------------------------- /pathint/regularizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke 2 | # MIT License -- see LICENSE for details 3 | # 4 | # This file is part of the code to reproduce the core results of: 5 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through 6 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on 7 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention 8 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995. 9 | # http://proceedings.mlr.press/v70/zenke17a.html 10 | # 11 | import tensorflow as tf 12 | 13 | def quadratic_regularizer(weights, vars, norm=2): 14 | """Compute the regularization term. 15 | 16 | Args: 17 | weights: list of Variables 18 | _vars: dict from variable name to dictionary containing the variables. 19 | Each set of variables is stored as a dictionary mapping from weights to variables. 20 | For example, vars['grads'][w] would retreive the 'grads' variable for weight w 21 | norm: power for the norm of the (weights - consolidated weight) 22 | 23 | Returns: 24 | scalar Tensor regularization term 25 | """ 26 | reg = 0.0 27 | for w in weights: 28 | reg += tf.reduce_sum(vars['omega'][w] * (w - vars['cweights'][w])**norm) 29 | return reg 30 | 31 | def get_power_regularizer(power=2.0): 32 | """Power regularizers with different norms""" 33 | def _regularizer_fn(weights, vars): 34 | return quadratic_regularizer(weights, vars, norm=power) 35 | return _regularizer_fn 36 | -------------------------------------------------------------------------------- /pathint/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke 3 | # MIT License -- see LICENSE for details 4 | # 5 | # This file is part of the code to reproduce the core results of: 6 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through 7 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on 8 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention 9 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995. 10 | # http://proceedings.mlr.press/v70/zenke17a.html 11 | # 12 | """Utility functions for benchmarking online learning""" 13 | from __future__ import division 14 | import numpy as np 15 | import keras 16 | from keras.utils import np_utils 17 | 18 | from keras.datasets import mnist, cifar10, cifar100 19 | from keras.optimizers import Adam, RMSprop, SGD 20 | import keras.backend as K 21 | 22 | import pickle 23 | import gzip 24 | 25 | import tensorflow as tf 26 | 27 | def ema(decay, prev_val, new_val): 28 | """Compute exponential moving average. 29 | 30 | Args: 31 | decay: 'sum' to sum up values, otherwise decay in [0, 1] 32 | prev_val: previous value of accumulator 33 | new_val: new value 34 | Returns: 35 | updated accumulator 36 | """ 37 | if decay == 'sum': 38 | return prev_val + new_val 39 | return decay * prev_val + (1.0 - decay) * new_val 40 | 41 | def leak(decay, prev_val, new_val): 42 | """Compute leaky integrator. 43 | 44 | Like ema, but expectation value depends on decay time constant. 45 | 46 | Args: 47 | decay: 'sum' to sum up values, otherwise decay in [0, 1] 48 | prev_val: previous value of accumulator 49 | new_val: new value 50 | Returns: 51 | updated accumulator 52 | """ 53 | if decay == 'sum': 54 | return prev_val + new_val 55 | return decay * prev_val + new_val 56 | 57 | def extract_weight_changes(weights, update_ops): 58 | """Given a list of weights and Assign ops, identify the change in weights. 59 | 60 | Args: 61 | weights: list of Variables 62 | update_ops: list of Assign ops, typically computed using Keras' opt.get_updates() 63 | 64 | Returns: 65 | list of Tensors containing the weight update for each variable 66 | """ 67 | name_to_var = {v.name: v.value() for v in weights} 68 | weight_update_ops = list(filter(lambda x: x.op.inputs[0].name in name_to_var, update_ops)) 69 | nonweight_update_ops = list(filter(lambda x: x.op.inputs[0].name not in name_to_var, update_ops)) 70 | # Make sure that all the weight update ops are Assign ops 71 | for weight in weight_update_ops: 72 | if weight.op.type != 'Assign': 73 | raise ValueError('Update op for weight %s is not of type Assign.'%weight.op.inputs[0].name) 74 | weight_changes = [(new_w.op.inputs[1] - name_to_var[new_w.op.inputs[0].name]) for new_w, old_w in zip(weight_update_ops, weights)] 75 | # Recreate the update ops, ensuring that we compute the weight changes before updating the weights 76 | with tf.control_dependencies(weight_changes): 77 | new_weight_update_ops = [tf.assign(new_w.op.inputs[0], new_w.op.inputs[1]) for new_w in weight_update_ops] 78 | return weight_changes, tf.group(*(nonweight_update_ops + new_weight_update_ops)) 79 | 80 | 81 | def compute_updates(opt, loss, weights): 82 | update_ops = opt.get_updates(weights, [], loss) 83 | deltas, new_update_op = extract_weight_changes(weights, update_ops) 84 | grads = tf.gradients(loss, weights) 85 | # Make sure that deltas are computed _before_ the weight is updated 86 | return new_update_op, grads, deltas 87 | 88 | 89 | def split_dataset_by_labels(X, y, task_labels, nb_classes=None, multihead=False): 90 | """Split dataset by labels. 91 | 92 | Args: 93 | X: data 94 | y: labels 95 | task_labels: list of list of labels, one for each dataset 96 | nb_classes: number of classes (used to convert to one-hot) 97 | Returns: 98 | List of (X, y) tuples representing each dataset 99 | """ 100 | if nb_classes is None: 101 | nb_classes = len(np.unique(y)) 102 | datasets = [] 103 | for labels in task_labels: 104 | idx = np.in1d(y, labels) 105 | if multihead: 106 | label_map = np.arange(nb_classes) 107 | label_map[labels] = np.arange(len(labels)) 108 | data = X[idx], np_utils.to_categorical(label_map[y[idx]], len(labels)) 109 | else: 110 | data = X[idx], np_utils.to_categorical(y[idx], nb_classes) 111 | datasets.append(data) 112 | return datasets 113 | 114 | def split_dataset_randomly(X, y, nb_splits, nb_classes=None): 115 | """Split dataset by labels. 116 | 117 | Args: 118 | X: data 119 | y: labels 120 | nb_splits: number of splits to return 121 | task_labels: list of list of labels, one for each dataset 122 | nb_classes: number of classes (used to convert to one-hot) 123 | Returns: 124 | List of (X, y) tuples representing each dataset 125 | """ 126 | if nb_classes is None: 127 | nb_classes = len(np.unique(y)) 128 | datasets = [] 129 | idx = range(len(y)) 130 | np.random.shuffle(idx) 131 | split_size = len(y)//nb_splits 132 | for i in range(nb_splits): 133 | data = X[idx[split_size*i:split_size*(i+1)]], np_utils.to_categorical(y[idx[split_size*i:split_size*(i+1)]], nb_classes) 134 | datasets.append(data) 135 | return datasets 136 | 137 | def get_mnist_variations(dsetnames=['MNIST_Rotated', 'MNIST_Basic'], datashape=(-1,1,28,28), validationset_fraction=0.1, multihead=False): 138 | """ Uses skdata package to import some MNIST variations 139 | 140 | The following dataset names exist in skdata: 141 | all = ['MNIST_Basic', 142 | 'MNIST_BackgroundImages', 143 | 'MNIST_BackgroundRandom', 144 | 'MNIST_Rotated', 145 | 'MNIST_Noise1', 146 | 'MNIST_Noise2', 147 | 'MNIST_Noise3', 148 | 'MNIST_Noise4', 149 | 'MNIST_Noise5', 150 | 'MNIST_Noise6' ] 151 | 152 | args: 153 | dsetnames: the names of the data sets from above list 154 | datashape: tuple with shape of the data (default (-1,1,28,28) 155 | validationset_fraction: the fraction of data to hold out 156 | multihead: whether to generate a multihead dataset or a single head one 157 | 158 | returns: 159 | doublet of training and validation set each being a list of tasks consisting of (X,y) tuples 160 | """ 161 | 162 | from skdata import larochelle_etal_2007 as L2007 163 | def dset(name): 164 | rval = getattr(L2007, name)() 165 | return rval 166 | 167 | n_tasks = len(dsetnames) 168 | training_datasets = [] 169 | validation_datasets = [] 170 | 171 | for i, dsname in enumerate(dsetnames): 172 | aa = dset(dsname) 173 | task = aa.classification_task() 174 | raw_data, raw_labels = task 175 | nb_datapoints = len(raw_data) 176 | label_offset = 0 177 | if multihead: 178 | nb_classes = 10*n_tasks 179 | label_offset = i*10 180 | else: 181 | nb_classes = 10 182 | nb_training_examples = int(nb_datapoints*(1.0-validationset_fraction)) 183 | data = raw_data.reshape(datashape) 184 | labels = np_utils.to_categorical(raw_labels+label_offset, nb_classes) 185 | training_datasets.append( (data[:nb_training_examples], labels[:nb_training_examples]) ) 186 | validation_datasets.append( (data[nb_training_examples:], labels[nb_training_examples:]) ) 187 | 188 | return training_datasets, validation_datasets 189 | 190 | def load_mnist(split='train'): 191 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 192 | X_train = X_train.reshape(-1, 784) 193 | X_test = X_test.reshape(-1, 784) 194 | X_train = X_train.astype('float32') 195 | X_test = X_test.astype('float32') 196 | X_train /= 255 197 | X_test /= 255 198 | 199 | if split == 'train': 200 | X, y = X_train, y_train 201 | else: 202 | X, y = X_test, y_test 203 | nb_classes = 10 204 | y = np_utils.to_categorical(y, nb_classes) 205 | return X, y 206 | 207 | def construct_split_mnist(task_labels, split='train', multihead=False): 208 | """Split MNIST dataset by labels. 209 | 210 | Args: 211 | task_labels: list of list of labels, one for each dataset 212 | split: whether to use train or testing data 213 | 214 | Returns: 215 | List of (X, y) tuples representing each dataset 216 | """ 217 | # Load MNIST data and normalize 218 | nb_classes = 10 219 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 220 | X_train = X_train.reshape(-1, 784) 221 | X_test = X_test.reshape(-1, 784) 222 | X_train = X_train.astype('float32') 223 | X_test = X_test.astype('float32') 224 | X_train /= 255 225 | X_test /= 255 226 | 227 | if split == 'train': 228 | X, y = X_train, y_train 229 | else: 230 | X, y = X_test, y_test 231 | 232 | return split_dataset_by_labels(X, y, task_labels, nb_classes, multihead) 233 | 234 | 235 | def construct_randomly_split_mnist(nb_splits=10, mode='train'): 236 | """Split MNIST dataset by labels. 237 | 238 | Args: 239 | nb_splits: numer of splits 240 | mode: whether to use train or testing data 241 | 242 | Returns: 243 | List of (X, y) tuples representing each dataset 244 | """ 245 | # Load MNIST data and normalize 246 | nb_classes = 10 247 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 248 | X_train = X_train.reshape(-1, 784) 249 | X_test = X_test.reshape(-1, 784) 250 | X_train = X_train.astype('float32') 251 | X_test = X_test.astype('float32') 252 | X_train /= 255 253 | X_test /= 255 254 | 255 | if mode == 'train': 256 | X, y = X_train, y_train 257 | else: 258 | X, y = X_test, y_test 259 | 260 | return split_dataset_randomly(X, y, nb_splits, nb_classes) 261 | 262 | def construct_transfer_cifar10_cifar100(nb_tasks=4, split='train'): 263 | """ 264 | Returns a two task dataset in which the first task is the full CIFAR10 dataset and the second task are 10 from CIFAR100 265 | classes from the CIFAR100 dataset. 266 | 267 | params: 268 | nb_tasks The total number of tasks 269 | split Whether to return training or validation data 270 | 271 | returns: 272 | A list with two tuples containing the two data sets 273 | """ 274 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 275 | # X_train = X_train.reshape(-1, 3, 32, 32) 276 | # X_test = X_test.reshape(-1, 32**2) 277 | X_train = X_train.astype('float32') 278 | X_test = X_test.astype('float32') 279 | no = X_train.max() 280 | X_train /= no 281 | X_test /= no 282 | 283 | if split == 'train': 284 | X, y = X_train, y_train 285 | else: 286 | X, y = X_test, y_test 287 | 288 | nb_classes = nb_tasks*10 289 | datasets = [(X,np_utils.to_categorical(y, nb_classes))] 290 | 291 | # Load CIFAR100 data and normalize 292 | (X_train, y_train), (X_test, y_test) = cifar100.load_data() 293 | X_train = X_train.astype('float32') 294 | X_test = X_test.astype('float32') 295 | m = np.max( (np.max(X_train), np.max(X_test) ) ) 296 | X_train /= m 297 | X_test /= m 298 | 299 | if split == 'train': 300 | X, y = X_train, y_train 301 | else: 302 | X, y = X_test, y_test 303 | 304 | # split dataset by labels 305 | task_labels = [ range(10*i,10*(i+1)) for i in range(1,nb_tasks) ] 306 | for labels in task_labels: 307 | idx = np.in1d(y+10, labels) 308 | data = X[idx], np_utils.to_categorical(y[idx]+10, nb_classes) 309 | datasets.append(data) 310 | 311 | 312 | all_task_labels = [range(10)] 313 | all_task_labels.extend(task_labels) 314 | return all_task_labels, datasets 315 | 316 | def construct_split_cifar100(num_tasks=3, num_classes=10): 317 | """Split CIFAR100 dataset and relabel classes num_classes 318 | 319 | Args: 320 | num_tasks: the number of tasks 321 | num_classes: the number of classes per task 322 | 323 | Returns: 324 | List of (X, y) tuples representing each dataset 325 | """ 326 | # Load CIFAR100 data and normalize 327 | (X_train, y_train), (X_test, y_test) = cifar100.load_data() 328 | X_train = X_train.astype('float32') 329 | X_test = X_test.astype('float32') 330 | m = np.max( (np.max(X_train), np.max(X_test) ) ) 331 | X_train /= m 332 | X_test /= m 333 | 334 | X, y = X_train, y_train 335 | 336 | # split dataset by labels 337 | # here we also flatten the labels of cifar100 to match num_classes via modulus operation 338 | task_labels = [ range(num_classes*i,num_classes*(i+1)) for i in range(num_tasks) ] 339 | datasets = [] 340 | for labels in task_labels: 341 | idx = np.in1d(y, labels) 342 | data = X[idx], np_utils.to_categorical(y[idx]%num_classes, num_classes) 343 | datasets.append(data) 344 | 345 | return datasets 346 | 347 | def construct_permute_mnist(num_tasks=2, split='train', permute_all=False, subsample=1): 348 | """Create permuted MNIST tasks. 349 | 350 | Args: 351 | num_tasks: Number of tasks 352 | split: whether to use train or testing data 353 | permute_all: When set true also the first task is permuted otherwise it's standard MNIST 354 | subsample: subsample by so much 355 | 356 | Returns: 357 | List of (X, y) tuples representing each dataset 358 | """ 359 | # Load MNIST data and normalize 360 | nb_classes = 10 361 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 362 | X_train = X_train.reshape(-1, 784) 363 | X_test = X_test.reshape(-1, 784) 364 | X_train = X_train.astype('float32') 365 | X_test = X_test.astype('float32') 366 | X_train /= 255 367 | X_test /= 255 368 | 369 | X_train, y_train = X_train[::subsample], y_train[::subsample] 370 | X_test, y_test = X_test[::subsample], y_test[::subsample] 371 | 372 | permutations = [] 373 | # Generate random permutations 374 | for i in range(num_tasks): 375 | idx = np.arange(X_train.shape[1],dtype=int) 376 | if permute_all or i>0: 377 | np.random.shuffle(idx) 378 | permutations.append(idx) 379 | 380 | both_datasets = [] 381 | for (X, y) in ((X_train, y_train), (X_test, y_test)): 382 | datasets = [] 383 | for perm in permutations: 384 | data = X[:,perm], np_utils.to_categorical(y, nb_classes) 385 | datasets.append(data) 386 | both_datasets.append(datasets) 387 | return both_datasets 388 | 389 | 390 | def construct_split_cifar10(task_labels, split='train'): 391 | """Split CIFAR10 dataset by labels. 392 | 393 | Args: 394 | task_labels: list of list of labels, one for each dataset 395 | split: whether to use train or testing data 396 | 397 | Returns: 398 | List of (X, y) tuples representing each dataset 399 | """ 400 | # Load CIFAR10 data and normalize 401 | nb_classes = 10 402 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 403 | # X_train = X_train.reshape(-1, 3, 32, 32) 404 | # X_test = X_test.reshape(-1, 32**2) 405 | X_train = X_train.astype('float32') 406 | X_test = X_test.astype('float32') 407 | no = X_train.max() 408 | X_train /= no 409 | X_test /= no 410 | 411 | if split == 'train': 412 | X, y = X_train, y_train 413 | else: 414 | X, y = X_test, y_test 415 | 416 | return split_dataset_by_labels(X, y, task_labels, nb_classes) 417 | 418 | 419 | def online_benchmark(datasets, model, loss, optimizer, epochs_per_dataset=1, 420 | ages=1, batch_size=256, callbacks=None, **kwargs): 421 | """Benchmark online learning. 422 | 423 | Sequentially optimize a set of tasks, and compute 424 | the predictions for each task over time. 425 | 426 | Args: 427 | datasets: list of (inputs, labels) tuples 428 | model: Keras model 429 | loss: string or function 430 | optimizer: string or Keras Optimizer object 431 | epochs_per_dataset: number of passes through an individual dataset 432 | ages: number of passes over datasets 433 | batch_size: batch size 434 | callbacks: list of functions to call with the model at each iteration 435 | 436 | Returns: 437 | labels: 438 | predictions: 439 | """ 440 | 441 | # Build the model 442 | model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy']) 443 | 444 | ndataset = len(datasets) 445 | predictions = [[] for i in range(ndataset)] 446 | labels = [[] for i in range(ndataset)] 447 | if callbacks is not None: 448 | callback_outputs = [[] for i in range(len(callbacks))] 449 | for cidx, callback in enumerate(callbacks): 450 | callback_outputs[cidx].append(callback(model)) 451 | 452 | optimization_data = [[] for i in range(len(model.get_weights())) ] 453 | for age in range(ages): 454 | for didx, dataset in enumerate(datasets): 455 | 456 | model.fit(*dataset, batch_size=batch_size, nb_epoch=epochs_per_dataset, verbose=1) 457 | # Log w, g, g2, ... 458 | # For all variables, ... , 459 | if isinstance(optimizer, Adam): 460 | weights = model.get_weights() 461 | opt_vars = optimizer.weights[1:] 462 | ms = opt_vars[:len(opt_vars)//2] 463 | vs = opt_vars[len(opt_vars)//2:] 464 | sess = K.get_session() 465 | stuff = sess.run([ms, vs]) 466 | for i in range(len(model.get_weights())): 467 | optimization_data[i].append([stuff[0][i], stuff[1][i]]) 468 | #optimization_data.append(stuff) 469 | 470 | # Evaluate on all datasets 471 | for eval_didx, eval_dataset in enumerate(datasets): 472 | # Evaluate model on dataset 473 | preds = model.predict(eval_dataset[0]) 474 | predictions[eval_didx].append(preds) 475 | # Convert from 1-hot back to categorical 476 | labels[eval_didx].append(np.argmax(eval_dataset[1], 1)) 477 | print(model.evaluate(*eval_dataset)) 478 | print("") 479 | if callbacks is not None: 480 | for cidx, callback in enumerate(callbacks): 481 | callback_outputs[cidx].append(callback(model)) 482 | 483 | if callbacks is None: 484 | callback_outputs = None 485 | # TODO(ben): might break some shit 486 | 487 | 488 | return dict(labels=labels, predictions=predictions, 489 | callback_outputs=callback_outputs, 490 | optimization_data=optimization_data) 491 | 492 | 493 | def save_zipped_pickle(obj, filename, protocol=-1): 494 | with gzip.open(filename, 'wb') as f: 495 | pickle.dump(obj, f, protocol) 496 | 497 | 498 | def load_zipped_pickle(filename): 499 | try: 500 | with gzip.open(filename, 'rb') as f: 501 | loaded_object = pickle.load(f) 502 | return loaded_object 503 | except IOError: 504 | print("Warning: IO Error returning empty dict.") 505 | return dict() 506 | 507 | 508 | def split_dataset(ds, split_sizes, permute_data=True): 509 | """ Helper function to split a single dataset into train, valid and test set. 510 | 511 | args: 512 | ds the dataset being a tuple of (data,labels) 513 | split_sizes a list of fractional split sizes of howto divide up the dataset 514 | 515 | returns: 516 | a list of datasets with the respective split ratios 517 | """ 518 | raw_data, raw_labels = ds 519 | if permute_data: 520 | idx = range(len(raw_data)) 521 | np.random.shuffle(idx) 522 | data = raw_data[idx] 523 | labels = raw_labels[idx] 524 | else: 525 | data = raw_data 526 | labels = raw_labels 527 | nelems = len(labels) 528 | nbegin = 0 529 | splits = [] 530 | for split in split_sizes: 531 | nend = nbegin+int(split*nelems) 532 | splits.append( (data[nbegin:nend], labels[nbegin:nend]) ) 533 | nbegin = nend 534 | return splits 535 | 536 | def mk_training_validation_splits( full_datasets, split_fractions = (0.8, 0.1, 0.1) ): 537 | """ Splits multiple a list of tasks into training, validation and test sets 538 | 539 | args: 540 | full_datasets: The full dataset as a list of tasks each being of the form (data, labels) 541 | split_fractions: A list of split fractions which should sum up to 1.0 542 | 543 | returns: 544 | a list of length len(split_fractions) each containing a list of tasks 545 | """ 546 | results = [ [] for i in range(len(split_fractions)) ] 547 | for ds in full_datasets: 548 | splits = split_dataset(ds, split_fractions) 549 | for i,sp in enumerate(splits): 550 | results[i].append(sp) 551 | return results 552 | 553 | def mk_joined_dataset( full_datasets, split_fractions = (0.9, 0.1) ): 554 | """ Joins datasets from multiple tasks to a single dataset as a baseline control and returns training and validation splints. """ 555 | l = len(full_datasets) 556 | data = np.concatenate([ full_datasets[i][0] for i in range(l) ], 0) 557 | labels = np.concatenate([ full_datasets[i][1] for i in range(l) ], 0) 558 | return split_dataset((data, labels), split_fractions) 559 | 560 | 561 | 562 | def main(): 563 | """ Test code for permute MNIST task 564 | 565 | Plots the first digit of the first two tasks. """ 566 | import matplotlib.pyplot as plt 567 | ds = construct_split_cifar100() 568 | plt.subplot(121) 569 | plt.imshow(ds[0][0][0].transpose((1,2,0) ), interpolation='nearest') 570 | plt.subplot(122) 571 | plt.imshow(ds[1][0][0].transpose((1,2,0)), interpolation='nearest') 572 | plt.show() 573 | 574 | if __name__ == "__main__": 575 | main() 576 | 577 | --------------------------------------------------------------------------------