├── .gitignore
├── LICENSE
├── README.md
├── fig_permuted_mnist
├── Figure permuted MNIST.ipynb
├── Permuted MNIST.ipynb
└── data_path_int[omega_decay=sum,xi=0.1]_optadam_lr1.00e-03_bs256_ep20_tsks10.pkl.gz
├── fig_split_mnist
└── Figure split MNIST.ipynb
├── fig_transfer_cifar
├── Figure barplot.ipynb
├── split_cifar10_data_path_int[omega_decay=sum,xi=0.001]_lr1.00e-03_ep60.pkl.gz
└── train.py
└── pathint
├── __init__.py
├── keras_utils.py
├── optimizers.py
├── protocols.py
├── regularizers.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Ben Poole & Friedemann Zenke
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Continual Learning Through Synaptic Intelligence
2 |
3 | This repository contains code to reproduce the key findings of our path integral approach to prevent catastrophic forgetting in continual learning.
4 |
5 | Zenke, F.1, Poole, B.1, and Ganguli, S. (2017). Continual Learning Through
6 | Synaptic Intelligence. In Proceedings of the 34th International Conference on
7 | Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention
8 | Centre, Sydney, Australia: PMLR), pp. 3987–3995.
9 |
10 | http://proceedings.mlr.press/v70/zenke17a.html
11 |
12 | 1) Equal contribution
13 |
14 | ## BibTeX
15 | ```
16 | @InProceedings{pmlr-v70-zenke17a,
17 | title = {Continual Learning Through Synaptic Intelligence},
18 | author = {Friedemann Zenke and Ben Poole and Surya Ganguli},
19 | booktitle = {Proceedings of the 34th International Conference on Machine Learning},
20 | pages = {3987--3995},
21 | year = {2017},
22 | editor = {Doina Precup and Yee Whye Teh},
23 | volume = {70},
24 | series = {Proceedings of Machine Learning Research},
25 | address = {International Convention Centre, Sydney, Australia},
26 | month = {06--11 Aug},
27 | publisher = {PMLR},
28 | pdf = {http://proceedings.mlr.press/v70/zenke17a/zenke17a.pdf},
29 | url = {http://proceedings.mlr.press/v70/zenke17a.html},
30 | }
31 | ```
32 |
33 |
34 | ## Requirements
35 |
36 | We have tested this maintenance release (v1.1) with the following configuration:
37 |
38 | * Python 3.5.2
39 | * Jupyter 4.4.0
40 | * Tensorflow 1.10
41 | * Keras 2.2.2
42 |
43 | Kudos to Mitra (https://github.com/MitraDarja) for making our code conform with Keras 2.2.2!
44 |
45 |
46 | ### Earlier releases
47 |
48 | For the original release (v1.0) we used the following configuration of the libraries which were available at the time:
49 |
50 | * Python 3.5.2
51 | * Jupyter 4.3.0
52 | * Tensorflow 1.2.1
53 | * Keras 2.0.5
54 |
55 | To revert to such a environment we suggest using virtualenv (https://virtualenv.pypa.io):
56 |
57 | ```
58 | virtualenv -p python3 env
59 | source env/bin/activate
60 | pip3 install -vI keras==2.0.5
61 | pip3 install jupyter matplotlib numpy tensorflow-gpu tqdm seaborn
62 | ```
63 |
--------------------------------------------------------------------------------
/fig_permuted_mnist/Figure permuted MNIST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# Copyright (c) 2017 Ben Poole & Friedemann Zenke\n",
12 | "# MIT License -- see LICENSE for details\n",
13 | "# \n",
14 | "# This file is part of the code to reproduce the core results of:\n",
15 | "# Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through\n",
16 | "# Synaptic Intelligence. In Proceedings of the 34th International Conference on\n",
17 | "# Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention\n",
18 | "# Centre, Sydney, Australia: PMLR), pp. 3987–3995.\n",
19 | "# http://proceedings.mlr.press/v70/zenke17a.html"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {
26 | "collapsed": false
27 | },
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "The autoreload extension is already loaded. To reload it, use:\n",
34 | " %reload_ext autoreload\n",
35 | "Populating the interactive namespace from numpy and matplotlib\n"
36 | ]
37 | },
38 | {
39 | "name": "stderr",
40 | "output_type": "stream",
41 | "text": [
42 | "Using TensorFlow backend.\n"
43 | ]
44 | }
45 | ],
46 | "source": [
47 | "%load_ext autoreload\n",
48 | "%autoreload 2\n",
49 | "%pylab inline\n",
50 | "\n",
51 | "import sys, os\n",
52 | "sys.path.extend([os.path.expanduser('..')])\n",
53 | "from pathint import utils\n",
54 | "import seaborn as sns\n",
55 | "sns.set_style(\"white\")\n",
56 | "\n",
57 | "# import operator\n",
58 | "import matplotlib.colors as colors\n",
59 | "import matplotlib.cm as cmx\n",
60 | "\n",
61 | "rcParams['pdf.fonttype'] = 42\n",
62 | "rcParams['ps.fonttype'] = 42"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 6,
68 | "metadata": {
69 | "collapsed": true
70 | },
71 | "outputs": [],
72 | "source": [
73 | "# Load data\n",
74 | "n_tasks = 10\n",
75 | "all_evals = utils.load_zipped_pickle(\"data_path_int[omega_decay=sum,xi=0.1]_optadam_lr1.00e-03_bs256_ep20_tsks10.pkl.gz\")"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 7,
81 | "metadata": {
82 | "collapsed": false
83 | },
84 | "outputs": [
85 | {
86 | "name": "stdout",
87 | "output_type": "stream",
88 | "text": [
89 | "[ 0. 0.01 0.02 0.1 ]\n"
90 | ]
91 | }
92 | ],
93 | "source": [
94 | "keys = list(all_evals.keys())\n",
95 | "sorted_keys = np.sort(keys)\n",
96 | "print(sorted_keys)"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 8,
102 | "metadata": {
103 | "collapsed": true
104 | },
105 | "outputs": [],
106 | "source": [
107 | "sns.set_context(\"paper\")\n",
108 | "sns.set_style('ticks')"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 9,
114 | "metadata": {
115 | "collapsed": true
116 | },
117 | "outputs": [],
118 | "source": [
119 | "plt.rc('text', usetex=False)\n",
120 | "plt.rc('xtick', labelsize=8)\n",
121 | "plt.rc('ytick', labelsize=8)\n",
122 | "plt.rc('axes', labelsize=8)"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 10,
128 | "metadata": {
129 | "collapsed": true
130 | },
131 | "outputs": [],
132 | "source": [
133 | "def simple_axis(ax):\n",
134 | " ax.spines['top'].set_visible(False)\n",
135 | " ax.spines['right'].set_visible(False)\n",
136 | " ax.get_xaxis().tick_bottom()\n",
137 | " ax.get_yaxis().tick_left()"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 11,
143 | "metadata": {
144 | "collapsed": true
145 | },
146 | "outputs": [],
147 | "source": [
148 | "def plot_data():\n",
149 | " marker = iter(['o', 's', 's', 'd', 'o'])\n",
150 | " plot_kwargs = dict(alpha=0.9)\n",
151 | " \n",
152 | " for cval in sorted_keys:\n",
153 | " stuff = []\n",
154 | " for i in range(len(all_evals[cval])):#n_tasks):\n",
155 | " stuff.append(all_evals[cval][i][:i+1].mean())\n",
156 | " plot(range(1,n_tasks+1), stuff, '%s-'%next(marker), label=\"Test (c=%g)\"%cval, zorder=2, **plot_kwargs)"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 12,
162 | "metadata": {
163 | "collapsed": false
164 | },
165 | "outputs": [
166 | {
167 | "data": {
168 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAC7CAYAAACqwUiwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXl8FEXa+L/dcyWTyUkOQkI4wmE4FFA88IYVr8VbuURR\ndJVVQBSEXXFxfyoriq7Rd19dUTlFI554i6CvaESQ04CBJJCQALkTckwyV/fvj55MMpCQBDJJyNT3\n8+lPV1d31zwzyfNU1VNVT0mqqqoIBAK/Q+5oAQQCQccglF8g8FOE8gsEfopQfoHATxHKLxD4KUL5\nBQI/RSi/QOCnCOUXCPwUofwCgZ8ilF8g8FOE8gsEfkqXUH6n00leXh5Op7OjRREIzhi6hPLn5+cz\nZswY8vPzO1oUgeCMoUsov0AgaD0+Uf6CggJuvvlmhg4dekJTfP/+/UycOJEJEyaQnp7eZJ5AIPAt\nPlH+sLAwli9fzrBhw064l5yczEsvvURycjLJyclN5gkEAt+i90WhJpMJk8nU6L2KigpiY2MBqKys\nbDJPIBD4Fp8o/8lQFMWTrgsi1FheU6SkpJCSkuKVZ7fb21BCgcA/aHfllyTJk5Zlucm8phg/fjzj\nx4/3ysvLy2PMmDFtKKVA0PVpd+UPDQ0lPz8fSZIICgpqMk8gEPgWnyi/w+Hg/vvvJz09nWnTpvHQ\nQw+xbds2pk+fzowZM3jkkUcAWLhwIUCjeQKBwLdIXSF6b12zf8OGDcTHx3e0OALBGYGY5CMQ+ClC\n+QUCP0Uov0DgpwjlFwj8FKH8AoGfIpRfIPBThPILBH6KUH6BwE9p9+m97c6SgQBYC7RLc4w7f86+\n0yrWun2HVt6I4adVjkDQUXR55R8dYQDgvh9dALyZpANg42mWW7psGdA2yi8MiaAj6PLKD9DnsEqf\nw/Xpg3HSyV9oBuv2HdTs3OlJn67StqUhEQhail8o/5itChKgd8KtGxT+b4RE+YcfASqqywWKgqoo\noKiguFBdCqgN0+777nTFV1/jqq5GAo4++SRhN98Eej2SwaAdevfZoEcyGpH0eu1w38dz34AtYz/W\n7dtBltvEkIhWhKCl+IXyA+hcEGiHuCK44UeVfalPodcZMeiNGAwBGHRGJFkGnYwkyaDTgSx50pIs\ngySBTsZVXg4OBwrgLC7i2JdfgsOJ6nB4Dlq4Xsp17BiqO85h7gMPEHTxxRh7xmOI74khPg5jQgKG\n7t01o9EC2roVIYxJ18UvlP+78yT6HFapMMN/b4CMnhI3Rl9EhlHhwLEDqKpKkMFAUrckBncbzJDI\nIZwVcRZBhsZjC1i37+DwrFkAxCUnN6oYqsuF6jzOINSl6/KdTgqffwFbRgaqoqCPikI2m6nZuYuK\nL77U3gHQ6TB0746hZ0+PYTD2jMfQsyf66GjNMNH23REQXZKuTJdXfntgJFvjHVzWoxIV2BYfTJBq\nZdae7zFdtwjrldeRXppOWnEae0r28MH+D1ixZwWSJNEntA9Dug1hcORgBncbTGxQLGPWahGD7ovS\nYg0+kPkYZMLGO7xdiJJOh6TTQROxDOuImj2bw7NmIQGxTz/tUTLV5cJZWIgjLw/7oVwch/Ow5+ZR\nnfoLjvyPwaU5MCWDAUN8PIb4eGp27UKx2ZBkmaLkZLo/tRDZHIQcZEYODNTkaQVtbUxEK6Jz0eWV\n/5Nx33Dj//zMt30zqbI5qcnqTYjFgOvsDfDtAsySxIghtzIiZgQAiqqQU5HDnpI97Cnew/bC7azL\nWgdAeEA4lfZK9LKeby40oZdbp0yNYR4xnEB3lOOGSiHpdBhiYzHExmIeOdLrHdXhwJGfj/3QIRx5\nh3Hk5WLPzcNZWIhSXQ1AzY4dHLrrbq/3JJMJOSgIOTAQ2WzWjILZjBQYqOWbzciBZs+9snfeQXXY\nQZIp/s9/6PH8YnTBwUhG4yl9V9GK6Fx0eeWPtJiYOaYfr2yAWocL1ergnIQIzNctAoMOvnlCe3DI\nrQDIkkyf0D70Ce3Dn/v+GYDy2nL2luwlrSSN/+76L1anlZ0xgAQ6Ww06SUdKegr9w/vTL7wfIcaQ\nVskYcc89rXpeMhgw9uyJsWdPr3zr9h3kzZoJikrMvMcx9uqFYrWi1NSgVFtRqqvd11btbLWiWmtw\nlZTiOJRb/6zVilpb6+WPsG7dysGbbtY+32RCDg5GFxysnS0W5JAQdMEWZEswupBgZEswcrAFXUgI\nssWCPTsb686dSIhWRGehyys/wPiRCYxJimF/QSU/ZRTz6c4j5JbX0vOqp7UHvnkCkGDILY2+HxYQ\nxqi4UYyKG8V76e+houJSXDhVJ07FiUtxsWzPMuwuLYpwjDmGfuH96BfWj/7h/RkQPoBuAd28ApUC\njH5/tPcHZWqn47sQLcU8YjjmYZoyhN5wwymVUYfqclG9eTNHHp8HqkrkrJkYe/RAqazEVVGJUlWJ\nq7IKpbICV2UVjtxcaisrtftVleDw3qzFY0gkibyHHiJo1Cj0UVHaER1Vn46KQg4OPuG3Oh4xz+L0\n8QvlB60FEGkxMbxnOD9mFPHit/v49/hhSFc9rXnmv/m79mATBqAhEhJ6WY8ePbhb/p/f/DmHKg+R\nWZZJZrl2fLD/A6odWjM81BRK/zCtZdA/rD/9w/v75Hu2thXRFJJOh+XiizGP0LpDEcdFTD4Zqqpq\nLQe3MVAqKyl4/gVsmZmgKOhCQ7VhzsxMqn/5BVdpqfdnm0xuoxCNPjKygZGIRh8ViSM/X8yzaAP8\nRvnrCDTqmP2nAfzto9/5MaOYywdEwdhnALcBkGQYfFOry9XLevqG9qVvaF/GMhbQlOBo9VEyyzPJ\nKMsgszyT9TnreS/9PQBKa0vRSTr0sh6drEMvaefToa3/gU/FmEiSpPkRAgMhOhqA6DlzPCMksYsW\necmpOhw4i4txFhVp58JCnIVFOIuKcBw+TM3OnThLSjxOTk8rQpbJ/etfCTx7qObYNJtPPIIapoM0\nudxp2WymNj29zUdIzhT8MoCnqqrMeHcHuaVWUh64iACDDhSX1vzf+wlcuxgG3egzeUtqSsgsz+SR\n7x/BqThxqk6vjUsGRAwgITiBXiG96B3aWzuH9CbUFNpkmSd0IdycahfCF+TNmAlA/KuvtPpd1eXC\nVVaGs7CQ/Kef8bQi9NHRWC67FMVa08Cn4Z3mJFu3u44d0yZ6SRL6bt0Iu+02TIl9MSYmYuzdG7mZ\n0ZrGOFO6EX5X84NWM829eiCTlv7KitRsHrg8EWQdXP0soMJX8wEJBp1ev7kpugV2o1tgNwL1gZ68\nOj+CS3Vxde+ryT6Wzdb8rXya9alnF6NQUyi9Q3qTEJJA7xDNKCSEJNAtoJtP5GxrTqdLIul0Whcg\nMpLouXPrWxHPPntSJVNVFdXh0IxBtRW1xuplGIpfex17drbW9ZNlqjZtovz997WXZRlDfDymxETN\nIPTVzvru3T1zKxrjTOlG+Ez5Fy1aRFpaGoMGDWLBggWe/J9//pnk5GRMJhNPPfUUiYmJPPvss57d\nedPT09m6dauvxPLQq1sQky5IYOUvOVw3NJaeEWa3AVik/SN8NU+b0Zc0zueygLcf4b6h93ny7S47\neZV5ZFdkc6jiENkV2ewu2s2XB7/EpWjN4CBDEMdsx9DJOnSSdsiSjE7Woapqs86z9qKtlKGp4dHG\nkCQJyWhENhohPPyE+7rwiPoJWy++iHnEcFxV1dgPHsR+IAtb1gFsB7Kwbt2KUlWllWkOxNSnL8bE\nvpj6JmLql4gxMRGdxXJGzY3wifLv2bMHq9XKmjVrWLhwIbt37+bss88G4D//+Q/Lly+nqqqKRYsW\n8fLLL/PEE9pw2969e3n77bd9IVKj3HtxH75Oy+el9ft56Y5zNCWRdXDNv7QHvnxcO/vIALSkSW7U\nGekb1pe+YX298h2KgyNVR8ipyCH7WDav7ngVp+LEptqgQUfuzx//mWhzNFHmKKIDtXOMOaY+zxyN\nSXdi07azdyPayrHZmCHRWYIIHDqEwKFDPM+pqoqzsAh7Via2AwexZWVSm7aHii+/8nQr9NHROEtL\nUKxWkGUKl7xAzN+fQB8ehq5bN+SAgFbL58tWhE+Uf+fOnYwaNQqAUaNGsXPnTo/yA5jNZsxmM4cO\nHfJ6b/369YwdO9YXIjVKoFHHI3/qz98++p1NGcVcNiBKu+ExAKrbAEiQ9Od2k6slGGQDvUJ60Suk\nF5fFX8byPcsBrfugqIrnmJw0maKaIgqthewr28dPh3+iwl7hVVaIMcRjCGLMMUSZo7C5bFrrwd2K\n6Gy0pTK0xJBIkoQhJhpDTDRB7v9tANVux56biy0zC/uBLMpS3kex2UBRqE3bQ9706fVlmAPRh0eg\ni4hAFxHuSesjwrW88Pq0HBDgk+naDfGJ8ldWVtLTPQElODiYjIwMr/vFxcUcO3aMAwcOeOVv2rSJ\nv/zlLyctu6136R19VjTn94ngxfX7Ob9PhOb8A7cBeE7rAnw5V+sCnHX9KX9OeyEheZr+AJOSJp3w\nTK2z1mMQiqxFFFgLKLIWUVhTyI7CHRTVFFFlr/I8r5N1BOgCMOpObWZfZ+d0lEoyGt0+gUStrAsu\n9HQjuj/9/zDGx+MsKcVVVoqztBRXaRmu0lKcZaXU5ObhKi3FdezYCQvBZLMZZ3k5ak0NkslE6bJl\nZ4byBwcHU+XuH1VVVRESUj/jbe7cucyePZu4uDhGuMeQAbKzs4mJiSEwMPCE8hrS1rv0NnT+rfwl\nm79cllh/U9Zpnn9U+GIOIMFZ153S5/ia1jTHA/QB9AzuSc/gnk0+c+X7V3ockDaXjWpnNdXOaham\nLuTq3lczsvtIDHLLVhr6Ew27EcFXXAGAqd/J31GdTlzl5dpoRmkZrtISnGVllK1+B8fRo+Ajl41P\nlH/YsGGkpKRw3XXXkZqayi231E+cGT58OKtWrSI7O5vVq1d78tevX89VV13lC3Gapc75tyI1h2uH\nuJ1/dcg6uGaxuwUwR8vrpAagLWnogDTpTCiqgs1l43DlYRb8tICwgDDGJIzhmt7XkBiW2HyBfkSr\np2vr9Z6RjIbel4CzkjytiLbycTTEJ8o/ePBgjEYjkyZNIikpidjYWF577TWmT5/Oa6+9RmpqKuHh\n4fzzn//0vPPDDz/wv//7v74Qp0U06vyrQ6eHa58HVEiZDAGhcLyT7DRjAnZ2ZEkmUB/I0rFLySzP\n5Ovsr/ku5zs+3P8hfcP6cnXvq/lTwp8IDzjRo+5vdMSoxqngl5N8mmJjegHzP/ydJbefU+/8a4jL\nCc/Fg9N2ogHo4srfGA7Fwa9Hf+Wb7G/YfHQzAOd3P5+re1/NRbEXYdCJbsHpcsYN9Z2pjD4rhi1P\nxDT9gE4PplBo/aSvLolBNnBJ3CVcEnfJSZ/r7MOGnRlfThTqEjW/0+kkPz+f7t27o9cLeyYQtIQu\nofwCgaD1dL7ZGwKBoF0Qyi8Q+ClC+QUCP0Uov0DgpwjlFwj8FKH8AoGfIpRfIPBThPILBH6KUH6B\nwE/xifIXFBRw8803M3ToUJzHRU7dv38/EydOZMKECZ64fY3ltQan00leXt4JnyUQCE6C6gNqa2vV\n8vJy9c4771QdDofXvb/+9a/qkSNH1Pz8fPXBBx9sMq815ObmqgMGDFBzc3PbRH6BwB/wySoYk8mE\nqYl45xUVFcTGxgJauK+m8tqatlwa2dbLLM+UOO+CrkW7L4FruDmF6l5T1FheWzH6bS0C632faGGu\n37xJi2238d60Uy6zrSOqdtZ95zprWW1dnr8a33ZX/oYRcmT3xgeN5TXFqQTw7HNYpc9hLd07T+Fg\n/Km5Oka/P5o+h+zct7kMgIVLLuFggvGU16W3dXnQtoaks5bV1uX5g/FtjHZX/tDQUPLz85EkiaCg\noCbzmuJUAniO2aIgq2B0wF1fquzs5+LrrddjNJrR6c3uzS7cYaqR6ze/QEaWZXRo92/PKWdwhp3A\nWq11cveH5eztbyJ/3zNadF9JAqmhMXOfPfca3JckbkivYNjeWgJsWnmTPy3n12FmSqqXI+n1SAY9\nkl4Pej2S3tB4nvv6ge8fIrbAwa2plajASwsv4lAPA++Oew9kWduDUELbaUaSQJbd+xTI2oGEJGvX\n13x0Lb3y7Ez7pRw4faPki40s2qq8tpatsxqlxvDJen6Hw8H999/Pnj17GDRoEA899BDbtm1j+vTp\npKen89RTTwGwcOFCkpKSGs1rDScL43Xl20O4/2MX/fLAZAdrABzsDmZV29tCkUBBRpVA9SiopgxI\n9ftfqBLYnTb6HQFLrZZXFQBZPcBisCAjISMjS7I7LSE1SMs0yHefj1YdIb7QRaBNMxO1RigJ09Ej\nIAbJqSC5FCSXC8npAiQ0myK5TYr77DY0pbWlmGsU9FrvBqcOrIEykaewlVdxbckJZVWbZaIsMdqn\n6nQgS0iyDiQJSSdrgU4bGhSdjCTJpJXuIabY6TFw1gCJQz2MXBg/CkmvA50eSZaRDHqQdVqerEPS\n6UCvQ6rL0+m1z9HrqfjyK5xHjwJgiIsj9Jab3UZNBln7Gza8ltzGDVn2GDgkLV2yfDn27BwATImJ\nRM14GMlo1AyswYBsNCIZDEgGA7jPksGIbHRfN2i1WrfvqN/9Jzn5tI1SW5XVFF0imMfJlH/020Po\nnadw/6fa11x6o8SBOIle5QuxWPOIo4DBgWUMMBURpeSjd5TglCRsEthMQdhD47AHx2KzRDP70Mf0\nzYPpn2o+iv/cKJMVL/Hg+XNxKA4cLgcOxYHNZdOuFQd2l92TX3dtV+w4FAd7i/fSL8/F7I8cqMDL\nN+vZHy/TLfA4hVVVZAV0Cuhcqucsu896F1RYS5j6nUJCofs3iYT3L5PpYYlHh4zefUhI6NFaNTok\n9JIO2Z0nq6BDJq3od67b7KB7qfY9C8NlNow0cVXCnzBJBkyyESN6jLIBk2TAgE47S3p0SKguBVQF\nVVFY8fsyzv29hrAKraxjwTJpAwK4rd8t4HKiOl2oLic4XaiKouW5lAZ5Lu3sqn/OduCAtisOIAcE\nYOjRw/15KigKqCqqqoCigsuFiqql3fca4tnxFy2Kri606c1QG8Wg9xgHZ0GhtjGoJKGzWDANHKgZ\nHp3bmMmyZsB0es0IuQ2fdl92Gz3tXPXjjziLipB0OoIuuuiUNjdtDr+IeXUwXuZgD60aOxAnIUkS\nyx+9ncPlNWw/VM72nDLeOlRGflUtJtXGBeFVjIqoYEhgKb3lQgIqc+HQdoxmlbwecFAbmOBID5VA\nVeXu4gIIjoXg3hDcHUJ6QGC4p1ZuitFvD6GkO2T30K5LYlUiXC4+v/lznKrTYzScitNjPOqunYoT\nu8vuSS/YMIN1l8j81W2YPr5M5lC8xJ8vuBOX4sKpOnEpLhRVwaE4UFQFl+rSylZdWN0x+p2Kky3Z\nf1AWLDHjQxsAK8bq2R+nsj/4D6wOq6ZMTWCQDZgNZoIMQZj1ZraFKGxLMHjKevOGAA4kGDnn0kuJ\nCIggLCCMcFM4gfrAFu0pqPlJDNyXon3PN8ebOJhQ2eIuiaqqmgFQFK5aexV9D1m4d63WvVl9YxCH\nYyVSrlmD6rCjOhyodgc4Hdpmn3YtD6cT1Z2uO17d8jKjKmroZtP+z8p1Fewy/M7EAeOhzhjWGTOX\n4jFqqAqqzYZyXL5itWqf08AZ3tZ0eeWv8+ofHfgrh0qtfHDlKCIt2jBkfLiZ+HAzN5yjad+R8hq2\n5ZSx41A5Kw6VcbS8BoA+kUGM6BeGevQ2AL4bCZJbBSQV2P8NVBWAy1H/wXqj2yDEasYguHv9dXCs\ndu1mw8h6B6QEmA3ufQNcDrBVatGCbdVa2l4FtgqwVTW4rsSkKhyOcxsmCY7EQaCqMkUfA2EJENYT\njCf3p9SReiSVgkTITtAckUWJ4YQDn938GaqqUuOsweq0YnVYqXJUYXVYPdfVjmotz32940gqGbGQ\n0UMzGGkxtai2WhamLvT6TKPOqBkDUxjhAeGEm8IJCwgjIiDCkw43haOoCgcTjBzsqe0edDChdbsI\n1flbkGUUnURmHxMH3GXsHajtpWeMj2tVmQA/6JaSE2fgvhTtN1t9UxgHE4w8csfcVpcF3s1+X8Ts\nhxYo/z333MMyt+MB4NFHH+Wll17yiTC+JPbiC4ht5pkeYYH0CAtknNsY5B+rZfuhMrbllPHrwVLe\nL6xGcTcbJSSkXE1Zi+d8SaTZANZiqDgClflQebT+KM6Ag5ugusjr8zZWF2l95bq98A67m6WvXawp\nuPMkoxg6A5iCwWjRzm42nO9O1DVv182of8fczW0I3MbAk+7l3VKp0voOG4arXtegKY/ZYNYM1Mk3\nVwLgi93a/86PbgMXoWh+lLXj1lJuK6estowyW5l2ri2j3FZOaW0pGeUZnmunUj9zs6y2DCT4+Dxt\n9+Fym1ZrP7j+QfSyHoNsQCfrMMiG+mtJp6V1BvSS3nPP6tS6Dl9eoKlBjbMGCYl1Wes8+xRKkoRe\n0mu+HPfh2QXZfZYlGYfiYH+8xIF4razMnjokFJyKE73cujq2LtrxfVHanJcHMh+DzLaPdtykVJs3\nb2bz5s3k5OSQnJwMgMvlorCwsKlXuhzdQwO4bmgs1w3VzEbtc3oqap2gahtiqirIkkRGQSWRiZFg\nidaOpnDaoSpfMw4VR2Ddw+5+qLtpJ+s1Q3DOhHqlrjuOv9Z7T6LauGSglqhTyDL3+cGf4NghKMuB\nY7lQfgjKsuHg/0FNeX0BxiC3QejFxvwyzSgZ3Yap2P1MwR5QXKBqzVVNdndaVRq/5+4iHPRUpioS\n0E020i20L4SdvKmvqipVjiqPkfjr55NRgdxYrWSDU2tm9wvrh1N1erpBdf4Vq8Pq1VWq6065VBc2\nexUqkBbjltJta1/e9vJJZWqMilrtB//kXO37lteUADD2g7EE6AOwGCwEGYI8h8VgIcgYRJA+CIvR\n4skzG8w4FAcSEt9d1LKu0KnSpPL37NkTWZbJzc3loosu0h7W65vdSLMrU/XwHm7+n59xuBRcikpF\nrQNFhZcqbC0rQG+sr20BvprX+HOjZjSefypYorQj7twT79VW1BuE8lwoz9HSztp6g9SQVbecmNcM\nG6ubqCxePQ8MgWCJcRvNmAbpaE9aCoom2BhMsDGYBBIwNeFumDNyTusEU1yMXn5O4zJfsxoFcIHm\nG0FFARRUXKi4VBUVFRfgUhUU4M5PbkQFCuI0ZQ12jyY9et6jVNmrqHZUe44qRxUltSXkVOZgdVip\ntFdidVo9E9zqDMm2KK1lGVF1gohtQpPKHxcXR1xcHCUlJZx/vtaWVFWVr7/+mmuvvdY30nRyIi0m\nZo7pxysbMgGFSIuJ+PBAnv3yD44cq+WBy/oiy76z1CflVHYMCgiBgMEQM9g7f8lAtOaNy9s7fucH\n7vkC7hZBXZfFkz7+Wob/jMRrp0nV3SK4fonWAqou0vwlFYfhyHati9HQdwJal6TOIKB4l1fHjy+A\nowYcteA87uywagbNUeM+14LLDuG6xn+XN65EpnWr3nRNlPXnXZ9DUCQERYE5EsLPcl+78wxaU63O\nl1LtqObWlCtR0YyHL/+bmu2MvPvuux5llySJ9957z2+VH2D8yATGJMWwv6CSATHBRJiNLE/N5r8/\nZpFRUMk/bxxMcEALt6nq1Ft8SSDpvf/7ug89hXKOUyHJbSSSxjX+uKpCTZlmFCrzNWNQVaAd1UVs\nPFxU383wyCmB6yswmEEfAIYA0Adq3SNLjDsvsP5sCAR9IBu/nOP2cxynYnesqP8M1a2Gquqdpyru\nfC1v42ezvL9D3bNhLshP0/xB1SUntqiMQRAUiRQUhTkoErM5Er3XcGQHNPvrcDgcHDt2jNDQUMrL\ny7HZWtjE7cJEWkyeEQOAey/pw4CYYP7xaRpTl21lye3n0CeyZZ71Lk9rDZwkgTlCO6IGnng/6/vG\n37v/FJxh3y5oPL/PZa0v66v5jeffurQ+rbg0P4u1WDNu1UVQXew+irT84ozj5iKo9Q7hNqZZ5Z8z\nZw4PPfQQqqoiyzKPP/64TwQ507mkfyTL7hnJ3LW7uXf5Vp66YTCXN7bZ55lAp26RnMHIOgjqph2N\nGTY3Gz3dLncXx0fK36IZfg6Hg9LSUmJiTrKJZQfSVrv0tgVVNidPrdvDj/uLuP/Svky7pE/H+QEE\nZyZLmjAMbWyUm635P/74Y9atW0dJSQkff/wxs2fP5pVX2n6qYVfBYtLz/K1n89ZPB1m66QDp+Zof\nwGLq8vOpBG1FO7W8mm1PrF27lmXLlhEaGopOp6O8vLy5V/weWZa4/7K+LLn9HLYfKuPeZVvJKanu\naLEEAi+aVX6dTkd1dTWSJFFbW+vTSQddjcsGRLFs6kgUVeWeZVvZlFHU/EsCQTvRrPLPnTuXmTNn\ncuDAAWbOnMljjz3WHnJ1GXpHBrHsnpEMSwjjsfd38eamAyjKGb+QUtAFOGlHVFVVMjMzeeutt9pL\nni5JcICBJbedw9JNB3jjxwPsy6/kqRsGEyT8AIIO5KQ1vyRJ/N///V97ydKlkWWJBy5PZPFtZ7M1\nu5R7l28lt9Ta0WIJ/Jhmq56ysjLGjRvHwIEDkSRtLfzzzz/fHrJ1Sa4cGE2vqWbmrN3F3cu28PSN\nQxjYPdgzY7Dh5CGBwJc0q/xz584lIiKiPWTxG/pGWVhx7/k8+Uka01dvQ5IkTHoZg05m5ph+jB+Z\n0NEiCvyAZh1+L7/8smeRT90hOH2CAwz87VotVmFlrYOSajul1XYWf72P7GIfLeMSCBrQbM0fHR3N\nG2+8wZAhQzzDfHVLfAWnR1ZxFQEGHXqdjM3hwu5SKLe6uPW1VC7sG8kl/SO5pF8kPSPMHS2qoAvS\nrPLHxcVht9vZvn27J08of9swICYYg04GFPQmPUFo61r+cmlftueW8z8bM/n3+v307hakGYL+kZwd\nF4peJ/ZXFZw+LZrbX1RURF5eHnFxcURHnyRSTQfRmeb2t5aUrYd4ZUMmDpeCQScza0x/7hjZE9DW\nCfx6oISfMov5ObOYcquD4AA9oxI1Q3BRYjdCjls+XFxlE85DQYtoVvnffPNNfv31V8466yz27t3L\nhRdeyP3aPVgJAAAYx0lEQVT3399e8rWIM1n5oWUK61JU9h6pYFNmET9lFJNZWIVOljgnPoyL+0dy\nWf9Ifskq4dWN9YZEOA8FJ6NZ5Z80aRJr1qzxXE+cOJF3333X54K1hjNd+U+Fo8dq2JShtQh+yy7D\n5nBRUevAoJMx6mX0OhmjTubThy8WLQBBozTbeTQYDGzfvp3a2lp+++039HoxK60zEBsayB3n9SR5\nwnDWP3oZ0y7tozkOnQrHahyUVNkorrLx7/X72X6oDLvTd/HfBWcmzdb8R48eZenSpRw6dIhevXox\nbdo0evTo0V7ytQh/rPmPp7jKxo0Ngova3ecoi4kahwuTQeac+DBG9o7g3F7hnNU9WDgO/Zxmq3Gb\nzcaTTz6JJEmoqkpOTk57yCVoJccHFw0xGJg1pj+3jIhjX0El23PK2Jpdxls/HeQ/32diMekZlhDG\neb0iGNk7nMQoywlBR4TzsGvTbM1/9913s2LFiiavOwOi5q+nOYV1uBT2HKngt+xSfssu4/fDx3C4\nFEIDDZzXO5xz3cZAOA+7Ps3W/LW1tZ60qqpe14LOx/HBRY/HoJMZ1jOMYT3DuO9SqHW42J13jK3Z\npWzPKePFb/fhcCocq3WglzXnoQq8siGTMUkxogXQhWhW+W+66SamTp3KoEGD+OOPP7jppptaVPCi\nRYtIS0tj0KBBLFhQHyV19uzZFBcXY7fbqa2t5dNPP+XVV19l/fr1hIaGMnr0aO7x0d5kghMJMOg4\nv08E5/fR1m9U2Zys2ZxD8oYM7C4FW622I45Jr+PDbXnce0kf98QkwZlOs8o/ceJErr76avLy8rjv\nvvtatMhnz549WK1W1qxZw8KFC9m9ezdnn302AP/+978BWL9+PWlpaZ535s+fz6hRo071ewjaCItJ\nzy3nxrPilxzNeaiq2BwKdqfC0h8PsHZbHmOSorlmcHfOiQ8TwUnPYFo0bhcREdGqlX07d+70KPKo\nUaPYuXOnR/nrWL9+PXfffbfnesmSJYSEhDBv3jySkpJa/FmCtsfLeej2B8wc048RCeF8nZbPN3vz\n+Xj7YbqHBjB2UAzXDOlOv+jg5gsWdCp8MmhfWVlJz57aFNXg4GAyMjK87jscDvbv38/gwdo2UVOm\nTGHGjBlkZ2fz97//3WtSkaBjOH5norq+fv+YYB66sh87csv4Zk8BH+04zMpfcugXbeHqwd25enB3\nuocGdLD0gpbQIuWvqqqisrLSs5Fgc+P8wcHBVFVVed4NCQnxur9lyxbP/n8AYWFhAPTu3btZWVJS\nUkhJSfHKs9tPspW14JRpynkoyxLn9org3F4RzBk7kNSsYr7ZU8DSTQf4z/eZDOsZxjVDujPmrBhC\nzdraAzFs2PloVvmffPJJjhw54rWg51//+tdJ3xk2bBgpKSlcd911pKamcsst3ru7rl+/nuuvv95z\nXVVVhcViobS0FJfLddKyx48fz/jx473y6ob6BO2PUS9zxcBorhgYTZXNyffphXy9J5/nv97Hi9/u\n54K+EVhMer7bW4BTUcWwYSeiWeXPy8tj2bJlrSp08ODBGI1GJk2aRFJSErGxsbz22mtMnz4dVVXZ\nuXMn//jHPzzPP//88+zfvx9VVUV04DMYi0nPuHN6MO6cHhRV2vjujwLW7TzCrwdLkJAw6mXMRp0Y\nNuwkNDvJZ968eQwaNIgBAwZ48jrben4xyafzkppVzIw1O7A5FWodLhRVxaiXefL6QUy+sFdHi+fX\nNFvz9+zZk8rKSrZt2+bJ62zK3xpE37N9GRATTIBBh06WMBt11Dpd1NoV/r1+Pz9nFXPPxX04r1e4\n2AymA/CrYB7HB85oSd/T4XDwzjvvsGnTJs4//3wGDx7MJZdccsJzTqeTF198kXnz5nnysrKyyMjI\n4Jprrmn2OyQnJ2OxWIiOjmbUqFGsX7+eCRMmNPvemcDxv/uM0f3oZjGxPPUgGQVVDO4Rwj0X9+GS\nfpFi3kA70mzNfyYE82iK3FIrVTYnAGXVdl78dj9O9245TsXFi9/uJyHCTJjZCGh91uPj5RkMBqZO\nnUplZSUTJ07kvffeIysri8rKSm688UZWrFhB3759GTFiBBkZGV4jGevWrWPWrFlkZmaybt06YmJi\nuO222/jss888oyEWi4XRo0djMpmYNm0aixYtYty4ceTl5bXXz+Rzmho2/FNSND9nlrDs54PMWbuL\nxCgLUy/uzZizosWKw3agWeXfuHHjCcE8zgTlL6u2c/vrv6C4GzZ2p0JFreOE56av3o5Rr/2jyZLE\nV7MuJTzI2GS5mzdv5s9//jOFhYWEhYURExNDeXk58fHx9O/f32sIs7a2FlmW2bBhA3fffTfdunVr\nsty6Zm/dWVEUXC4XOp2u9V++E9LYsKEkSVzSP5KL+3Vj+6Eylv2czZOfpPHf8EDuuqg31w2N9fxt\nBG1Ps8pfF8xj0KBBpKWlnTHBPMKDjKx98CJPzV9utTPrvZ2emh9AL0skTxjmVfOfTPFBm7FYU1ND\n7969KS8vx2w2c+DAAaqrqykvLyc1NdUzuzEgIABFURgzZgwrV66ke/fu3HLLLdx2220nlGuz2Xj7\n7bcZMmQIALIsdxnFbw5Jqp83kHb4GCtSs1n05R8s3XSAOy/sxU3D4gg0+sdv0Z74VTCPkwXL9AWt\n6fM3pLS0lG+//bbL9PlPhczCKlb+ks23ewoICdQzYWQCt58XT3BA208a8lcncJPKr6oqkiShKIrX\nNWi1UmeiNUN9/vqHPlPJLbWyenMOn+8+ikkvc+u58QQadLz108E2iTVwKk7grkKTyv+vf/2Lv/3t\nb0yZMsWj9HUGYOXKle0qZHOIcf6uT2FlLe9sPsQH2/LIr6jFpJfRuUcGdJLEA1ckYjHqUFRQVBVF\n1f5f1QbXiqqiKCoqWrqq1knK1lxc7rkHRp22ZZq/BD1tsgP/t7/9DYDp06d7LbX97bfffC+VL1gy\nsPH8OfvaVw7BKREdHMDsqwYwpEcIs9/fRY3D5VlrIiGx/OeDBBn1IGnGQJYkJAmvsyxp/oW6c1Wt\ng1qnCxUtqIlRL2Mx6dlfUOnfyl/H66+/7qX8y5cv57zzzvOpUJ2JjhjnHzduHAC5ubmekZapU6dy\n9OhRXn/9dR599FEGDBjAW2+9xeTJkwkI8J9VdMN7hRMaaMDcwAF4qrV1w6CnNqdClc1JudXBweJq\nLurbrctPPGpS+T/88EM+/PBD9u/fz+TJk1FVFVmWGTp0aHvKd3qU5YCtUksrJw7zAZBfH1AEUzCE\ne0857ahxfoBvvvmGqVOnYrfbWb9+PXfeeafXAqZLL72Ub7/9lhtuuKGNfrDOT8NYAw0dt6dSU3sH\nPQWz0UTvyCCSv8sgNbOE+dee1aX3SWxS+W+99VZuvfVWNm7cyOjRo9tTprbBWgpvXwOqO159TVnj\nz62+tT4tyTD9ZzA3Hbikvcb5j89vjB49evDNN980eb+r0tSkobYqKzWzmH99lc6kNzfzl8sSmTiy\nZ5ecdNRss/+7777zKL+qqixYsIBnn33W54KdNuYIuPfr+pp/ZRO1450f1qdNwSdVfGifcf7S0lI2\nb97M1VdfzfLlywEtanJWVhapqank5OTwwAMPcOTIEXr18s/FMc0FKj2dskb1i+S9v1zIaz9k8T8b\nM1i/t4AF1ycxIKZrRStqdpx/ypQprFq1qsnrzkCLvP0d4PA71XH+luKPff725ve8Yzz9xV5yS63c\neWEv7ru0DyZ915hw1GzNHx4eztq1axk+fDg7duwgPDy8PeRqezrAq5+YmEhiYqLPyp82bZrPyhZo\nDI0PZfW0C1iRms3y1Gy+Ty/k79cnMSLhDNWDBjTbkVm8eDHV1dWsXr2ampoaFi9e3B5yCQSdBqNe\n5v7L+rLi3vMJCTTw4KptPPdVOpWNrBU5k2jRkl673U5JSUmLY/i1N2KSj6C9cCkqH2zL5X9/yMJi\n0jPvmrO4bEBUR4t1SjTb7H/jjTf46aefOHDgAAkJCRiNRo8T6kxi9PuNj1hsvGNjO0siOJPRyRLj\nRyZwaf8onvsqnTlrd/GnQTE8dtUAup1hE4OabfZv2LCBlStX0qdPH9asWeOJtOsvOBwOli9fzrRp\n0/jvf//LTz/91OhzTqfzhC5RVlYWX3/9dYs+Jzk5mbfeeovPPvvMk5ebm8vixYtZvHgxBQUFbNiw\ngTfeeIN//OMfuFwu3nrrLbF9WgfRIyyQ5AnD+OcNg9l6sJQ7/vsLn+06QlFlLalZxRRX2TpaxGZp\ntuY3GrUlrgEBAWzdupWsrCyfC9VWHK46TJVdm0zjVJyNPrOvtN4RaDFaiLPEed3vjJN8/vnPf+Jw\nOPxykk9nQpIkrh0aywV9u/HS+n38/aPfcSoqgUYdJr3MzDH9mXj+qS8S8vUitGaV/4knnsButzN/\n/nzeffddHn/88TYXwheU15Zz11d3efwUx2zHGn1u+nfTPWlJkvhw3IeEBTTduunoST7Lli1j7Nix\nBAQE+O0kn85GRJCRWWMG8NXv+dTUOLBZtfDzCz7+nbd/OkhUsIlws5HwIAMRZiPhQUYigoyEmY1E\nmI2EmQ1EBBkxG3Wev3d7rDY8qfKrqsrbb7/N888/T2JioteGm52dsIAwVl670lPz3/9t49GHXvvT\na560xWg5qeJDx07yeeONN9i3bx+SJDFkyBC/nuTT2cgorEQnS4QHGXG6FM+qwvP7RBASaKCs2k5h\nhY19+ZWUWR1U1Jw4UmDUy4SbjQSZ9OzKLQdAr9OMgS/CnTfr7X/hhRcYO3YsgwcP9qzjPxPX83eE\nw09M8vEfGi4SquNkC44cLoVyq4Nyq53SajtlVjtlVgel1Xb2HD7Gxn2FKCpIQGigFsDk5QnDGJUY\n2WYyN9vs3717N7t370aSpE67nr8ldIRXX0zy8R9au+DIoJOJCjYRFXzi/eIqG7saMSRtPb24SeWv\n20Krs03lPV2s23cAYB4xvIMlEXQ12mrBUVuuXDwZTSr/X//6V08N//e//51Fixa16Qd3FKXurceE\n8gt8QVstOGrLlYtN0aLOe1eJIW/dvoOanTup2bnT0wJojo4c5y8qKuKZZ55h9erVAHz//ffs3bu3\nReUJznwiLSZGJUb6LKpQkzV/Xl4eycnJqKrqSdcxa9YsnwjT1tjzDqNUV3mui159BdXp9KRj5s/3\nel4OsmCM7zzj/FFRUUyZMoVNmzYBcMUVV/DSSy8xaNAg3/xgAr+iSeV/7rnnPOmGYbzOFJxlZeRM\nngxKvdPEdeyYR/lrtm0n977jhv9kmT6ffIz+JCsX23uc//hn7Hb7Sb+3QNBSmlT+hv/EZyL68HB6\nvfOOV81f+0c6he6mefS8eQQkneX1jhxkOaniQ/uO819++eV89NFHHDx4kAsuuIB+/fphMp1Z88cF\nnZcWreo7FRYtWkRaWhqDBg3ymhw0f/58srKyCAgI4I477mDcuHEUFBQwd+5c7HY7M2fObHVLozWr\n+vJmzAQg/tVXWv+lWklbj/N///33xMTEiGa/oE3wyd5be/bswWq1smbNGhYuXMju3bs5++yzPfeX\nLFniNTNt6dKlzJo1i7POOosHH3zQp92MiHvu8VnZx9PW4/xXXnllm5UlEPhkqt7OnTs9Cjxq1Ch2\n7tzpuSdJEvPmzePBBx/k8OHDAOzbt48RI0YQFBREUFCQxxnmC8wjhothPoEAH9X8lZWV9Oyp7YEX\nHBxMRkaG5968efMICwvjt99+Y/HixbzyyisoiuJxdFksFioqKrBYLI2WnZKSQkpKileecIIJBK3H\nJ8ofHBzsqb2rqqoICQnx3KuLB3Deeefx4osvAt5rBY5//njGjx/P+PHjvfLq+vwCgaDl+ET5hw0b\nRkpKCtdddx2pqanccsstnnt104YPHDjgUfKBAweyY8cOBg4cSHV1dZO1flO4XNoSyvz8/Lb7EgLB\nGUb37t3R61uu0j5R/sGDB2M0Gpk0aRJJSUnExsby2muvMX36dObMmcOxY8eQJImnnnoKgPvuu4/H\nH38cm83GjBkzWv15RUVFAEyePLktv4ZAcEbR2hiWPhvqa09qa2tJS0sjKioKna59Yqo/+OCDvP76\n6+3yWa2ls8rWWeWCzitba+TqFDV/exMQENDum4cajcZOGym4s8rWWeWCziubL+XqXFE5BAJBuyGU\nXyDwU4TyCwR+iu6pOpe7oNUMGTKko0Voks4qW2eVCzqvbL6Sq0t4+wUCQesRzX6BwE8Ryi8Q+ClC\n+QUCP0Uov0DgpwjlbyW7du1iwoQJTJw4sdOGM1++fDkTJ07saDG8+OSTT7j77ruZMmUKBQUFHS2O\nh5qaGv7yl78wZcoUpk+f3uHLwwsKCrj55psZOnQoTne8yUWLFjFp0iSeeeaZNv0sofytpEePHqxY\nsYJ3332XkpIS9u3b1/xL7YjdbuePP/7oaDG8KCgoYMuWLaxYsYJVq1YRExPT0SJ52LRpE2effTar\nVq3i7LPP5scff+xQecLCwli+fDnDhg0DvKNiORwOdu/e3WafJZS/lURFRXmCaBoMhnZbSNRS1q5d\ny0033dTRYnixadMmFEXh7rvv5umnn/Yswe4MJCQkUFNTA0BFRYUn3kRHYTKZCA0N9VyfLCrW6SKU\n/xRJT0+ntLSUfv36dbQoHhwOB1u2bOGiiy7qaFG8KCkpweFwsGLFCgICAtiwYUNHi+ShV69e7Ny5\nk+uvv560tDRGjBjR0SJ5UVlZ6YlvERwcTEVFRZuVLZT/FCgvL+fpp5/m2Wef7WhRvPj00089G350\nJiwWCyNHjgTgwgsvJCsrq4Mlqufjjz/myiuv5IsvvuCKK65g3bp1HS2SFyeLinW6COVvJU6nk7lz\n5zJv3jyioqI6WhwvDh48yLvvvsu0adPIzMzsNJusjhgxwuMb+eOPPzrV0llVVT3N7PDwcCorKztY\nIm+GDRvG5s2bAUhNTfX4AtoCMb23lXz++ec888wz9O/fH4BHH32U4cM7XzTgiRMn8u6773a0GB4W\nL15MWloa4eHhLFmyBKPR2NEiAVo/f/bs2djtdvR6Pf/+9787tN/vcDi4//772bNnD4MGDeLRRx/l\ns88+Y+/evSQlJfHkk0+22WcJ5RcI/BTR7BcI/BSh/AKBnyKUXyDwU4TyCwR+ilB+gcBPEcrfBfj1\n118ZPny4Z/bX/PnzycnJOaWyPvroI9auXduW4mG1WpkwYQIzZ870yv/ggw9aVc6UKVM8i10Ep49Q\n/i5CbGxsmyttS1EU5aT309PTOe+883jllVe88j/88ENfiiVohi6xaYcAxowZw/fff8/UqVM9ea++\n+irnnnsuo0aNYv78+Tz88MNs2bKFH374gdraWlwuF6NHj+bLL7+kd+/enunKGzdu5Ouvv8ZoNJKc\nnIzBYOCpp57i4MGDBAQE8MILL5Cens6yZcsAbULR5ZdfDmhz0efMmUNVVRVJSUksWLCAF154gfz8\nfHQ6HbNnzwa03Zb379/PlClTWLBgAWvXriU9PR1FUViyZAmRkZE8/PDD1NTUEBERQXJysud7ffbZ\nZ+zevZuHHnrIs73bwIEDWbBgQXv81F0GUfN3EWRZ5sorr+Tbb79t9tno6GjeeOMNevTogcPh4J13\n3uHo0aOUl5cD0K1bN9566y2GDx/O+vXr+f777+nRowcrV65k8uTJvPfee4A2G+3111/3KD5oSn3t\ntdfyzjvvUFNTw65du3jkkUe44YYbPIoP2m7LAwYMYNWqVQwcOJDHHnuM1atX8/DDD5OSkkJ+fj4R\nERGsWrWKl19+2fPe559/zq5du3jiiSf4448/OP/881m1ahVPPPFEW/2UfoOo+bsQt99+O4888gjR\n0dEASJLkuddwIueAAQMAzQjUTVOOjo72+AySkpI8599//x2DwcAXX3zBTz/9hNPp9MwvHzx48Aky\nHDp0yGMMhgwZQk5OTovW77/55pv88ssvOJ1OEhMTSUhIYMCAATz22GMMGTKEe+65B4ClS5eyZs0a\nQNvmfcuWLTz22GNceumlnW4pc2dHKH8XIiQkhD59+vDLL78A2mq6wsJCVFUlIyPD81xDo9CYgahb\nhJOenk5CQgIBAQHcdNNN3HvvvYBW42/fvt3r3ToSEhLYs2cP/fv3Jy0tjdtvvx2bzdaovHXvl5WV\nsWXLFtasWcPPP//MZ599ht1uZ+rUqciyzL333utZrfjcc88xd+5cXnnlFSRJYtasWQDceOONQvlb\niWj2dzGmTJnCgQMHABg7diwrV65k1qxZXgEimqO8vJx7772Xbdu2MXbsWMaMGcPhw4e56667uOuu\nu04a7eaOO+7giy++YNKkSRiNxpOuQouNjWXGjBmUlJRgNpu56667+OGHHwA4fPgwkydPZvz48YSH\nh9OtWzdAa41MmzaNxx9/nN27dzNx4kRuv/12T8ALQcsRC3sEAj9F1PwCgZ8ilF8g8FOE8gsEfopQ\nfoHATxHKLxD4KUL5BQI/RSi/QOCnCOUXCPyU/w+X4IpEBK2BIgAAAABJRU5ErkJggg==\n",
169 | "text/plain": [
170 | ""
171 | ]
172 | },
173 | "metadata": {},
174 | "output_type": "display_data"
175 | }
176 | ],
177 | "source": [
178 | "fig = plt.figure(figsize=(3.3,2.4))\n",
179 | "# fig, ax = plt.subplots()\n",
180 | "# plt.rc('font', family='serif', serif='Times')\n",
181 | "gs = GridSpec(2, 1, height_ratios=[0.5, 1])\n",
182 | "ax = plt.subplot(gs[0])\n",
183 | "plot_data()\n",
184 | "# Training error for control network -- trained conventionally\n",
185 | "# plt.arrow(10.5, 0.995198, -0.2, 0, head_width=0.005, head_length=0.2, fc='k', ec='k')\n",
186 | "plt.tick_params(\n",
187 | " axis='x', # changes apply to the x-axis\n",
188 | " which='both', # both major and minor ticks are affected\n",
189 | " bottom='off', # ticks along the bottom edge are off\n",
190 | " top='off', # ticks along the top edge are off\n",
191 | " labelbottom='off')\n",
192 | "plt.tick_params(\n",
193 | " axis='y', # changes apply to the x-axis\n",
194 | " which='both', # both major and minor ticks are affected\n",
195 | " right='off', # ticks along the bottom edge are off\n",
196 | " left='on')\n",
197 | "\n",
198 | "\n",
199 | "ax.spines['top'].set_visible(False)\n",
200 | "ax.spines['right'].set_visible(False)\n",
201 | "ax.spines['bottom'].set_visible(False)\n",
202 | "# ax.get_yaxis().tick_left()\n",
203 | "ylim(0.965, 1.005)\n",
204 | "yticks([0.97, 1.0])\n",
205 | "xlim(0.5, 10.46)\n",
206 | "\n",
207 | "ax2 = plt.subplot(gs[1])\n",
208 | "plot_data()\n",
209 | "\n",
210 | "xlabel('Number of tasks')\n",
211 | "ylabel('Fraction correct')\n",
212 | "ylim(0.48, 1.02)\n",
213 | "xlim(0.5, 10.5)\n",
214 | "simple_axis(ax2)\n",
215 | "yticks([0.5, 0.75, 1.0])\n",
216 | "\n",
217 | "from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes\n",
218 | "\n",
219 | "legend(loc='lower left', fontsize=6)\n",
220 | "\n",
221 | "plt.subplots_adjust(left=.18, bottom=.20, right=.99, top=.97)\n",
222 | "plt.savefig(\"accuracy_vs_nbtasks.pdf\", pad_inches=0)"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {
229 | "collapsed": true
230 | },
231 | "outputs": [],
232 | "source": []
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "metadata": {
238 | "collapsed": true
239 | },
240 | "outputs": [],
241 | "source": []
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": null,
246 | "metadata": {
247 | "collapsed": true
248 | },
249 | "outputs": [],
250 | "source": []
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": null,
255 | "metadata": {
256 | "collapsed": true
257 | },
258 | "outputs": [],
259 | "source": []
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {
265 | "collapsed": true
266 | },
267 | "outputs": [],
268 | "source": []
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "metadata": {
274 | "collapsed": true
275 | },
276 | "outputs": [],
277 | "source": []
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "metadata": {
283 | "collapsed": true
284 | },
285 | "outputs": [],
286 | "source": []
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": null,
291 | "metadata": {
292 | "collapsed": true
293 | },
294 | "outputs": [],
295 | "source": []
296 | }
297 | ],
298 | "metadata": {
299 | "kernelspec": {
300 | "display_name": "Python 3",
301 | "language": "python",
302 | "name": "python3"
303 | },
304 | "language_info": {
305 | "codemirror_mode": {
306 | "name": "ipython",
307 | "version": 3
308 | },
309 | "file_extension": ".py",
310 | "mimetype": "text/x-python",
311 | "name": "python",
312 | "nbconvert_exporter": "python",
313 | "pygments_lexer": "ipython3",
314 | "version": "3.5.2"
315 | }
316 | },
317 | "nbformat": 4,
318 | "nbformat_minor": 1
319 | }
320 |
--------------------------------------------------------------------------------
/fig_permuted_mnist/data_path_int[omega_decay=sum,xi=0.1]_optadam_lr1.00e-03_bs256_ep20_tsks10.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ganguli-lab/pathint/86c5e07c5603d6a44c2f53585c19ee70773ef15c/fig_permuted_mnist/data_path_int[omega_decay=sum,xi=0.1]_optadam_lr1.00e-03_bs256_ep20_tsks10.pkl.gz
--------------------------------------------------------------------------------
/fig_split_mnist/Figure split MNIST.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# Copyright (c) 2017 Ben Poole & Friedemann Zenke\n",
12 | "# MIT License -- see LICENSE for details\n",
13 | "# \n",
14 | "# This file is part of the code to reproduce the core results of:\n",
15 | "# Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through\n",
16 | "# Synaptic Intelligence. In Proceedings of the 34th International Conference on\n",
17 | "# Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention\n",
18 | "# Centre, Sydney, Australia: PMLR), pp. 3987–3995.\n",
19 | "# http://proceedings.mlr.press/v70/zenke17a.html"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 10,
25 | "metadata": {
26 | "collapsed": false,
27 | "scrolled": false
28 | },
29 | "outputs": [
30 | {
31 | "name": "stdout",
32 | "output_type": "stream",
33 | "text": [
34 | "The autoreload extension is already loaded. To reload it, use:\n",
35 | " %reload_ext autoreload\n",
36 | "Populating the interactive namespace from numpy and matplotlib\n"
37 | ]
38 | },
39 | {
40 | "name": "stderr",
41 | "output_type": "stream",
42 | "text": [
43 | "/usr/local/lib/python3.5/dist-packages/IPython/core/magics/pylab.py:160: UserWarning: pylab import has clobbered these variables: ['colors']\n",
44 | "`%matplotlib` prevents importing * from pylab and numpy\n",
45 | " \"\\n`%matplotlib` prevents importing * from pylab and numpy\"\n"
46 | ]
47 | }
48 | ],
49 | "source": [
50 | "%load_ext autoreload\n",
51 | "%autoreload 2\n",
52 | "%pylab inline\n",
53 | "\n",
54 | "import tensorflow as tf\n",
55 | "slim = tf.contrib.slim\n",
56 | "graph_replace = tf.contrib.graph_editor.graph_replace\n",
57 | "\n",
58 | "import sys, os\n",
59 | "sys.path.extend([os.path.expanduser('..')])\n",
60 | "from pathint import utils\n",
61 | "import seaborn as sns\n",
62 | "sns.set_style(\"ticks\")\n",
63 | "\n",
64 | "from tqdm import trange, tqdm\n",
65 | "\n",
66 | "# import operator\n",
67 | "import matplotlib.colors as colors\n",
68 | "import matplotlib.cm as cmx\n",
69 | "\n",
70 | "rcParams['pdf.fonttype'] = 42\n",
71 | "rcParams['ps.fonttype'] = 42"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "## Parameters"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 11,
84 | "metadata": {
85 | "collapsed": true
86 | },
87 | "outputs": [],
88 | "source": [
89 | "select = tf.select if hasattr(tf, 'select') else tf.where"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 12,
95 | "metadata": {
96 | "collapsed": true
97 | },
98 | "outputs": [],
99 | "source": [
100 | "# Data params\n",
101 | "input_dim = 784\n",
102 | "output_dim = 10\n",
103 | "\n",
104 | "# Network params\n",
105 | "n_hidden_units = 256\n",
106 | "activation_fn = tf.nn.relu\n",
107 | "\n",
108 | "# Optimization params\n",
109 | "batch_size = 64\n",
110 | "epochs_per_task = 10\n",
111 | "\n",
112 | "n_stats = 10\n",
113 | "\n",
114 | "# Reset optimizer after each age\n",
115 | "reset_optimizer = True"
116 | ]
117 | },
118 | {
119 | "cell_type": "markdown",
120 | "metadata": {},
121 | "source": [
122 | "## Construct datasets"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 13,
128 | "metadata": {
129 | "collapsed": true
130 | },
131 | "outputs": [],
132 | "source": [
133 | "task_labels = [[0,1], [2,3]]#, [4,5], [6,7], [8,9]]\n",
134 | "task_labels = [[0,1], [2,3], [4,5], [6,7], [8,9]]\n",
135 | "# task_labels = [[0,1,2,3,4], [5,6,7,8,9]]\n",
136 | "n_tasks = len(task_labels)\n",
137 | "training_datasets = utils.construct_split_mnist(task_labels, split='train')\n",
138 | "validation_datasets = utils.construct_split_mnist(task_labels, split='test')\n",
139 | "# training_datasets = utils.mk_training_validation_splits(full_datasets, split_fractions=(0.9, 0.1))"
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "metadata": {},
145 | "source": [
146 | "## Construct network, loss, and updates"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 14,
152 | "metadata": {
153 | "collapsed": true
154 | },
155 | "outputs": [],
156 | "source": [
157 | "tf.reset_default_graph()"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 15,
163 | "metadata": {
164 | "collapsed": true
165 | },
166 | "outputs": [],
167 | "source": [
168 | "config = tf.ConfigProto()\n",
169 | "config.gpu_options.allow_growth=True\n",
170 | "sess = tf.InteractiveSession(config=config)\n",
171 | "sess.run(tf.global_variables_initializer())"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 9,
177 | "metadata": {
178 | "collapsed": true
179 | },
180 | "outputs": [],
181 | "source": [
182 | "# tf.equal(output_mask[None, :], 1.0)"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 16,
188 | "metadata": {
189 | "collapsed": true
190 | },
191 | "outputs": [],
192 | "source": [
193 | "import keras.backend as K\n",
194 | "import keras.activations as activations\n",
195 | "\n",
196 | "output_mask = tf.Variable(tf.zeros(output_dim), name=\"mask\", trainable=False)\n",
197 | "\n",
198 | "def masked_softmax(logits):\n",
199 | " # logits are [batch_size, output_dim]\n",
200 | " x = select(tf.tile(tf.equal(output_mask[None, :], 1.0), [tf.shape(logits)[0], 1]), logits, -1e32 * tf.ones_like(logits))\n",
201 | " return activations.softmax(x)\n",
202 | "\n",
203 | "def set_active_outputs(labels):\n",
204 | " new_mask = np.zeros(output_dim)\n",
205 | " for l in labels:\n",
206 | " new_mask[l] = 1.0\n",
207 | " sess.run(output_mask.assign(new_mask))\n",
208 | " print(sess.run(output_mask))\n",
209 | " \n",
210 | "def masked_predict(model, data, targets):\n",
211 | " pred = model.predict(data)\n",
212 | " print(pred)\n",
213 | " acc = np.argmax(pred,1)==np.argmax(targets,1)\n",
214 | " return acc.mean()"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 17,
220 | "metadata": {
221 | "collapsed": true
222 | },
223 | "outputs": [],
224 | "source": [
225 | "from keras.models import Sequential\n",
226 | "from keras.layers import Dense\n",
227 | "model = Sequential()\n",
228 | "model.add(Dense(n_hidden_units, activation=activation_fn, input_shape=(input_dim,)))\n",
229 | "model.add(Dense(n_hidden_units, activation=activation_fn))\n",
230 | "model.add(Dense(output_dim, kernel_initializer='zero', activation=masked_softmax))"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 9,
236 | "metadata": {
237 | "collapsed": true
238 | },
239 | "outputs": [],
240 | "source": [
241 | "from pathint import protocols\n",
242 | "from pathint.optimizers import KOOptimizer\n",
243 | "from keras.optimizers import Adam, RMSprop,SGD\n",
244 | "from keras.callbacks import Callback\n",
245 | "from pathint.keras_utils import LossHistory\n",
246 | "\n",
247 | "#protocol_name, protocol = protocols.PATH_INT_PROTOCOL(omega_decay='sum',xi=1e-3)\n",
248 | "protocol_name, protocol = protocols.PATH_INT_PROTOCOL(omega_decay='sum',xi=1e-3)\n",
249 | "# protocol_name, protocol = protocols.SUM_FISHER_PROTOCOL('sum')\n",
250 | "opt = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999)\n",
251 | "# opt = SGD(1e-3)\n",
252 | "# opt = RMSprop(lr=1e-3)\n",
253 | "oopt = KOOptimizer(opt, model=model, **protocol)\n",
254 | "model.compile(loss='categorical_crossentropy', optimizer=oopt, metrics=['accuracy'])\n",
255 | "model._make_train_function()\n",
256 | "saved_weights = model.get_weights()\n",
257 | "\n",
258 | "history = LossHistory()\n",
259 | "callbacks = [history]\n",
260 | "datafile_name = \"split_mnist_data_%s.pkl.gz\"%protocol_name"
261 | ]
262 | },
263 | {
264 | "cell_type": "markdown",
265 | "metadata": {},
266 | "source": [
267 | "## Train!"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 10,
273 | "metadata": {
274 | "collapsed": true,
275 | "scrolled": true
276 | },
277 | "outputs": [],
278 | "source": [
279 | "# diag_vals = dict()\n",
280 | "# all_evals = dict()\n",
281 | "# data = utils.load_zipped_pickle(\"comparison_data_%s.pkl.gz\"%protocol_name)\n",
282 | "# returns empty dict if file not found\n",
283 | "\n",
284 | "def run_fits(cvals, training_data, valid_data, eval_on_train_set=False, nstats=1):\n",
285 | " acc_mean = dict()\n",
286 | " acc_std = dict()\n",
287 | " for cidx, cval_ in enumerate(tqdm(cvals)):\n",
288 | " runs = []\n",
289 | " for runid in range(nstats):\n",
290 | " sess.run(tf.global_variables_initializer())\n",
291 | " # model.set_weights(saved_weights)\n",
292 | " cstuffs = []\n",
293 | " evals = []\n",
294 | " print(\"setting cval\")\n",
295 | " cval = cval_\n",
296 | " oopt.set_strength(cval)\n",
297 | " oopt.init_task_vars()\n",
298 | " print(\"cval is\", sess.run(oopt.lam))\n",
299 | " for age, tidx in enumerate(range(n_tasks)):\n",
300 | " print(\"Age %i, cval is=%f\"%(age,cval))\n",
301 | " print(\"settint output mask\")\n",
302 | " set_active_outputs(task_labels[age])\n",
303 | " stuffs = model.fit(training_data[tidx][0], training_data[tidx][1], batch_size, epochs_per_task, callbacks=callbacks)\n",
304 | " oopt.update_task_metrics(training_data[tidx][0], training_data[tidx][1], batch_size)\n",
305 | " oopt.update_task_vars()\n",
306 | " ftask = []\n",
307 | " for j in range(n_tasks):\n",
308 | " set_active_outputs(task_labels[j])\n",
309 | " if eval_on_train_set:\n",
310 | " f_ = masked_predict(model, training_data[j][0], training_data[j][1])\n",
311 | " else:\n",
312 | " f_ = masked_predict(model, valid_data[j][0], valid_data[j][1])\n",
313 | " ftask.append(np.mean(f_))\n",
314 | " evals.append(ftask)\n",
315 | " cstuffs.append(stuffs)\n",
316 | "\n",
317 | " # Re-initialize optimizater variables\n",
318 | " if reset_optimizer:\n",
319 | " oopt.reset_optimizer()\n",
320 | "\n",
321 | " evals = np.array(evals)\n",
322 | " runs.append(evals)\n",
323 | " \n",
324 | " runs = np.array(runs)\n",
325 | " acc_mean[cval_] = runs.mean(0)\n",
326 | " acc_std[cval_] = runs.std(0)\n",
327 | " return dict(mean=acc_mean, std=acc_std)\n"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": 11,
333 | "metadata": {
334 | "collapsed": false
335 | },
336 | "outputs": [
337 | {
338 | "name": "stdout",
339 | "output_type": "stream",
340 | "text": [
341 | "[0, 1.0]\n"
342 | ]
343 | }
344 | ],
345 | "source": [
346 | "# cvals = np.concatenate(([0], np.logspace(-2, 2, 10)))\n",
347 | "# cvals = np.concatenate(([0], np.logspace(-1, 2, 2)))\n",
348 | "# cvals = np.concatenate(([0], np.logspace(-2, 0, 3)))\n",
349 | "cvals = np.logspace(-3, 3, 7)#[0, 1.0, 2, 5, 10]\n",
350 | "cvals = [0, 1.0]\n",
351 | "print(cvals)\n"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": 12,
357 | "metadata": {
358 | "collapsed": true
359 | },
360 | "outputs": [],
361 | "source": [
362 | "%%capture\n",
363 | "\n",
364 | "recompute_data = True\n",
365 | "\n",
366 | "if recompute_data:\n",
367 | " data = run_fits(cvals, training_datasets, validation_datasets, eval_on_train_set=True, nstats=n_stats)\n",
368 | " utils.save_zipped_pickle(data, datafile_name)"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": 13,
374 | "metadata": {
375 | "collapsed": false
376 | },
377 | "outputs": [
378 | {
379 | "name": "stdout",
380 | "output_type": "stream",
381 | "text": [
382 | "[0, 1.0]\n"
383 | ]
384 | }
385 | ],
386 | "source": [
387 | "data = utils.load_zipped_pickle(datafile_name)\n",
388 | "print(cvals)"
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "execution_count": 14,
394 | "metadata": {
395 | "collapsed": false
396 | },
397 | "outputs": [
398 | {
399 | "name": "stdout",
400 | "output_type": "stream",
401 | "text": [
402 | "(-5, 0.0)\n"
403 | ]
404 | }
405 | ],
406 | "source": [
407 | "cmap = plt.get_cmap('cool') \n",
408 | "cNorm = colors.Normalize(vmin=-5, vmax=np.log(np.max(list(data['mean'].keys()))))\n",
409 | "scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cmap)\n",
410 | "print(scalarMap.get_clim())"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "execution_count": 15,
416 | "metadata": {
417 | "collapsed": false
418 | },
419 | "outputs": [
420 | {
421 | "name": "stderr",
422 | "output_type": "stream",
423 | "text": [
424 | "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py:13: RuntimeWarning: divide by zero encountered in log\n",
425 | " del sys.path[0]\n",
426 | "/usr/local/lib/python3.5/dist-packages/matplotlib/axes/_axes.py:545: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots.\n",
427 | " warnings.warn(\"No labelled objects found. \"\n"
428 | ]
429 | },
430 | {
431 | "data": {
432 | "image/png": "\n",
433 | "text/plain": [
434 | ""
435 | ]
436 | },
437 | "metadata": {},
438 | "output_type": "display_data"
439 | }
440 | ],
441 | "source": [
442 | "figure(figsize=(14, 2.5))\n",
443 | "axs = [subplot(1,n_tasks+1,1)]#, None, None]\n",
444 | "for i in range(1, n_tasks + 1):\n",
445 | " axs.append(subplot(1, n_tasks+1, i+1, sharex=axs[0], sharey=axs[0]))\n",
446 | " \n",
447 | "keys = list(data['mean'].keys())\n",
448 | "sorted_keys = np.sort(keys)\n",
449 | "\n",
450 | "for cval in sorted_keys:\n",
451 | " mean_vals = data['mean'][cval]\n",
452 | " std_vals = data['std'][cval]\n",
453 | " for j in range(n_tasks):\n",
454 | " colorVal = scalarMap.to_rgba(np.log(cval))\n",
455 | " # axs[j].plot(evals[:, j], c=colorVal)\n",
456 | " axs[j].errorbar(range(n_tasks), mean_vals[:, j], yerr=std_vals[:, j]/np.sqrt(n_stats), c=colorVal)\n",
457 | " label = \"c=%g\"%cval\n",
458 | " average = mean_vals.mean(1)\n",
459 | " axs[-1].plot(average, c=colorVal, label=label)\n",
460 | " \n",
461 | "for i, ax in enumerate(axs):\n",
462 | " ax.legend(loc='best')\n",
463 | " ax.set_title((['task %d'%j for j in range(n_tasks)] + ['average'])[i])\n",
464 | "gcf().tight_layout()\n",
465 | "sns.despine()"
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "execution_count": 16,
471 | "metadata": {
472 | "collapsed": true
473 | },
474 | "outputs": [],
475 | "source": [
476 | "plt.rc('text', usetex=False)\n",
477 | "plt.rc('xtick', labelsize=8)\n",
478 | "plt.rc('ytick', labelsize=8)\n",
479 | "plt.rc('axes', labelsize=8)\n",
480 | "\n",
481 | "def simple_axis(ax):\n",
482 | " ax.spines['top'].set_visible(False)\n",
483 | " ax.spines['right'].set_visible(False)\n",
484 | " ax.get_xaxis().tick_bottom()\n",
485 | " ax.get_yaxis().tick_left()"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 17,
491 | "metadata": {
492 | "collapsed": false
493 | },
494 | "outputs": [
495 | {
496 | "data": {
497 | "image/png": "\n",
498 | "text/plain": [
499 | ""
500 | ]
501 | },
502 | "metadata": {},
503 | "output_type": "display_data"
504 | }
505 | ],
506 | "source": [
507 | "fig = plt.figure(figsize=(3.3,2.5))\n",
508 | "ax = plt.subplot(111)\n",
509 | "\n",
510 | "for cval in sorted_keys:\n",
511 | " mean_stuff = []\n",
512 | " std_stuff = []\n",
513 | " for i in range(len(data['mean'][cval])):\n",
514 | " mean_stuff.append(data['mean'][cval][i][:i+1].mean())\n",
515 | " std_stuff.append(np.sqrt((data['std'][cval][i][:i+1]**2).sum())/(n_stats*np.sqrt(n_stats)))\n",
516 | " # plot(range(1,n_tasks+1), mean_stuff, 'o-', label=\"c=%g\"%cval)\n",
517 | " errorbar(range(1,n_tasks+1), mean_stuff, yerr=std_stuff, fmt='o-', label=\"c=%g\"%cval)\n",
518 | " \n",
519 | "axhline(data['mean'][cval][0][0], linestyle='--', color='k')\n",
520 | "xlabel('Number of tasks')\n",
521 | "ylabel('Fraction correct')\n",
522 | "legend(loc='best')\n",
523 | "xlim(0.5, 5.5)\n",
524 | "ylim(0.7, 1.02)\n",
525 | "# grid('on')\n",
526 | "# sns.despine()\n",
527 | "simple_axis(ax)"
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "execution_count": 18,
533 | "metadata": {
534 | "collapsed": false
535 | },
536 | "outputs": [
537 | {
538 | "data": {
539 | "image/png": "\n",
540 | "text/plain": [
541 | ""
542 | ]
543 | },
544 | "metadata": {},
545 | "output_type": "display_data"
546 | }
547 | ],
548 | "source": [
549 | "fig = plt.figure(figsize=(3.3,2.0))\n",
550 | "ax = plt.subplot(111)\n",
551 | "\n",
552 | "plot_keys =sorted(data['mean'].keys())# [0,1]\n",
553 | "\n",
554 | "for cval in plot_keys:\n",
555 | " mean_stuff = []\n",
556 | " std_stuff = []\n",
557 | " for i in range(len(data['mean'][cval])):\n",
558 | " mean_stuff.append(data['mean'][cval][i][:i+1].mean())\n",
559 | " std_stuff.append(np.sqrt((data['std'][cval][i][:i+1]**2).sum())/(n_stats*np.sqrt(n_stats)))\n",
560 | " # plot(range(1,n_tasks+1), mean_stuff, 'o-', label=\"c=%g\"%cval)\n",
561 | " errorbar(range(1,n_tasks+1), mean_stuff, yerr=std_stuff, fmt='o-', label=\"c=%g\"%cval)\n",
562 | " \n",
563 | "axhline(data['mean'][cval][0][0], linestyle=':', color='k')\n",
564 | "xlabel('Number of tasks')\n",
565 | "ylabel('Fraction correct')\n",
566 | "legend(loc='best', fontsize=8)\n",
567 | "xlim(0.5, 5.5)\n",
568 | "plt.yticks([0.6,0.8,1.0])\n",
569 | "ylim(0.6, 1.02)\n",
570 | "# grid('on')\n",
571 | "# sns.despine()\n",
572 | "simple_axis(ax)\n",
573 | "plt.subplots_adjust(left=.15, bottom=.18, right=.99, top=.97)\n",
574 | "plt.savefig(\"split_mnist_accuracy.pdf\")"
575 | ]
576 | },
577 | {
578 | "cell_type": "code",
579 | "execution_count": 20,
580 | "metadata": {
581 | "collapsed": false,
582 | "scrolled": true
583 | },
584 | "outputs": [
585 | {
586 | "name": "stdout",
587 | "output_type": "stream",
588 | "text": [
589 | "[0, 1.0]\n"
590 | ]
591 | },
592 | {
593 | "data": {
594 | "image/png": "\n",
595 | "text/plain": [
596 | ""
597 | ]
598 | },
599 | "metadata": {},
600 | "output_type": "display_data"
601 | }
602 | ],
603 | "source": [
604 | "figure(figsize=(7, 1.8))\n",
605 | "axs = [subplot(1,n_tasks+1,1)]\n",
606 | "for i in range(1,n_tasks+1):\n",
607 | " axs.append(subplot(1, n_tasks+1, i+1, sharex=axs[0], sharey=axs[0]))\n",
608 | "fmts = ['o', 's']\n",
609 | "\n",
610 | "plot_keys =sorted(data['mean'].keys())\n",
611 | "# plot_keys = [0]\n",
612 | "print(plot_keys)\n",
613 | "\n",
614 | "for i, cval in enumerate(plot_keys):\n",
615 | " label = \"c=%g\"%cval\n",
616 | " mean_vals = data['mean'][cval]\n",
617 | " std_vals = data['std'][cval]\n",
618 | " for j in range(n_tasks+1):\n",
619 | " sca(axs[j])\n",
620 | " errorbar_kwargs = dict(fmt=\"%s-\"%fmts[i], markersize=5)\n",
621 | " if j < n_tasks:\n",
622 | " # print(j,mean_vals[:, j])\n",
623 | " norm= np.sqrt(n_stats) # np.sqrt(n_stats) for SEM or 1 for STDEV\n",
624 | " axs[j].errorbar(np.arange(n_tasks)+1, mean_vals[:, j], yerr=std_vals[:, j]/norm, label=label, **errorbar_kwargs)\n",
625 | " else:\n",
626 | " mean_stuff = []\n",
627 | " std_stuff = []\n",
628 | " for i in range(len(data['mean'][cval])):\n",
629 | " mean_stuff.append(data['mean'][cval][i][:i+1].mean())\n",
630 | " #std_stuff.append(data['mean'][cval][i][:i+1].std()/np.sqrt(n_stats))\n",
631 | " std_stuff.append(np.sqrt((data['std'][cval][i][:i+1]**2).sum())/(n_stats*np.sqrt(n_stats)))\n",
632 | " # plot(range(1,n_tasks+1), mean_stuff, 'o-', label=\"c=%g\"%cval)\n",
633 | " errorbar(range(1,n_tasks+1), mean_stuff, yerr=std_stuff, label=\"c=%g\"%cval, **errorbar_kwargs)\n",
634 | " plt.xticks(np.arange(5)+1)\n",
635 | " plt.xlim((1.0,5.5))\n",
636 | " if j == 0:\n",
637 | " axs[j].set_yticks([0.5,1])\n",
638 | " else:\n",
639 | " setp(axs[j].get_yticklabels(), visible=False)\n",
640 | " plt.ylim((0.45,1.1))\n",
641 | "\n",
642 | "for i, ax in enumerate(axs):\n",
643 | " if i < n_tasks:\n",
644 | " ax.set_title((['Task %d (%d or %d)'%(j+1,task_labels[j][0], task_labels[j][1]) for j in range(n_tasks)] + ['average'])[i], fontsize=8)\n",
645 | " else:\n",
646 | " ax.set_title(\"Average\", fontsize=8)\n",
647 | " #ax.set_title((['Task %d'%(j+1) for j in xrange(n_tasks)] + ['average'])[i], fontsize=8)\n",
648 | " # ax.axhline(0.5, linestyle=':', color='k')\n",
649 | " ax.axhline(0.5, color='k', linestyle=':', label=\"chance\", zorder=0)\n",
650 | "handles, labels = axs[-1].get_legend_handles_labels()\n",
651 | "# Reorder legend so chance is last\n",
652 | "axs[-1].legend([handles[j] for j in [1,2,0]], [labels[j] for j in [1,2,0]], loc='lower right', fontsize=8, bbox_to_anchor=(-1.3, -.7), ncol=3, frameon=True)\n",
653 | "# axs[-1].legend(loc='lower right', fontsize=8, bbox_to_anchor=(-1.3, -.7), ncol=3, frameon=True)\n",
654 | " \n",
655 | "axs[0].set_xlabel(\"Tasks\")\n",
656 | "axs[0].set_ylabel(\"Accuracy\")\n",
657 | "gcf().tight_layout()\n",
658 | "sns.despine()\n",
659 | "plt.savefig(\"split_mnist_tasks.pdf\")"
660 | ]
661 | },
662 | {
663 | "cell_type": "code",
664 | "execution_count": null,
665 | "metadata": {
666 | "collapsed": true
667 | },
668 | "outputs": [],
669 | "source": []
670 | },
671 | {
672 | "cell_type": "code",
673 | "execution_count": null,
674 | "metadata": {
675 | "collapsed": true
676 | },
677 | "outputs": [],
678 | "source": []
679 | }
680 | ],
681 | "metadata": {
682 | "kernelspec": {
683 | "display_name": "Python 3",
684 | "language": "python",
685 | "name": "python3"
686 | },
687 | "language_info": {
688 | "codemirror_mode": {
689 | "name": "ipython",
690 | "version": 3
691 | },
692 | "file_extension": ".py",
693 | "mimetype": "text/x-python",
694 | "name": "python",
695 | "nbconvert_exporter": "python",
696 | "pygments_lexer": "ipython3",
697 | "version": "3.5.2"
698 | }
699 | },
700 | "nbformat": 4,
701 | "nbformat_minor": 2
702 | }
703 |
--------------------------------------------------------------------------------
/fig_transfer_cifar/Figure barplot.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# Copyright (c) 2017 Ben Poole & Friedemann Zenke\n",
12 | "# MIT License -- see LICENSE for details\n",
13 | "# \n",
14 | "# This file is part of the code to reproduce the core results of:\n",
15 | "# Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through\n",
16 | "# Synaptic Intelligence. In Proceedings of the 34th International Conference on\n",
17 | "# Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention\n",
18 | "# Centre, Sydney, Australia: PMLR), pp. 3987–3995.\n",
19 | "# http://proceedings.mlr.press/v70/zenke17a.html"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 19,
25 | "metadata": {
26 | "collapsed": false
27 | },
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "Populating the interactive namespace from numpy and matplotlib\n"
34 | ]
35 | },
36 | {
37 | "name": "stderr",
38 | "output_type": "stream",
39 | "text": [
40 | "/usr/local/lib/python3.5/dist-packages/IPython/core/magics/pylab.py:160: UserWarning: pylab import has clobbered these variables: ['colors']\n",
41 | "`%matplotlib` prevents importing * from pylab and numpy\n",
42 | " \"\\n`%matplotlib` prevents importing * from pylab and numpy\"\n"
43 | ]
44 | }
45 | ],
46 | "source": [
47 | "%pylab inline"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 20,
53 | "metadata": {
54 | "collapsed": true
55 | },
56 | "outputs": [],
57 | "source": [
58 | "import sys, os\n",
59 | "sys.path.extend([os.path.expanduser('..')])\n",
60 | "from pathint import utils\n",
61 | "import seaborn as sns\n",
62 | "\n",
63 | "rcParams['pdf.fonttype'] = 42\n",
64 | "rcParams['ps.fonttype'] = 42"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 21,
70 | "metadata": {
71 | "collapsed": true
72 | },
73 | "outputs": [],
74 | "source": [
75 | "datafile_name = \"split_cifar10_data_path_int[omega_decay=sum,xi=0.001]_lr1.00e-03_ep60.pkl.gz\"\n",
76 | "# backup all_evals to disk\n",
77 | "# all_evals = dict() # uncomment to delete on disk\n",
78 | "data = utils.load_zipped_pickle( datafile_name )"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 22,
84 | "metadata": {
85 | "collapsed": true
86 | },
87 | "outputs": [],
88 | "source": [
89 | "colors = (sns.color_palette(\"deep\"))\n",
90 | "colors[2] = 'lightgray'"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 23,
96 | "metadata": {
97 | "collapsed": false
98 | },
99 | "outputs": [
100 | {
101 | "data": {
102 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPIAAACrCAYAAABc6cGbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XlAVOX+P/D3sAgom5CgYCZpGIIomV66XSVRspuaSSIg\njBmiAuIKKq5RQqDigksKYqiI5Ia50L1WoNJVw3tdAzVEhUBhMEEZEBiYOb8/+HG+jMxwDtsgw+f1\nF5zlWWb4cLbnfB4BwzAMCCGdmkZHN4AQ0noUyISoAQpkQtQABTIhaoACmRA1QIFMiBqgQCZEDVAg\nE6IGKJAJUQMUyBwyMjIwZswYCIVCCIVCiMVihIeHQyqVtqi8Y8eONXufdevWtagudaXoO3nVNPU9\nZ2RkYMuWLW1an1abltaBJgWdbNX+pzdNVrruk08+weLFi9nfV61a1eJ6jh8/jqlTpzZrnzVr1rS4\nvo6UmZnZqv3t7OyUrnv5O6knk8mgoaG645Oy+lryPbeG2gSyKgmFQsTHx2PXrl3Iz89HcXExLC0t\nER4ejpKSEqxcuRIVFRUYMGAAQkND2f0OHz6M7OxsCIVCrF69GqGhoUhKSkJBQQF27NiByMhITJs2\nDdbW1vj9998RFBSE0aNHw9PTE0lJSQgJCUG3bt1w7949vP/++wgMDMTNmzcRGhoKKysrPHz4ECdO\nnOi4D6YDCYVCDBkyBMXFxfjyyy8RHByM8vJy2NjYYPXq1UhOTsb58+dRVVUFqVQKZ2dn/Pjjj+jf\nvz/Cw8PZciQSCQIDA1FZWQkTExNER0fj6tWr2LhxI7S1teHp6QlTU1PEx8cDADw9PfHrr7/i7t27\nkMlkiIqKwp07d9jv2c/PD7q6uo32z8rKgp+fH549e4a9e/eiR48ereo/nVrzcOrUKQiFQqxYsaLR\nusGDB2Pfvn0oLCxEWVkZYmNjMXfuXCQkJKBHjx64fv06u627uzusra2RkJCAQYMGKazr+fPnWLx4\nMWJjY3H48OFG60eNGoWkpCRcuHABAPDtt99i165dCAsLw+PHj9uox68+Rd+Ji4sLoqKicPjwYfzz\nn/9EYmIiKisrcfPmTQCAmZkZYmNjYWFhgZqaGiQmJqKwsBDPnj1jyygqKoKJiQkSEhKwdetWAMDm\nzZvx7bffIiEhAR999BEAoKamBrt374aTkxOCgoJw8OBBBAYG4vDhwxg7diz7Pb///vsK99fW1mb3\nv3z5cqs/Dzoi86DsNA4A3nrrLQB1fyRisRj379/Hpk2bIBAIUFFRAXt7e87yG76A1rNnT5iamgIA\nysrKlNanq6sLACgvL0fv3r0BAP379+ffqU5O0Xdia2sLAPjzzz/h5OQEoO70PC8vDwBgbW0NoO67\navi9lZWVwdjYGADQr18/WFtbIygoCHZ2dvjiiy/AMAxMTEwAgD2Nrq8LAOLi4nD58mXU1tZiwIAB\njdqqaP/6tpibm7fJNT4FcisJBAL2Z4ZhYGVlhU8++YS9vqutrVW6fXV1NQAgOztb4Xqu+gBAX18f\nIpEIhoaG7B9sV1UfJP369UNWVhbeeustZGZmws3NDQ8ePJD77F7+3upJJBLMnDkTGhoa8PHxwaRJ\nkyAQCFBaWoqePXtCJpPJ7V9aWoorV67g0KFDuHjxIk6fPt2ofEX7N9QWbxJTILcxPz8/rFmzBmKx\nGBoaGggLC0Pfvn3Z9X369MH8+fOxaNEifPDBB/D09MTQoUNbXF9AQAD8/f3Rr18/9OnTpy260OlN\nmzYNQUFBOHLkCAYNGoRhw4bhwYMHvPZ99OgRVq5cCZlMhr59+8LU1BRLliyBv7+/3DVuPSMjI3Tv\n3h0zZsyQu1yyt7dHQEAAfHx8mty/rQgosUDnVltbCy0tLbx48QI+Pj74/vvvO7pJpAPQEbmTu3bt\nGrZt24aKigrMmzevo5tDOggdkQlRA/T4iRA1QIFMiBqgQCZEDVAg83T58mUIhUJ4eXlh3rx5KC0t\nbbOyt2/fjkuXLuHOnTs4evSo3LqCggKEhIQo3bfh4PzWvMxBOje1uWs97bB/q/Y/4r5L6bqSkhLs\n3LkTu3fvhr6+Ph4+fIiamppW1aeIjY0NbGxsmrVPw8H5rXmZg3RuahPI7enChQuYPHky9PX1AQBW\nVla4e/cuFixYAJlMBi8vL0yePFnhSw2bN2/Gf//7X2hpaSEqKgpSqRQhISGQSCRwdnbGnDlz2Hoy\nMjJw6dIlLF68GNHR0cjIyMDAgQPZ9WFhYU0Ozt+9ezfi4+NRXFzcqI7t27c3esGDqA86tebhyZMn\n6NWrl9yyrVu3IioqComJiTh48CB7hH75pYZr164hMTERCQkJMDMzw549e7BgwQJ8//33yMjIgEgk\nalRfcXExbt26hUOHDmHEiBHscq7B+fWU1fHyCx5EfdARmYdevXqhuLhYbllZWRk79LJv374oKSkB\n0PilBl9fXyxfvhzGxsZYvHgx/vzzTwwePBhA3al0QUFBo/oeP37MDveztbXFxYsXAXAPzq+nrI6X\nX/AwNDRswadBXkV0RObByckJp06dQnl5OQAgLy8POjo6KCgoQE1NDfLz89m3W15+qcHR0REbN26E\nqakpzp8/zw7oB4A7d+7IjcOuZ2Fhwb5IcefOHQDyg/MXLlzIDrRX9JKFsjqUvShAOj86IvNgYmKC\ngIAA+Pn5gWEYGBkZISgoCMHBwZBKpfDy8oK2trbCfQMCAlBVVQUAiI6OxtChQxESEoKamhqMGTMG\n5ubmjfYxMzODra0tpk+fjrfffhsAv8H59Xx9fTnrIOqFhmgSogbo1JoQNUCBTIgaaLdAXrFiBd57\n7z1MnDhR4XqGYRAWFgYXFxdMmjSJvTlDCGk+zkD+448/WlSwq6sr4uLilK5PT09Hbm4ufvrpJ6xb\nt04u2yQhpHk4A/mrr77C1KlTkZiY2KwkYSNGjICRkZHS9ampqfj0008hEAgwbNgwlJWVNXpWSwjh\nhzOQDx06hKioKBQVFcHV1RVBQUHsAIXWEIlEbPZHAOjdu7fCUU7K1NbWoqCgoFFyO0K6Il7Pkfv3\n749FixbBzs4OYWFhuH37NhiGwZIlS/Dhhx+2dxtx+PDhRjmeJRIJ7t27h9TUVIWDKgjpSjgD+e7d\nu0hOTsaFCxfw97//Hbt374atrS1EIhE8PDxaHMjm5uYoKipify8qKlI6cMHd3R3u7u5yywoKCjB2\n7NgW1U2IuuE8tQ4LC8PgwYNx8uRJfPnll2xibnNzcyxcuLDFFTs7O+OHH34AwzC4ceMGDAwMYGZm\n1uLyCOnKOI/IMTEx0NXVhaamJoC6Sauqq6uhp6eHTz/9VOl+S5YswZUrV1BaWorRo0dj/vz57PWs\np6cnnJyccOHCBbi4uEBPTw/ffPNNG3WJkK6Hc4jmtGnTEB8fz04yVVFRgVmzZnV4/uT6U2u6RiaE\nx6l1dXW13ExxPXr0QGVlZbs2ihDSPJyBrKenJzfqKjMzk33XlhDyauC8Rl65ciUWLlwIMzMzMAyD\nv/76q81nWyeEtA5nINvb2+Nf//oXHj58CKAuX5Wyd28JIR2D14CQhw8fIicnBxKJBLdv3waAJu9Y\nE0JUizOQd+zYgYyMDNy/fx9OTk5IT0/H8OHDKZAJeYVw3uw6e/Ys9u/fj9deew0RERE4efJkm8yw\nTghpO5yBrKOjAw0NDWhpaaG8vBympqYoLCxURdsIITxxnlrb2dmhrKwMbm5ucHV1Rffu3eHg4KCK\nthFCeGoykBmGwdy5c2FoaAhPT0+MGjUK5eXlbGZHLunp6QgPD4dMJoObm5vcrApAXf7m5cuXQywW\nQyqVIjg4GE5OTi3vDSFdFcNh4sSJXJsoVFtby4wdO5b5888/merqambSpEnMvXv35LZZvXo1k5iY\nyDAMw9y7d48ZM2YM7/Lz8/MZa2trJj8/v0XtI0SdcF4jDx48GLdu3Wr2P4hbt27hjTfewOuvv45u\n3bphwoQJSE1NldtGIBCwSd/FYjG9/URIC3FeI9+8eROnT5+GhYUF9PT02OWnT59ucr+XM4CYm5s3\n+ocQGBiIWbNm4eDBg6isrER8fHxz208IAY9A3rt3b7tVnpKSgilTpsDHxwfXr1/HsmXLcObMGWho\nyJ8oKMsQ0hqZmZlNrrezs2tV+aRzmhR0ssn1pzdNVlFLmoczkBXNLcTHyxlARCJRowwgx44dYzNt\nOjg4oLq6GqWlpTA1NZXbjjKEENI0zkCeO3cu+3N1dTUKCgpgZWWFlJSUJvcbMmQIcnNzkZ+fD3Nz\nc6SkpGDTpk1y2/Tp0weXL1+Gq6sr7t+/j+rqanYyNEIIf5yB/PK1cFZWFg4dOsRdsJYW1q5dC19f\nX0ilUnz22Wd46623EB0dDTs7O4wdOxYhISFYvXo19u3bB4FAgMjIyBafARDSlTV7NkZbW1ved7Gd\nnJwaPRdumOdr4MCBHZ5phLTOtMP+Ta4/4r5LRS1RjVe1v5yB3PBOskwmw+3bt+kxESGvGM5Arqio\nYH/W1NSEk5MTxo8f366NIoQ0D2cgBwYGqqIdbY7rFGit7TwVtaRro8d8qsEZyF988QWio6NhaGgI\nAHj+/DmWLFnSrs+XO7NX9RqKqDfOQC4pKWGDGACMjIzw9OnTdm0Uab6u9g+kq/WXC2cga2pq4vHj\nx7CwsAAAPHr0iB4RtQKdaqpGR33OHVUvZyAvWrQI06dPx4gRI8AwDK5evYqvv/66XRpDSD2uoZJ6\nI1XUkE6CM5BHjx6N5ORk3Lx5E0BdelwafUXIq4UzkH/++Wc4OjpizJgxAICysjL88ssvGDduXLs3\njnQ8OjJ2DryyaLq4uLC/GxoaYseOHRTInQxdm6s3zsQCMpms0TKpVMqr8PT0dIwfPx4uLi6IjY1V\nuM2PP/6Ijz/+GBMmTEBQUBCvcgkh8ngl34uIiICXlxcAIDExkZ0juSlSqRRff/014uPjYW5ujqlT\np8LZ2RkDBw5kt8nNzUVsbCySkpLosRYhrcB5RF6zZg20tbWxaNEiLFq0CN26dcPatWs5C+aT6ufI\nkSPw8vKCkZERADR6D5kQwg/nEbl79+4IDg5udsF8Uv3k5uYCADw8PCCTyRAYGIjRo0c3Kqs9MoQQ\nok54jezas2cPcnJyUF1dzS4/cOBAqyuXSqXIy8tDQkICioqK4O3tjdOnT8uNJAMoQwhAd49J0zhP\nrYODg/Hmm2+ioKAAgYGBsLS0xJAhQzgL5pPqx9zcHM7OztDW1sbrr7+O/v37s0dpQgh/nIH87Nkz\nuLm5QUtLCyNHjkRERAR+++03zoIbpvqRSCRISUmBs7Oz3Dbjxo3DlStXANQd+XNzc/H666+3sCuE\ndF2cp9ZaWnWbmJmZ4fz58zAzM8Pz58+5C+aR6mfUqFG4ePEiPv74Y2hqamLZsmXo2bNn63tFSBfD\nGcj+/v4Qi8VYvnw51q1bh4qKCqxYsYJX4VypfgQCAVasWMG7PEKIYpyBXD8008DAAAkJCe3eIEJI\n8zU7+V5XR3ePyauI82YXIeTVR4FMiBrgPLWWSCQ4e/YsHj16hNraWnZ5Z03KR4g64nXX2sDAALa2\ntujWrZsq2kQIaSbOQBaJRJQxk5BXHGcgOzg44I8//sCgQYNU0Z5m8w3/GdrdG6ceorvHpCvhDOSr\nV6/ixIkTsLS0lDu15pronBCiOpyBvGfPnhYXnp6ejvDwcMhkMri5uWHOnDkKtzt79iwWLFiAY8eO\n8XohgxAij/Pxk6WlJcRiMc6dO4dz585BLBbD0tKSs+D6DCFxcXFISUnBmTNnkJOT02i78vJyHDhw\nAEOHDm1ZDwgh3IG8f/9+BAcH4+nTp3j69CmWLl3Ka6gmnwwhABAdHY3Zs2dDR0enZT0ghHCfWh87\ndgxHjhxB9+7dAQCzZ8+Gu7s7hEJhk/vxyRCSlZWFoqIifPDBB03eGacMIYQ0jddYa01NTYU/t4ZM\nJkNkZCQiIiI4t6UMIYQ0jTOQXV1d4ebmxua2/uWXX/DZZ59xFsyVIaSiogLZ2dmYMWMGAODJkyfw\n9/fHrl276IYXIc3Ea1rVkSNH4urVqwCAiIgIDB48mLPghhlCzM3NkZKSgk2bNrHrDQwMkJGRwf4u\nFAqxbNkyCmJCWkBpIJeXl0NfXx/Pnj2DpaWl3J3qZ8+ewdjYuOmCeWQIIYS0DaWBHBQUhJiYGLi6\nuspNo8owDAQCgcI70C/jyhDSECUtIKTllAZyTEwMACAtLU1ljSGEtAznc+TPP/+c1zJCSMdRekSu\nrq5GZWUlSktL8fz5czAMA6Du2lkkEqmsgYQQbkoD+fvvv8f+/ftRXFwMV1dXNpD19fXh7e2tsgYS\nQrgpDeTPP/8cn3/+ORISEjhHcRFCOhbnc2ShUIjs7Gzk5OTIDYv89NNP27VhhBD+OAN5x44dyMjI\nwP379+Hk5IT09HQMHz6cApmQVwjnXeuzZ89i//79eO211xAREYGTJ09CLBarom2EEJ44A1lHRwca\nGhrQ0tJCeXk5TE1NUVhYqIq2EUJ44jy1trOzQ1lZGdzc3ODq6oru3bvDwcGBV+FcGULi4+Nx9OhR\naGpqwsTEBN988w2vpAWEEHmcgRwaGgoA8PT0xKhRo1BeXo63336bs+D6DCHx8fEwNzfH1KlT4ezs\njIEDB7Lb2NjY4Pjx49DT08OhQ4ewceNGbN26teW9IaSLUhrIWVlZSnfKysqCra1tkwU3zBACgM0Q\n0jCQHR0d2Z+HDRuGU6dO8W44IeT/KA3kyMhIAHWZODIzM9l0uH/88Qfs7OwaZex4GZ8MIQ0dO3YM\no0ePVriOMoQQ0jSlgVz/NlJgYCCSk5PZQM7OzsaOHTvatBEnT55EZmYmDh48qHA9ZQghpGmc18gP\nHz6US05vbW2N+/fvcxbMlSGk3qVLl7B7924cPHiQpqQhpIU4A3nQoEFYtWoVPvnkEwB1ien5zDrB\nlSEEAG7fvo21a9ciLi4OpqamLewCIYQzkCMiIpCUlIQDBw4AAEaMGAFPT0/ugnlkCNmwYQNevHjB\nJhvo06cPdu/e3couEdL1cAayjo4OZs6ciZkzZza7cK4MIfv27Wt2mYSQxpQG8sKFCxEdHY1JkyYp\nXE9zPxHy6lAayKtWrQIAOtUlpBNQGshmZmYAQEMmCekElAayg4ODXPbMevVZNK9du9auDSOE8Kc0\nkK9fv67KdhBCWoHX3E8A8PTpU1RXV7O/W1hYtEuDCCHNxxnIqampWL9+PYqLi2FiYoLHjx9jwIAB\nSElJUUX7CCE8cCYWiI6OxuHDh9G/f3+kpaVh3759NCk5Ia8YzkDW0tJCz549IZPJIJPJ4OjoiMzM\nTFW0jRDCE+eptaGhISoqKjBixAgEBwfDxMSEnfScC1eGEIlEgmXLliErKwvGxsbYsmUL+vbt27Ke\nENKFcR6Rv/32W+jq6mLFihUYNWoU+vXrh127dnEWXJ8hJC4uDikpKThz5gxycnLktjl69CgMDQ3x\n888/Y+bMmYiKimp5TwjpwpQekb/66itMnDgRw4cPZ5dNmTKFd8F8MoSkpaUhMDAQADB+/Hh8/fXX\n7HNqLlKpFABQU/lM4XrN0som9y8uLm5yfUFBgcLlNS9KmtyP6qV6W1KvIr1794aWFr8HS0q36t+/\nPzZs2IAnT57go48+wsSJE3lNcF6PT4YQkUiEPn361DVESwsGBgYoLS2FiYmJ3HaKMoRUVFQAAAou\nKxlCyjGJpD8u8+lG81G9VG8bSU1N5X2pyTllzKNHj5CSkoKVK1eiqqoKEydOxIQJE2BlZdVmDeai\nKENIVVUVMjMz0atXL2hqaja7TD8/vw4ZR071Ur18NTwQcuE8bltaWmLOnDmYM2cObt++jZUrV2Ln\nzp24c+dOk/vxyRBibm6OwsJC9O7dG7W1tRCLxejZsyevhuvq6uLdd9/lta0i3bp165Aba1Qv1dse\nOG921dbWIi0tDUFBQZg9ezasrKywfft2zoIbZgiRSCRISUmBs7Oz3DbOzs44ceIEgLoZLRwdHXld\nHxNC5Ck9Il+8eBFnzpxBeno6hgwZggkTJmDdunW8Hz3xyRAydepULF26FC4uLjAyMsKWLVvarGOE\ndCVKAzkmJgaTJk1CSEgIjIyMWlQ4V4YQHR0dbNu2rUVlE0L+j2Zo/VQSL5kyZQpsbW2hq6ur4iap\njp2dHdVL9apFvQKGYRiV10oIaVOcN7sIIa8+CmRC1ADvxAKvKrFYjICAAAB1Ce8HDx6Mvn37IiIi\nosn96qdzdXV1bbROIpHA29sb9+7dw+nTpxU+F2yPevPy8rBixQoIBAJYWFggMjKy0WCX9qj3r7/+\nQmBgILS0tGBoaIgtW7ZAR0en3eut9+OPPyIqKgppaY2HTbVHvbW1tXB0dISNjQ2AuvcJDAwMVNLf\n9PR0xMXFgWEYrFy5km1DqzFqxMPDg/e2R44cYY4fP65wnVQqZf766y8mKCiIyc/PV1m9paWljFgs\nZhiGYTZs2MCcP39eJfXW1tYyUqmUYRiG2bJlC/PTTz+ppF6GYRiZTMYsWbKEcXd35yyrreqtqalh\nvL29eZfVVvVWVFQw8+fPZ2pra3mXx5danlqfP38eQqEQrq6u7FStBw4cwLRp0yAUCnH37l1228LC\nQsyZMwdPnjxhl2loaLRoCpvW1mtsbAx9fX0Adc/h+Q49bW29mpqa0NCo+1NgGAb9+vVTSb1A3Ysz\no0aNatZAoLaoNzs7G9OnT2/W2IXW1nvt2jVoaGjA19cXy5cvR2Vl0y9gNEub/2voQPX/OV+8eMEw\nDMNIJBJ22cyZM5mqqiqGYeqOAkeOHGF27tzJzJ49mxGJRArLa+4Rua3qLSwsZDw8PDj/c7dlvdeu\nXWOmTJnCTJ8+nT0rUEW9CxYsYGpqangd9dqy3mfPnjEymYxZsWIF7zOf1tZ74sQJxtvbm6mtrWUO\nHDjA7N+/n7PPfHX6a2RFfv/9d+zcuRNSqRQPHjwAUDc97Nq1a6Gjo4NFixYBAA4dOoTg4GA2h/er\nUG9VVRVCQkIQFhbG+4jcFvU6ODggOTkZsbGxOHHiBIRCYbvX+5///AcjRozg/apeW/a3fpDTuHHj\nkJ2d3WjgUnvUa2BggOHDh0NTUxOOjo5KpxFuCbU8tY6NjcX69evx3Xffsaeqtra2WL9+Pd555x38\n8MMPAOq+hH//+99NTsCu6npXr16NGTNmYMCAASqrt+Gk8fr6+rwHAbW23pycHPz888+YNWsWcnJy\neI/ya229L168gEwmA1B3usv3UqK19drb27NTEt+5c6dNX65QyyPyhx9+iLlz58LGxgaGhoYA6gKk\nsLAQEokEkZGRuHbtGrp164aoqCgsWLAAq1evxptvvsmWMX/+fFy/fh1Lly7FnDlzMGbMmHav93//\n+x/S0tIgEokQHx+PmTNn8prMvbX1ZmVlISoqChoaGujZsyc2bNigks+54eSAnp6eWLBggUrqffDg\nAdasWQM9PT288cYbcHFxUUm9vXr1wtChQ+Ht7Q1dXV1s3ryZV7180MguQtSAWp5aE9LVUCATogYo\nkAlRA2oVyP369cM//vEPldQ1depUhYNG7OzsOuz1OaK+bty4wQ7aUUSt7lpXVlbi+fPnKqmrpKRE\nblK7euXl5Sqpn3QtIpEITd2XVqsjMiFdlVodkYG6kVFtNcCjKWKxWOk6iUSikjaQrqN+JJkyavUc\n2cbGRm7genuzsLDAo0eP5Jb97W9/w5UrV1TWBtJ1aGhosDOsvEytArmlBAIBiouL0atXr45uSpfi\n4+ODpKSktn0LqIuiQEbdPFX29vYd3Ywup7KyEvn5+bC2tu7opnR6FMiEqIFOe7NLT08PVVVVHd0M\nhXR1dZt9uvgq94e0j5b8nSjTaY/IAoGgyedqHaklbXuV+0PaR1t+5/QcmRA1oDaBfPnyZQiFQnh5\neWHevHnw8/NDXl4ekpOTMX78eAiFQvj6+rLb79y5U+73htsFBASwL9uHhobC0dERR48elavL3d0d\nQqFQbsbJ9upLaWkpQkJCVNofvn309/fHu+++i0uXLrHLTp06BQ8PD8ydO5fXSDdPT0/+H04HO3bs\nWLO2r//e2l2bJQ1SsYZNf/r0KePl5cXmmnrw4AHj4+PD5ObmMsePH2eOHDnSaH9fX18mMDCQKSsr\nYxiGkdtu586dTFpaGsMwDCMSiRqV4e3tzYjFYubGjRtMaGhok21rbn8U9UUkEjHLly9XaX+4+lhP\nJBIx27ZtYy5evMgwTF0+K09PT6ampoZJSUlh9uzZw9n35mSpVLX67KL1mtvW+u9NkbYMP7U4Il+4\ncAGTJ09m069YWVk1+Uw4Pz8fffv2xbhx43D+/PlG6xuO2no571JlZSV0dXWhr6+PoUOHIicnp206\n8f8p6gtXTrG27k9z+vhyeXl5ebC2toaWlhbee+893LhxQ279kydP4OvrC6FQiE2bNsmti4mJgbe3\nN9zc3HD79m0AwPLly+Ht7Q2hUAiZTIbNmzfD09MTQqEQIpEIJSUl8PPzg1AoRP00ZomJiWxmy6ys\nLLk6pk2bhpCQELi6uuLcuXMAgJs3b0IoFMLDwwPHjx8HAAiFQmzYsAHLli1j901NTUV2djaEQiEu\nXrzIq731fvvtNwQFBaGmpkbpZ9kanfaudUNPnjxp8llkXFwcTp06hWHDhiEoKAi//PILxo8fDzs7\nO3z11VeYNGkSu92BAwdgaGiIJUuWKCyrrKyMDTIASkfatFdf6tvZnv1pTR8b7mtgYICysjK59TEx\nMZg5cyb+8Y9/yP2hA8CMGTMwd+5c5OXlYdu2bYiMjERRUREOHjwIhmEgEAhw7do1JCYmQkNDAwzD\nYP369Zg7dy4cHBywceNGXL9+HampqThw4AB0dXUb3UwqKSnB1q1bYWxsDB8fH4wZMwbbtm3Drl27\n0KNHD3zxxRfs5+fi4gIHBwd237Fjx8La2hoJCQkAgHfeeYezvQBw5coV/Pbbb4iMjIS2tjbvz7I5\n1CKQe/XqheLiYqXrfX194ebmxv5+4cIF/PrrrxAIBMjLy2PfYvL19YWrqyvmzZuH58+f47XXXmtU\nloGBgdx2b/n4AAADDklEQVR1H99Ml3xx9aW+ne3Zn9b0seG+5eXlbG6rerm5uWxwvPxa3smTJ3H6\n9Gl2uba2NqZMmYLg4GBYWlpi4cKFbE5oY2NjLF68GPfv38emTZsgEAhQUVEBe3t7zJ8/H6GhodDW\n1sbChQvl+m1sbAwLCwu5ft29exf+/v4AgNLSUpSWlgKoS6zXFD7tBeruX+zbt6/dghhQk5tdTk5O\nOHXqFPsHlJeX1ygheb0nT56gd+/e+O6777B3717MmjULFy9eZNdramrCy8sL+/btU7h/9+7dUVVV\nhYqKCty6datZ2S5b2pemArs9+qOsjyKRiLP9/fv3x7179yCVSnHp0iUMHTpUbr2VlRVu3rwJAI2O\nyIcOHUJCQgLWrVsHoO5MYMKECYiKikJJSQl+//13ODo6YuPGjTA1NcX58+dhZWWFkJAQJCQkIDk5\nGWPHjoWNjQ0iIyMxcuRIJCcny9Xx/PlzFBUVobKykj3TsLGxQUxMDBISEnDixAmYm5sDaPyPBoBc\nIn0+7QWAiIgIhIaGoqSkhPPzaym1OCKbmJggICAAfn5+YBgGRkZGSv/7paamYvjw4ezvI0eORFxc\nHEaMGMEue//997F161ZIJBLs3bsXZ86cAcMwEIlECAwMhL+/P3x8fNCtWzesX7++3fsSHh6udPv2\n6o+iZcuWLUN8fLzcH3hYWBjOnTuHtLQ0eHh4wN3dHW5ubvDy8oKhoWGj6+A5c+YgJCQEu3btgoOD\ng9wpv729Pby8vNi2V1RUwN/fH1KpFPr6+rC2tkZAQAA7cCY6OhqOjo5Ys2YNxGIxNDQ0EBYWhu3b\nt6OgoAASiaTRXE09e/bE9u3bcefOHcybNw8AsGDBAvbzNjY2xvbt25V+3vb29ggICICPjw+v9gJ1\nL9esWrUKS5cuxbZt29CjRw+l5bcUDQhpB+o4IEQmkyE8PBxr1qzp6Ka0iqenJ5KSkjq6GQBoQAjp\nABoaGp0+iNUZHZHbgToekUnboyMyIUROp73Zpaur26ypOFWJ79xJL+/zqvaHtI+W/J0o8/8ADAEZ\nnqZHJdcAAAAASUVORK5CYII=\n",
103 | "text/plain": [
104 | ""
105 | ]
106 | },
107 | "metadata": {},
108 | "output_type": "display_data"
109 | }
110 | ],
111 | "source": [
112 | "sns.set_style('ticks')\n",
113 | "figure(figsize=(3.3, 2.2))\n",
114 | "ax = axes()\n",
115 | "cval = 0.01 # orig value\n",
116 | "cvals = [0, cval, 'scratch']\n",
117 | "# cvals = [cval]\n",
118 | "# cvals = [0, cval]\n",
119 | "n_tasks = 6\n",
120 | "group_width = 0.8\n",
121 | "bar_width = group_width/len(cvals)\n",
122 | "bar_width = group_width/3\n",
123 | "index = np.arange(n_tasks)\n",
124 | "xtick_labels = ['Task %i'%(i+1) for i in range(n_tasks)]\n",
125 | "# xtick_labels[0] = 'CIFAR10'\n",
126 | "\n",
127 | "def do_plot(eval_type=0, age=-1):\n",
128 | " for k,cv in enumerate(cvals):\n",
129 | " means = []\n",
130 | " stdevs = []\n",
131 | " # print(cv)\n",
132 | " for tid in range(n_tasks):\n",
133 | " if cv=='scratch':\n",
134 | " a = tid\n",
135 | " else:\n",
136 | " a = age\n",
137 | " means.append( data['mean'][cv][a, tid, eval_type] )\n",
138 | " stdevs.append( data['std'][cv][a, tid, eval_type] )\n",
139 | " # print(means)\n",
140 | " \n",
141 | " bar(index+k*bar_width, means, width=bar_width, yerr=stdevs, color=colors[k], ecolor='gray')\n",
142 | " xticks(index)\n",
143 | " # gca().set_xticklabels()\n",
144 | " if eval_type==0:\n",
145 | " ylabel('Validation accuracy')\n",
146 | " else:\n",
147 | " ylabel('Training accuracy')\n",
148 | " xticks(index+group_width/2, xtick_labels, fontsize=8)\n",
149 | " xlim(-0.1, 6.0)\n",
150 | " # ylim(0.5, 1.0)\n",
151 | " yticks(np.arange(0.0, 1.1, 0.2))\n",
152 | " legend(('Fine tuning', 'Consolidation', 'From scratch'), bbox_to_anchor=(0., 1.02, 1., .102), loc=3,\n",
153 | " ncol=2, mode='expand', borderaxespad=0., fontsize=8)\n",
154 | " \n",
155 | " ax.annotate('CIFAR10', xy=(0.27, 0.12), xytext=(0.27, 0.02), xycoords='figure fraction', \n",
156 | " fontsize=8, ha='center', va='bottom',\n",
157 | " bbox=dict(boxstyle='square', fc='white'),\n",
158 | " arrowprops=dict(arrowstyle='-[, widthB=1.7, lengthB=0.5', lw=1.0))\n",
159 | "\n",
160 | " ax.annotate('CIFAR100, 10 classes per task', xy=(0.66, 0.12), xytext=(0.66, 0.02), xycoords='figure fraction', \n",
161 | " fontsize=8, ha='center', va='bottom',\n",
162 | " bbox=dict(boxstyle='square', fc='white'),\n",
163 | " arrowprops=dict(arrowstyle='-[, widthB=9.4, lengthB=0.5', lw=1.0))\n",
164 | "\n",
165 | "do_plot(eval_type=0)\n",
166 | "sns.despine()\n",
167 | "subplots_adjust(left=.21, bottom=.25, right=.98, top=.82)\n",
168 | "savefig(\"cifar10_cifar100_transfer_valid.pdf\")"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 24,
174 | "metadata": {
175 | "collapsed": false
176 | },
177 | "outputs": [
178 | {
179 | "data": {
180 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPIAAACrCAYAAABc6cGbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XlYVGX7B/DvsJvshIihYhqKGGhmL72aJka+ZmqiCAhj\nhqSAiAsuYFqkkCi4YYYhhgmIG2gora+KGiqW+4KXCkKiMNALySLrzPn9wY9zMTIz5zAwgMP9ua6u\nZM45zzIz95wzZ57nfgQMwzAghLzQNDq7AYSQtqNAJkQNUCATogYokAlRAxTIhKgBCmRC1AAFMiFq\ngAKZEDVAgUyIGqBA5pCVlYXx48dDKBRCKBSioqIC4eHhEIvFSpV35MiRVh+zfv16pepSV7Jek65G\n0euclZWFrVu3tmt9Wu1aWieaEvRDm44/vnma3G1Tp07F0qVL2b8/++wzpetJSUnBzJkzW3XM2rVr\nla6vM926datNxw8bNkzutudfkyYSiQQaGh13fpJXnzKvc1uoTSB3JKFQiPj4eMTExODRo0coLi7G\nK6+8gvDwcJSWlmL16tWoqqrCwIEDERoayh538OBB3Lt3D0KhEGvWrEFoaCiSk5NRUFCAr7/+GhER\nEZg1axZsbGxw8+ZNBAUFYezYsfDw8EBycjKCg4Oho6OD+/fvY/To0QgICMD169cRGhqKAQMG4OHD\nhzh69GjnPTGdSCgU4vXXX0dxcTG++OILLF++HJWVlbC1tcWaNWuQmpqKjIwM1NTUQCwWw8nJCT/+\n+COsra0RHh7OllNXV4eAgABUV1fD1NQU27dvx+XLlxEZGQltbW14eHjAzMwM8fHxAAAPDw+cO3cO\nd+/ehUQiQVRUFLKzs9nX2dfXF3p6ei2Ov337Nnx9ffHPP/9gz5496NmzZ5v6T5fWPKSlpUEoFCIk\nJKTFtqFDh2Lv3r0oLCxEeXk5YmNjsWDBAiQkJKBnz564evUqu6+bmxtsbGyQkJCAwYMHy6zr6dOn\nWLp0KWJjY3Hw4MEW29955x0kJyfjzJkzAIBvvvkGMTExCAsLw5MnT9qpx12frNfE2dkZUVFROHjw\nICZNmoSkpCRUV1fj+vXrAIBevXohNjYWffr0QX19PZKSklBYWIh//vmHLaOoqAimpqZISEjAtm3b\nAABbtmzBN998g4SEBPznP/8BANTX12PXrl0YN24cgoKCkJiYiICAABw8eBATJkxgX+fRo0fLPF5b\nW5s9/sKFC21+PuiMzIO8yzgAeO211wA0vkkqKiqQk5ODzZs3QyAQoKqqCvb29pzlN5+AZmJiAjMz\nMwBAeXm53Pr09PQAAJWVlejduzcAwNramn+nXnCyXhM7OzsAwF9//YVx48YBaLw8z8/PBwDY2NgA\naHytmr9u5eXlMDY2BgD069cPNjY2CAoKwrBhw/DJJ5+AYRiYmpoCAHsZ3VQXAMTFxeHChQtoaGjA\nwIEDW7RV1vFNbbGwsGiX7/gUyG0kEAjYfzMMgwEDBmDq1Kns97uGhga5+9fW1gIA7t27J3M7V30A\noK+vD5FIBENDQ/YN2101BUm/fv1w+/ZtvPbaa7h16xZcXV2Rm5sr9dw9/7o1qaurw9y5c6GhoQFv\nb29MmTIFAoEAZWVlMDExgUQikTq+rKwMly5dwv79+5GZmYnjx4+3KF/W8c21x0xiCuR25uvri7Vr\n16KiogIaGhoICwuDlZUVu93S0hKLFi3CkiVL8O6778LDwwMODg5K1+fv7w8/Pz/069cPlpaW7dGF\nF96sWbMQFBSEQ4cOYfDgwRg+fDhyc3N5Hfv48WOsXr0aEokEVlZWMDMzw7Jly+Dn5yf1HbeJkZER\nXnrpJcyZM0fq65K9vT38/f3h7e2t8Pj2IqDEAi+2hoYGaGlp4dmzZ/D29saBAwc6u0mkE9AZ+QV3\n5coVREdHo6qqCgsXLuzs5pBOQmdkQtQA/fxEiBqgQCZEDVAgE6IGKJB5unDhAoRCITw9PbFw4UKU\nlZW1W9k7duzA+fPnkZ2djcOHD0ttKygoQHBwsNxjmw/Ob8tkDvJiU5u71rMO+rXp+ENuMXK3lZaW\nYufOndi1axf09fXx8OFD1NfXt6k+WWxtbWFra9uqY5oPzm/LZA7yYlObQFalM2fOYNq0adDX1wcA\nDBgwAHfv3kVgYCAkEgk8PT0xbdo0mZMatmzZgj/++ANaWlqIioqCWCxGcHAw6urq4OTkhPnz57P1\nZGVl4fz581i6dCm2b9+OrKwsDBo0iN0eFhamcHD+rl27EB8fj+Li4hZ17Nixo8UED6I+6NKah5KS\nEpibm0s9tm3bNkRFRSEpKQmJiYnsGfr5SQ1XrlxBUlISEhIS0KtXL+zevRuBgYE4cOAAsrKyIBKJ\nWtRXXFyMGzduYP/+/Rg1ahT7ONfg/Cby6nh+ggdRH3RG5sHc3BzFxcVSj5WXl7NDL62srFBaWgqg\n5aQGHx8frFq1CsbGxli6dCn++usvDB06FEDjpXRBQUGL+p48ecIO97Ozs0NmZiYA7sH5TeTV8fwE\nD0NDQyWeDdIV0RmZh3HjxiEtLQ2VlZUAgPz8fOjq6qKgoAD19fV49OgRO7vl+UkNjo6OiIyMhJmZ\nGTIyMtgB/QCQnZ0tNQ67SZ8+fdiJFNnZ2QCkB+cvXryYHWgva5KFvDrkTRQgLz46I/NgamoKf39/\n+Pr6gmEYGBkZISgoCMuXL4dYLIanpye0tbVlHuvv74+amhoAwPbt2+Hg4IDg4GDU19dj/PjxsLCw\naHFMr169YGdnh9mzZ2PIkCEA+A3Ob+Lj48NZB1EvNESTEDVAl9aEqAEKZELUAAUyIWqAApkQNUCB\nTIgaeGEDuaGhAQUFBS2S2xHSHb2wgVxUVIQJEyagqKios5tCSKfjDGRZ6Tv5CAkJwdtvv40PP/xQ\n5naGYRAWFgZnZ2dMmTKFHYlECGk9zkB+//33sXnzZjx8+LBVBbu4uCAuLk7u9rNnzyIvLw+//vor\n1q9fL7W0CiGkdTgDOTU1FZaWllixYgU8PDxw5MgRVFVVcRY8atQoGBkZyd1+8uRJfPTRRxAIBBg+\nfDjKy8tbTEwghPDDOdba0NAQs2fPxuzZs5GVlYXly5cjPDwckyZNgp+fH/r27atUxSKRiF3qBAB6\n9+4NkUiEXr16tdj34MGDLdZBqqurU6rezsa1QqGiFQipXtXjWtVT0aqdnYkzkCUSCc6dO4eUlBTk\n5eVhzpw5mDp1Kv7880/4+Pjgl19+UXkj3dzc4ObmJvVYQUEBJkyYoHSZL9objLROd/vg4gzk999/\nHyNHjoRQKJSa5D558mT88ccfSldsYWEhdce5qKiIZukQoiTOQD527Bib4uZ5bblB5eTkhMTEREye\nPBnXr1+HgYGBzMtqQgg3zptdX331lVRamKdPn2LNmjWcBS9btgzu7u54+PAhxo4di8OHDyM5ORnJ\nyckAGifr9+3bF87Ozli7di2++OKLNnSDkO6N84x8584dqZQwRkZGvH7z3bJli8LtAoGAgpeQdsLr\nZldFRQUMDAwANJ6RaVhk18OVDvhzO1rgrT101eeZM5A//vhjuLm54YMPPgDDMPjpp5/g4+PTEW0j\nL4Cu+sbubjgDecaMGbCzs8PFixcBNF4yN+WRIi3RG5t0Bl7J94YMGYLevXujtrYWQONgDvqpiJCu\ngzOQMzIysGHDBhQVFcHExAQikQj9+/fHzz//3BHtI93UizrCqrNwBvLWrVuRnJwMb29vHDt2DJmZ\nmfjpp586om2EyEVfYaRxBrKmpiZMTU0hkUjAMAxGjx6NjRs3dkTb2oReaNKdcAaygYEBqqqqMHLk\nSKxcuRJmZmbsciiEkK6Bc2TXzp07oaenh9WrV+Ott96ChYUFdu3axavws2fPYuLEiXB2dkZsbGyL\n7U+ePIFQKMRHH32EKVOmsAufEUJaR+EZWSwWY9GiRYiPj4empiZcXV15FywWi7Fu3TrEx8fDwsIC\nM2fOhJOTk9QyoTExMZg0aRJmz56NBw8eYP78+Th16pTyvSGkm1J4RtbU1IRYLGYXL2uNGzduoH//\n/ujbty90dHQwefJknDx5UmofgUDAll1RUUGTJghREq/vyFOnTsWYMWPQo0cP9vGQkBCFxz2fOMDC\nwgI3btyQ2icgIADz5s1DYmIiqqurER8f39r2ExXj+hmox1sd1BCiEGcgv/vuu3j33XdVUnl6ejqm\nT58Ob29vXL16FStXrsSJEyegoSF9oaBOGUIIUQXOQG7N9+Lmnk8cIGs02JEjR9gEfSNGjEBtbS3K\nyspgZmYmtZ8qMoQQok54ZQiRtZg2V4qf119/HXl5eXj06BEsLCyQnp6OzZs3S+1jaWmJCxcuwMXF\nBTk5OaitrWUXDCeE8McZyPv372f/XVtbi59//hkVFRXcBWtp4fPPP4ePjw/EYjFmzJiB1157Ddu3\nb8ewYcMwYcIEBAcHY82aNdi7dy8EAgEiIiJkfmgQQhTjDOSXX35Z6u958+bBxcUFS5Ys4Sx83Lhx\nGDdunNRjixcvZv89aNAgHDhwgG9bCSFycAby3bt32X8zDINbt26hvr5epY0ihLQOZyCvW7eO/bem\npiasrKywdetWlTaKENI6rfqOTAjpmjjHWm/btq1FFs3o6GiVNooQ0jq8Egs0v7FlZGSE06dPIzAw\nUKUNI9JohBVRhPOMLBaLpUZR1dbW0s0uQroYzjPy5MmT4e3tjRkzZgAAUlJS5K55TAjpHJyB7Ovr\ni8GDB+P8+fMAAB8fH5WNvSaEKIczkJ88eYJ///vfGD9+PACgpqYGhYWFsLS05Cz87NmzCA8Ph0Qi\ngaurK+bPn99inx9//BFff/01BAIBhgwZ0mIYJyGEG+d35ICAAKlhkxoaGli0aBFnwU2JBeLi4pCe\nno4TJ07gwYMHUvvk5eUhNjYWycnJSE9Px+rVq5XoAiGE180uHR0d9m8dHR1eUwj5JBY4dOgQPD09\nYWRkBAAtZj0RQvjhDGRjY2OpXFqnT59mA08RWYkFRCKR1D55eXl4+PAh3N3dMWvWLJw9e7Y1bSeE\n/D/O78hffvklli1bhi+//BIAYGpqisjIyHapXCwWIz8/HwkJCSgqKoKXlxeOHz8utfoj0LUSC9Dv\nuaQr4gxka2trpKamsqO7ng8yefgkFrCwsICDgwO0tbXRt29fWFtbIy8vD/b29lL7UWIBQhTjtfbT\nuXPn8ODBA3btJ6DxZylF+CQWeO+995Ceno4ZM2agtLQUeXl56Nu3rxLdIKR74wzk0NBQVFRU4I8/\n/oCLiwt+/fVXODg4cBfMI7HAO++8g8zMTHzwwQfQ1NTEypUrYWJi0i4dI6Q74Qzky5cv4/jx45g6\ndSqWLFkCHx8fLFiwgFfhXIkFBAIBQkJCODNyEkIU47xr3bQ8jK6uLkpKSqCrq4vi4mKVN4wQwh/n\nGXns2LEoLy+Ht7c3PvroI2hoaGD69Okd0TZCCE+cgdw0imvSpEkYP348ampqYGxsrPKGEUL443XX\nuomenh6txEhIF8T5HZkQ0vVRIBOiBlqVDreJgYEBLC0tW6zR1Bl8wn+D9kstV6egoZKkO+EM5M8+\n+wx3797FoEGDwDAMcnNzMXDgQFRVVWH9+vV4++23O6KdhBAFOE+pr7zyClJSUvDDDz8gLS0NKSkp\nsLa2xp49exAREaHw2LNnz2LixIlwdnZGbGys3P1++eUXDB48GDdv3mx9Dwgh3IGcm5uLIUOGsH8P\nHjwYOTk56N+/v8Lj+CQWAIDKykrs27eP17BPQohsnIH86quvYv369bh8+TIuX76MsLAwvPrqq6ir\nq4Ompqbc4/gkFgCA7du349NPP4Wurm7bekJIN8YZyBs3boSFhQV2796N3bt3o1evXoiIiICmpia+\n//57ucfxSSxw+/ZtFBUVUTI/QtqI82ZXjx49ZCbNAxrvXitLIpEgIiICGzZs4Ny3KyUWIKQr4gzk\na9eu4euvv8aTJ08gFovZx7kWOudKLFBVVYV79+5hzpw5AICSkhL4+fkhJiYGr7/+ulRZlFiAEMU4\nAzkkJAQrVqyAnZ2dwu/Ez+NKLGBgYICsrCz2b6FQiJUrV7YIYkIIN85A1tfXh5OTU+sL5pFYgBDS\nPjgD2dHREZs3b4azs7NUWtzmP0nJw5VYoLmEhATO8gghsvHKENL8/0BjZo+kpCTVtYoQ0iq00Dkh\nakBuIJ84cQIffvgh9u3bJ3N7091mQkjnkxvIT58+BQCUlpZ2WGMIIcqRG8ienp4AgCVLlnRYYwgh\nyuH8jlxaWoqUlBQ8fvxYakDI+vXrVdowQgh/nIHs7++P4cOHY+TIka0aEEII6TicgVxdXY3g4OCO\naAshREmcs5/Gjh2L33//XanCuRILxMfH44MPPsCUKVPw8ccf4/Hjx0rVQ0h3x3lGPnDgAHbv3o0e\nPXpAW1sbDMNAIBDg0qVLCo9rSiwQHx8PCwsLzJw5E05OThg0aBC7j62tLVJSUtCjRw/s378fkZGR\n2LZtW9t7RUg3wxnIFy9eVKrg5okFALCJBZoHsqOjI/vv4cOHIy0tTam6COnu5AZyXl4erK2tcf/+\nfZnbucZay0oscOPGDbn7HzlyBGPHjuVqLyFEBrmBHBsbi6+++grr1q1rsa29x1r/8MMPuHXrFhIT\nE2Vup8QChCgmN5C/+uorAMqPteZKLNDk/Pnz2LVrFxITE6VmVzVHiQUIUYzX2k85OTnIyclBbW0t\n+9iUKVMUHsOVWAAA7ty5g88//xxxcXEwMzNTovmEEIBHIH/zzTfIzMxEbm4uxowZg99//x0jR47k\nDGQ+iQU2bdqEZ8+esXOULS0tsWvXrvbpGSHdCGcg//TTTzh27BimT5+OyMhIFBcXIyQkhFfhXIkF\n9u7d27rWEkJk4hwQoqurC01NTWhpaaGyshLm5uY0cIOQLobzjDx06FCUl5djxowZmDFjBvT19WFv\nb98RbSOE8KQwkBmGQUBAAAwNDeHp6YkxY8agsrISdnZ2HdU+QggPCi+tBQIBvL292b/79+9PQUxI\nF8T5HXnIkCG4c+dOR7SFEKIkuZfWDQ0N0NLSQnZ2NmbOnIm+ffvipZdeYidNHD16tCPbSQhRQG4g\nu7q64ujRo4iJienI9hBClCA3kBmGAQD069evwxpDCFGO3EAuLS1FfHy83AM/+eQTlTSIENJ6cm92\nSSQSVFVVyf2PD64MIXV1dViyZAmcnZ3h6uqKgoIC5XtCSDcm94xsbm6OgIAApQvmkyHk8OHDMDQ0\nxG+//Yb09HRERUVRhhBClMD5HVlZfDKEnDp1iv2wmDhxItatW8feFefSlJq3vvofmds1y6oVHl9c\nXKxwu7yrg/pnihP2U71UrzL1ytK7d29oafGaoCg/kNs6oYFPhhCRSARLS8vGhmhpwcDAAGVlZTA1\nNZXaT1ZigabL+4ILcmZLnVLcPj9c4NON1qN6qd52cvLkSVhZWfHaV24gGxsbt1uD2kpWYoGamhrc\nunUL5ubmSuXb9vX17ZQpk1Qv1ctX8xMhF37nbSXwyRBiYWGBwsJC9O7dGw0NDaioqICJiQmv8vX0\n9PDmm28q3T4dHR3en3btieqlelWBc4imsppnCKmrq0N6ejqcnJyk9nFycmJHiP3yyy9wdHTk9f2Y\nECJNZWdkPhlCZs6ciRUrVsDZ2RlGRkbYunWrqppDiFpTWSAD3BlCdHV1ER0drcomENItaIaGhoZ2\ndiM6y7Bhw6heqlct6hUwbf3BmBDS6VR2s4sQ0nEokAlRAyq92dURKioq4O/vD6Ax4f3QoUNhZWWF\nDRs2KDzu8OHD0NTUhIuLS4ttdXV18PLywv3793H8+HGZvwuqot78/HyEhIRAIBCgT58+iIiIaDHY\nRRX1/v333wgICICWlhYMDQ2xdetW6OrqqrzeJj/++COioqJw6lTLYVOqqLehoQGOjo6wtbUF0Ji7\n3cDAoEP6e/bsWcTFxYFhGKxevZptQ5sxasTd3Z33vocOHWJSUlJkbhOLxczff//NBAUFMY8ePeqw\nesvKypiKigqGYRhm06ZNTEZGRofU29DQwIjFYoZhGGbr1q3Mr7/+2iH1MgzDSCQSZtmyZYybmxtn\nWe1Vb319PePl5cW7rPaqt6qqilm0aBHT0NDAuzy+1PLSOiMjA0KhEC4uLuxSrfv27cOsWbMgFApx\n9+5ddt/CwkLMnz8fJSUl7GMaGhpKLWHT1nqNjY2hr68PoPF3eL5DT9tar6amJjQ0Gt8KDMPwTibR\n1nqBxokz77zzTqsGArVHvffu3cPs2bNbNXahrfVeuXIFGhoa8PHxwapVq1BdrXgCRqu0+0dDJ2r6\n5Hz27BnDMAxTV1fHPjZ37lympqaGYZjGs8ChQ4eYnTt3Mp9++ikjEolkltfaM3J71VtYWMi4u7tz\nfnK3Z71Xrlxhpk+fzsyePZu9KuiIegMDA5n6+npeZ732rPeff/5hJBIJExISwvvKp631Hj16lPHy\n8mIaGhqYffv2Md9//z1nn/l64b8jy3Lz5k3s3LkTYrEYubm5AICAgAB8/vnn0NXVxZIlSwA0rjS5\nfPly9OrVq8vUW1NTg+DgYISFhfE+I7dHvSNGjEBqaipiY2Nx9OhRCIVCldf7+++/Y9SoUbyn6rVn\nf42MjAAA7733Hu7du9di4JIq6jUwMMDIkSOhqakJR0dHucsIK0MtL61jY2OxceNGfPfdd+ylqp2d\nHTZu3Ig33ngDx44dA9D4Ivz8888KF2Dv6HrXrFmDOXPmYODAgR1Wb/O1pvX19aGnp9ch9T548AC/\n/fYb5s2bhwcPHvAe5dfWep89ewaJRAKg8XKX71eJttZrb2+PnJwcAEB2dna7Tq5QyzPy+++/jwUL\nFsDW1haGhoYAGgOksLAQdXV1iIiIwJUrV6Cjo4OoqCgEBgZizZo1ePXVV9kyFi1ahKtXr2LFihWY\nP38+xo8fr/J6//zzT5w6dQoikQjx8fGYO3curzWg21rv7du3ERUVBQ0NDZiYmGDTpk0d8jzPnTsX\nc+fOBQB4eHggMDCwQ+rNzc3F2rVr0aNHD/Tv3x/Ozs4dUq+5uTkcHBzg5eUFPT09bNmyhVe9fNDI\nLkLUgFpeWhPS3VAgE6IGKJAJUQNqFcj9+vXDmDFjOqSumTNnyhw0MmzYsE6bPkfU17Vr19hBO7Ko\n1V3r6upqPH36tEPqKi0tRW1tbYvHKysrO6R+0r2IRCKFKarV6oxMSHelVmdkoHFkVHsN8FCkoqJC\n7ra6uroOaQPpPppGksmjVr8j29raSg1cV7U+ffrg8ePHUo/961//wqVLlzqsDaT70NDQYFdYeZ5a\nBbKyBAIBiouLYW5u3tlN6Va8vb2RnJzcvrOAuikKZDSuU2Vvb9/Zzeh2qqur8ejRI9jY2HR2U154\nFMiEqIEX9mZXjx49UFNT09nNkElPT6/Vl4tduT9ENZR5n8jzwp6RBQJBm5d+VRVl2taV+0NUoz1f\nc/odmRA1oDaBfOHCBQiFQnh6emLhwoXw9fVFfn4+UlNTMXHiRAiFQvj4+LD779y5U+rv5vv5+/uz\nk+1DQ0Ph6OiIw4cPS9Xl5uYGoVAoteKkqvpSVlaG4ODgDu0P3z76+fnhzTffxPnz59nH0tLS4O7u\njgULFvAa6ebh4cH/yelkR44cadX+Ta+byrVb0qAO1rzp//vf/xhPT08211Rubi7j7e3N5OXlMSkp\nKcyhQ4daHO/j48MEBAQw5eXlDMMwUvvt3LmTOXXqFMMwDCMSiVqU4eXlxVRUVDDXrl1jQkNDFbat\ntf2R1ReRSMSsWrWqQ/vD1ccmIpGIiY6OZjIzMxmGacxn5eHhwdTX1zPp6enM7t27OfvemiyVHa0p\nu2iT1ra16XWTpT3DTy3OyGfOnMG0adPY9CsDBgxQ+Jvwo0ePYGVlhffeew8ZGRkttjcftfV83qXq\n6mro6elBX18fDg4OePDgQft04v/J6gtXTrH27k9r+vh8efn5+bCxsYGWlhbefvttXLt2TWp7SUkJ\nfHx8IBQKsXnzZqlt3377Lby8vODq6oo7d+4AAFatWgUvLy8IhUJIJBJs2bIFHh4eEAqFEIlEKC0t\nha+vL4RCIZqWMUtKSmIzW96+fVuqjlmzZiE4OBguLi44ffo0AOD69esQCoVwd3dHSkoKAEAoFGLT\npk1YuXIle+zJkydx7949CIVCZGZm8mpvk4sXLyIoKAj19fVyn8u2eGHvWjdXUlKi8LfIuLg4pKWl\nYfjw4QgKCsJ///tfTJw4EcOGDcOXX36JKVOmsPvt27cPhoaGWLZsmcyyysvL2SADIHekjar60tRO\nVfanLX1sfqyBgQHKy8ultn/77beYO3cuxowZI/VGB4A5c+ZgwYIFyM/PR3R0NCIiIlBUVITExEQw\nDAOBQIArV64gKSkJGhoaYBgGGzduxIIFCzBixAhERkbi6tWrOHnyJPbt2wc9Pb0WN5NKS0uxbds2\nGBsbw9vbG+PHj0d0dDRiYmLQs2dPfPLJJ+zz5+zsjBEjRrDHTpgwATY2NkhISAAAvPHGG5ztBYBL\nly7h4sWLiIiIgLa2Nu/nsjXUIpDNzc1RXFwsd7uPjw9cXV3Zv8+cOYNz585BIBAgPz+fncXk4+MD\nFxcXLFy4EE+fPsXLL7/coiwDAwOp7318M13yxdWXpnaqsj9t6WPzYysrK9ncVk3y8vLY4Hh+Wt4P\nP/yA48ePs49ra2tj+vTpWL58OV555RUsXryYzQltbGyMpUuXIicnB5s3b4ZAIEBVVRXs7e2xaNEi\nhIaGQltbG4sXL5bqt7GxMfr06SPVr7t378LPzw8AUFZWhrKyMgCNifUU4dNeoPH+xd69e1UWxICa\n3OwaN24c0tLS2DdQfn5+i4TkTUpKStC7d29899132LNnD+bNm4fMzEx2u6amJjw9PbF3716Zx7/0\n0kuoqalBVVUVbty40apsl8r2RVFgq6I/8vooEok4229tbY379+9DLBbj/PnzcHBwkNo+YMAAXL9+\nHQBanJH379+PhIQErF+/HkDjlcDkyZMRFRWF0tJS3Lx5E46OjoiMjISZmRkyMjIwYMAABAcHIyEh\nAampqZgY+SSpAAAByElEQVQwYQJsbW0RERGBt956C6mpqVJ1PH36FEVFRaiurmavNGxtbfHtt98i\nISEBR48ehYWFBYCWHzQApBLp82kvAGzYsAGhoaEoLS3lfP6UpRZnZFNTU/j7+8PX1xcMw8DIyEju\np9/JkycxcuRI9u+33noLcXFxGDVqFPvY6NGjsW3bNtTV1WHPnj04ceIEGIaBSCRCQEAA/Pz84O3t\nDR0dHWzcuFHlfQkPD5e7v6r6I+uxlStXIj4+XuoNHhYWhtOnT+PUqVNwd3eHm5sbXF1d4enpCUND\nwxbfg+fPn4/g4GDExMRgxIgRUpf89vb28PT0ZNteVVUFPz8/iMVi6Ovrw8bGBv7+/uzAme3bt8PR\n0RFr165FRUUFNDQ0EBYWhh07dqCgoAB1dXUt1moyMTHBjh07kJ2djYULFwIAAgMD2efb2NgYO3bs\nkPt829vbw9/fH97e3rzaCzROrvnss8+wYsUKREdHo2fPnnLLVxYNCFEBdRwQIpFIEB4ejrVr13Z2\nU9rEw8MDycnJnd0MADQghHQCDQ2NFz6I1RmdkVVAHc/IpP3RGZkQIuWFvdmlp6fXqqU4OxLftZOe\nP6ar9oeohjLvE3n+D7fxBgwelo6gAAAAAElFTkSuQmCC\n",
181 | "text/plain": [
182 | ""
183 | ]
184 | },
185 | "metadata": {},
186 | "output_type": "display_data"
187 | }
188 | ],
189 | "source": [
190 | "sns.set_style('ticks')\n",
191 | "figure(figsize=(3.3, 2.2))\n",
192 | "ax = axes()\n",
193 | "do_plot(eval_type=1)\n",
194 | "sns.despine()\n",
195 | "subplots_adjust(left=.21, bottom=.25, right=.98, top=.82)\n",
196 | "savefig(\"cifar10_cifar100_transfer_train.pdf\")"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "metadata": {
203 | "collapsed": true
204 | },
205 | "outputs": [],
206 | "source": []
207 | }
208 | ],
209 | "metadata": {
210 | "kernelspec": {
211 | "display_name": "Python 3",
212 | "language": "python",
213 | "name": "python3"
214 | },
215 | "language_info": {
216 | "codemirror_mode": {
217 | "name": "ipython",
218 | "version": 3
219 | },
220 | "file_extension": ".py",
221 | "mimetype": "text/x-python",
222 | "name": "python",
223 | "nbconvert_exporter": "python",
224 | "pygments_lexer": "ipython3",
225 | "version": "3.5.2"
226 | }
227 | },
228 | "nbformat": 4,
229 | "nbformat_minor": 2
230 | }
231 |
--------------------------------------------------------------------------------
/fig_transfer_cifar/split_cifar10_data_path_int[omega_decay=sum,xi=0.001]_lr1.00e-03_ep60.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ganguli-lab/pathint/86c5e07c5603d6a44c2f53585c19ee70773ef15c/fig_transfer_cifar/split_cifar10_data_path_int[omega_decay=sum,xi=0.001]_lr1.00e-03_ep60.pkl.gz
--------------------------------------------------------------------------------
/fig_transfer_cifar/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke
3 | # MIT License -- see LICENSE for details
4 | #
5 | # This file is part of the code to reproduce the core results of:
6 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through
7 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on
8 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention
9 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995.
10 | # http://proceedings.mlr.press/v70/zenke17a.html
11 | #
12 |
13 | import sys, os
14 | sys.path.extend([os.path.expanduser('..')])
15 |
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | import matplotlib.colors as colors
19 | import matplotlib.cm as cmx
20 |
21 | import seaborn as sns
22 | sns.set_style("white")
23 |
24 | from tqdm import trange, tqdm
25 |
26 | import tensorflow as tf
27 |
28 | from pathint import protocols
29 | from pathint.optimizers import KOOptimizer
30 | from keras.optimizers import Adam, RMSprop, SGD
31 | from keras.callbacks import Callback
32 | import keras.backend as K
33 | import keras.activations as activations
34 | from keras.models import Sequential
35 | from keras.layers import Dense, Dropout, Activation, Flatten
36 | from keras.layers import Conv2D, MaxPooling2D
37 |
38 | from pathint import utils
39 | from pathint.keras_utils import LossHistory
40 |
41 | # ## Parameters
42 |
43 | # Data params
44 |
45 | # Network params
46 | input_shape = (3,32,32)
47 |
48 | # size of pooling area for max pooling
49 | pool_size = (2, 2)
50 |
51 | # convolution kernel size
52 | kernel_size = (3, 3)
53 |
54 | # Optimization parameters
55 | batch_size = 256
56 | epochs_per_task = 60
57 | learning_rate = 1e-3
58 | nstats = 1 # repeats of experiment to compute stdev
59 |
60 |
61 | add_evals=False # Saves evals accross runs
62 | cvals = ['scratch', 0, 0.01, 0.05, 0.1, 0.2]
63 | cvals = ['scratch', 0, 0.01, 0.1]
64 | print("cvals %s"%cvals)
65 |
66 |
67 | debug=False
68 | if debug:
69 | cvals = [0.1]
70 | epochs_per_task = 1
71 | nstats = 1
72 |
73 | # Reset optimizer after each age
74 | reset_optimizer = True
75 |
76 |
77 | # ## Construct datasets
78 | n_tasks = 6
79 | output_dim = 10*n_tasks
80 | nb_classes = output_dim
81 | # task_labels = [ range(i*10,(i+1)*10) for i in range(n_tasks) ]
82 |
83 | task_labels, training_datasets = utils.construct_transfer_cifar10_cifar100(n_tasks, split='train')
84 | _, validation_datasets = utils.construct_transfer_cifar10_cifar100(n_tasks, split='test')
85 | print(task_labels)
86 |
87 | # training_datasets = utils.construct_split_cifar10(task_labels, split='train')
88 | # validation_datasets = utils.construct_split_cifar10(task_labels, split='test')
89 |
90 | # ## Construct network, loss, and updates
91 | tf.reset_default_graph()
92 |
93 | config = tf.ConfigProto()
94 | config.gpu_options.allow_growth=True
95 | sess = tf.InteractiveSession(config=config)
96 | sess.run(tf.global_variables_initializer())
97 |
98 |
99 | # Instantiate masking functions
100 | output_mask = tf.Variable(tf.zeros(output_dim), name="mask", trainable=False)
101 |
102 | select = tf.select if hasattr(tf, 'select') else tf.where
103 |
104 | def masked_softmax(logits):
105 | # logits are [batch_size, output_dim]
106 | x = select(tf.tile(tf.equal(output_mask[None, :], 1.0), [tf.shape(logits)[0], 1]), logits, -1e32 * tf.ones_like(logits))
107 | return activations.softmax(x)
108 |
109 | def set_active_outputs(labels):
110 | new_mask = np.zeros(output_dim)
111 | for l in labels:
112 | new_mask[l] = 1.0
113 | sess.run(output_mask.assign(new_mask))
114 | # print("setting output mask")
115 | # print(sess.run(output_mask))
116 |
117 | def masked_predict(model, data, targets):
118 | pred = model.predict(data)
119 | # print(pred)
120 | acc = np.argmax(pred,1)==np.argmax(targets,1)
121 | return acc.mean()
122 |
123 | # Assemble the network model
124 | model = Sequential()
125 |
126 | model.add(Conv2D(32, (3, 3), padding='same',
127 | input_shape=training_datasets[0][0].shape[1:]))
128 | model.add(Activation('relu'))
129 | model.add(Conv2D(32, (3, 3)))
130 | model.add(Activation('relu'))
131 | model.add(MaxPooling2D(pool_size=(2, 2)))
132 | model.add(Dropout(0.25))
133 |
134 | model.add(Conv2D(64, (3, 3), padding='same'))
135 | model.add(Activation('relu'))
136 | model.add(Conv2D(64, (3, 3)))
137 | model.add(Activation('relu'))
138 | model.add(MaxPooling2D(pool_size=(2, 2)))
139 | model.add(Dropout(0.25))
140 |
141 | model.add(Flatten())
142 | model.add(Dense(512))
143 | model.add(Activation('relu'))
144 | model.add(Dropout(0.5))
145 | # model.add(Dense(nb_classes))
146 | model.add(Dense(nb_classes, kernel_initializer='zero', activation=masked_softmax))
147 |
148 |
149 | # Define our training protocol
150 | protocol_name, protocol = protocols.PATH_INT_PROTOCOL(omega_decay='sum', xi=1e-3 )
151 | opt = Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999)
152 | # opt = RMSprop(lr=1e-3)
153 | # opt = SGD(1e-3)
154 | oopt = KOOptimizer(opt, model=model, **protocol)
155 | model.compile(loss='categorical_crossentropy', optimizer=oopt, metrics=['accuracy'])
156 | model._make_train_function()
157 |
158 | history = LossHistory()
159 | callbacks = [history]
160 | datafile_name = "split_cifar10_data_%s_lr%.2e_ep%i.pkl.gz"%(protocol_name, learning_rate, epochs_per_task)
161 |
162 |
163 |
164 | def run_fits(cvals, training_data, valid_data, nstats=1):
165 | acc_mean = dict()
166 | acc_std = dict()
167 | for cidx, cval_ in enumerate(cvals):
168 | runs = []
169 | for runid in range(nstats):
170 | evals = []
171 | sess.run(tf.global_variables_initializer())
172 | # model.set_weights(saved_weights)
173 | cstuffs = []
174 | if cval_=='scratch':
175 | print("Scratch mode -- inits net before each age")
176 | cval = 0
177 | else:
178 | print("setting cval")
179 | cval = cval_
180 | oopt.set_strength(cval)
181 | oopt.init_task_vars()
182 | print("cval is %f"%sess.run(oopt.lam))
183 | for age, tidx in enumerate(range(n_tasks)):
184 | if cval_=='scratch':
185 | sess.run(tf.global_variables_initializer())
186 | oopt.reset_optimizer()
187 | print("Age %i, cval is=%f"%(age,cval))
188 | set_active_outputs(task_labels[age])
189 | stuffs = model.fit(training_data[tidx][0], training_data[tidx][1], batch_size, epochs_per_task, callbacks=callbacks, verbose=0)
190 | oopt.update_task_metrics(training_data[tidx][0], training_data[tidx][1], batch_size)
191 | oopt.update_task_vars()
192 | ftask = []
193 | for j in range(n_tasks):
194 | set_active_outputs(task_labels[j])
195 | train_err = masked_predict(model, training_data[j][0], training_data[j][1])
196 | valid_err = masked_predict(model, valid_data[j][0], valid_data[j][1])
197 | ftask.append( (np.mean(valid_err), np.mean(train_err)) )
198 | evals.append(ftask)
199 | cstuffs.append(stuffs)
200 |
201 | # Re-initialize optimizater variables
202 | if reset_optimizer:
203 | oopt.reset_optimizer()
204 |
205 | evals = np.array(evals)
206 | runs.append(evals)
207 |
208 | runs = np.array(runs)
209 | acc_mean[cval_] = runs.mean(0)
210 | acc_std[cval_] = runs.std(0)
211 | return dict(mean=acc_mean, std=acc_std)
212 |
213 |
214 | # Run the sim
215 | data = run_fits(cvals, training_datasets, validation_datasets, nstats=nstats)
216 |
217 |
218 | # data = dict(mean={0.1:0.0}, std={0.1:0.0})
219 | # print(data)
220 | if add_evals:
221 | old_data = utils.load_zipped_pickle(datafile_name)
222 | # returns empty dict if file not found
223 | for k in old_data.keys():
224 | for l in old_data[k].keys():
225 | data[k][l] = old_data[k][l]
226 |
227 | # Save the data
228 | utils.save_zipped_pickle(data, datafile_name)
229 |
230 | # To overwrite the data in the file uncomment this
231 | # all_evals = dict() # uncomment to delete on disk
232 | # utils.save_zipped_pickle(data, datafile_name)
233 |
--------------------------------------------------------------------------------
/pathint/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ganguli-lab/pathint/86c5e07c5603d6a44c2f53585c19ee70773ef15c/pathint/__init__.py
--------------------------------------------------------------------------------
/pathint/keras_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke
2 | # MIT License -- see LICENSE for details
3 | #
4 | # This file is part of the code to reproduce the core results of:
5 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through
6 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on
7 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention
8 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995.
9 | # http://proceedings.mlr.press/v70/zenke17a.html
10 | #
11 |
12 | # Keras-specific functions and utils
13 | from keras.callbacks import Callback
14 | from keras.models import Model
15 | import keras.backend as K
16 | from keras.layers import Dense
17 | import tensorflow as tf
18 |
19 | class LossHistory(Callback):
20 | def __init__(self, *args, **kwargs):
21 | super(LossHistory, self).__init__(*args, **kwargs)
22 | self.losses = []
23 | self.regs = []
24 |
25 | def on_batch_end(self, batch, logs={}):
26 | self.losses.append(logs.get('loss'))
27 | self.regs.append(K.get_session().run(self.model.optimizer.regularizer))
28 |
29 |
30 | # Create a callback that tracks FisherInformation
31 | from keras.models import Model
32 | import keras.backend as K
33 | def compute_fishers(model):
34 | # Check that model only contains Dense layers
35 | for l in model.layers:
36 | if not isinstance(l, Dense):
37 | raise ValueError("All layers of the model must be Dense, got %s"%l)
38 | # Create new model used to extract activations at each layer
39 | new_model = Model(inputs=model.input, outputs=[l.output for l in model.layers])
40 | acts = new_model(model.input)
41 |
42 | out = acts[-1]
43 | out_dim = out.get_shape().as_list()[-1]
44 | #assert len(model.weights) %2 == 0
45 | n_weights = len(model.weights) // 2
46 |
47 |
48 | def _get_fishers():
49 | fisher_weights = [0.0] * n_weights
50 | fisher_biases = [0.0] * n_weights
51 | for idx in range(out_dim):
52 | # Clips output
53 | # https://github.com/fchollet/keras/blob/master/keras/backend/tensorflow_backend.py#L2743
54 | output = out[:, idx]
55 | epsilon = tf.convert_to_tensor(x)
56 | if epsilon.dtype != output.dtype.base_dtype:
57 | epsilon = tf.cast(epsilon, output.dtype.base_dtype)
58 |
59 | output = tf.clip_by_value(output, epsilon, 1. - epsilon)
60 | y = K.log(output)
61 | # From the post-nonlinearity outputs of each layer, we walk back up through the graph
62 | # to find the linear activation corresponding to h=XW + b.
63 | # Then we identify the weights W, biases b, and previous activation X.
64 | # 1. TensorFlow can compute dy/dh, giving us a [batch_size, n_neurons] matrix
65 | # 2. We manually comute dy/dW=dy/dh X and dy/db=dy/dh
66 | # 3. We sum the squared Jacobians
67 |
68 | # Identify pre-nonlinearity activation corresponding to h=XW+b
69 | def _walk_up_until_add(x):
70 | if x.op.type == 'BiasAdd':
71 | return x
72 | elif x.op.type == 'Select':
73 | return x.op.inputs[1]
74 | else:
75 | return _walk_up_until_add(x.op.inputs[0])
76 |
77 | linear = [_walk_up_until_add(a) for a in acts]
78 | # Identify previous activation, X
79 | prev_acts = [l.op.inputs[0].op.inputs[0] for l in linear]
80 | # Compute dy/dh
81 | dy_dlinear = [tf.gradients(y,l)[0] for l in linear]
82 |
83 | # Figure out which Jacobians correspond to which weights
84 | if idx == 0:
85 | val_to_var = {v.value():v for v in model.weights}
86 | weight_vars = [val_to_var[l.op.inputs[0].op.inputs[1]] for l in linear]
87 | bias_vars = [val_to_var[l.op.inputs[1]] for l in linear]
88 |
89 | # Compute the sum of the Jacobian squared
90 | # Because each of the Jacobians are rank-1, we can compute this by first squaring and then summing:
91 | # \sum_i (u_i v_i^T)^2 = \sum_i (u_i^2 (v_i^2)^T)
92 | weights_sum_jacobian_squared = [tf.matmul(tf.transpose(a)**2, dh**2) for dh,a in zip(dy_dlinear, prev_acts)]
93 | bias_sum_jacobian_squared = [tf.reduce_sum(dh**2, 0) for dh in dy_dlinear]
94 | # Keep track of aggregate across outputs
95 | for jj in range(n_weights):
96 | fisher_weights[jj] += weights_sum_jacobian_squared[jj]
97 | fisher_biases[jj] += bias_sum_jacobian_squared[jj]
98 | var_to_fisher = dict(zip(weight_vars+bias_vars, fisher_weights+fisher_biases))
99 | return {w: var_to_fisher[w] for w in model.weights}
100 |
101 | fishers = _get_fishers()
102 | return fishers
103 |
104 | # Allocate space for accumulated Fisher
105 | avg_fishers = [K.zeros(w.get_shape().as_list()) for w in model.weights]
106 | # Create updates to reset avg fisher, update, etc.
107 | update_fishers = tf.group(*[tf.assign_add(avg_f, f) for avg_f, f in zip(avg_fishers, fishers)])
108 | zero_fishers = tf.group(*[tf.assign(avg_f, 0.0 * avg_f) for avg_f in avg_fishers])
109 | return fishers, avg_fishers, update_fishers, zero_fishers
110 |
111 |
--------------------------------------------------------------------------------
/pathint/optimizers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke
2 | # MIT License -- see LICENSE for details
3 | #
4 | # This file is part of the code to reproduce the core results of:
5 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through
6 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on
7 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention
8 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995.
9 | # http://proceedings.mlr.press/v70/zenke17a.html
10 | #
11 | """Optimization algorithms."""
12 |
13 | import tensorflow as tf
14 |
15 | import numpy as np
16 | import keras
17 | from keras import backend as K
18 | from keras.optimizers import Optimizer
19 | from keras.callbacks import Callback
20 | from pathint.utils import extract_weight_changes, compute_updates
21 | from pathint.regularizers import quadratic_regularizer
22 | from collections import OrderedDict
23 |
24 |
25 | class KOOptimizer(Optimizer):
26 | """An optimizer whose loss depends on its own updates."""
27 |
28 | def _allocate_var(self, name=None):
29 | return {w: K.zeros(w.get_shape(), name=name) for w in self.weights}
30 |
31 | def _allocate_vars(self, names):
32 | #TODO: add names, better shape/init checking
33 | self.vars = {name: self._allocate_var(name=name) for name in names}
34 |
35 | def __init__(self, opt, step_updates=[], task_updates=[], init_updates=[], task_metrics = {}, regularizer_fn=quadratic_regularizer,
36 | lam=1.0, model=None, compute_average_loss=False, compute_average_weights=False, **kwargs):
37 | """Instantiate an optimzier that depends on its own updates.
38 |
39 | Args:
40 | opt: Keras optimizer
41 | step_updates: OrderedDict or List of tuples
42 | Contains variable names and updates to be run at each step:
43 | (name, lambda vars, weight, prev_val: new_val). See below for details.
44 | task_updates: same as step_updates but run after each task
45 | init_updates: updates to be run before using the optimizer
46 | task_metrics: list of names of metrics to compute on full data/unionset after a task
47 | regularizer_fn (optional): function, takes in weights and variables returns scalar
48 | defaults to EWC regularizer
49 | lam: scalar penalty that multiplies the regularization term
50 | model: Keras model to be optimized. Needed to compute Fisher information
51 | compute_average_loss: compute EMA of the loss, default: False
52 | compute_average_weights: compute EMA of the weights, default: False
53 |
54 | Variables are created for each name in the task and step updates. Note that you cannot
55 | use the name 'grads', 'unreg_grads' or 'deltas' as those are reserved to contain the gradients
56 | of the full loss, loss without regularization, and the weight updates at each step.
57 | You can access them in the vars dict, e.g.: oopt.vars['grads']
58 |
59 | The step and task update functions have the signature:
60 | def update_fn(vars, weight, prev_val):
61 | '''Compute the new value for a variable.
62 | Args:
63 | vars: optimization variables (OuroborosOptimzier.vars)
64 | weight: weight Variable in model that this variable is associated with.
65 | prev_val: previous value of this varaible
66 | Returns:
67 | Tensor representing the new value'''
68 |
69 | You can run both task and step updates on the same variable, allowing you to reset
70 | step variables after each task.
71 | """
72 | super(KOOptimizer, self).__init__(**kwargs)
73 | if not isinstance(opt, keras.optimizers.Optimizer):
74 | raise ValueError("opt must be an instance of keras.optimizers.Optimizer but got %s"%type(opt))
75 | if not isinstance(step_updates, OrderedDict):
76 | step_updates = OrderedDict(step_updates)
77 | if not isinstance(task_updates, OrderedDict): task_updates = OrderedDict(task_updates)
78 | if not isinstance(init_updates, OrderedDict): init_updates = OrderedDict(init_updates)
79 | # task_metrics
80 | self.names = set().union(step_updates.keys(), task_updates.keys(), task_metrics.keys())
81 | if 'grads' in self.names or 'deltas' in self.names:
82 | raise ValueError("Optimization variables cannot be named 'grads' or 'deltas'")
83 | self.step_updates = step_updates
84 | self.task_updates = task_updates
85 | self.init_updates = init_updates
86 | self.compute_average_loss = compute_average_loss
87 | self.regularizer_fn = regularizer_fn
88 | # Compute loss and gradients
89 | self.lam = K.variable(value=lam, dtype=tf.float32, name="lam")
90 | self.nb_data = K.variable(value=1.0, dtype=tf.float32, name="nb_data")
91 | self.opt = opt
92 | #self.compute_fisher = compute_fisher
93 | #if compute_fisher and model is None:
94 | # raise ValueError("To compute Fisher information, you need to pass in a Keras model object ")
95 | self.model = model
96 | self.task_metrics = task_metrics
97 | self.compute_average_weights = compute_average_weights
98 |
99 | def set_strength(self, val):
100 | K.set_value(self.lam, val)
101 |
102 | def set_nb_data(self, nb):
103 | K.set_value(self.nb_data, nb)
104 |
105 | def get_updates(self, params,loss,model=None):
106 | self.weights = params
107 | # Allocate variables
108 | with tf.variable_scope("KOOptimizer"):
109 | self._allocate_vars(self.names)
110 |
111 | #grads = self.get_gradients(loss, params)
112 |
113 | # Compute loss and gradients
114 | self.regularizer = 0.0 if self.regularizer_fn is None else self.regularizer_fn(params, self.vars)
115 | self.initial_loss = loss
116 | self.loss = loss + self.lam * self.regularizer
117 | with tf.variable_scope("wrapped_optimizer"):
118 | self._weight_update_op, self._grads, self._deltas = compute_updates(self.opt, self.loss, params)
119 |
120 | wrapped_opt_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "wrapped_optimizer")
121 | self.init_opt_vars = tf.variables_initializer(wrapped_opt_vars)
122 |
123 | self.vars['unreg_grads'] = dict(zip(params, tf.gradients(self.initial_loss, params)))
124 | # Compute updates
125 | self.vars['grads'] = dict(zip(params, self._grads))
126 | self.vars['deltas'] = dict(zip(params, self._deltas))
127 | # Keep a pointer to self in vars so we can use it in the updates
128 | self.vars['oopt'] = self
129 | # Keep number of data samples handy for normalization purposes
130 | self.vars['nb_data'] = self.nb_data
131 |
132 | if self.compute_average_weights:
133 | with tf.variable_scope("weight_emga") as scope:
134 | weight_ema = tf.train.ExponentialMovingAverage(decay=0.99, zero_debias=True)
135 | self.maintain_weight_averages_op = weight_ema.apply(self.weights)
136 | self.vars['average_weights'] = {w: weight_ema.average(w) for w in self.weights}
137 | self.weight_ema_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)
138 | self.init_weight_ema_vars = tf.variables_initializer(self.weight_ema_vars)
139 | print(">>>>>")
140 | K.get_session().run(self.init_weight_ema_vars)
141 | if self.compute_average_loss:
142 | with tf.variable_scope("ema") as scope:
143 | ema = tf.train.ExponentialMovingAverage(decay=0.99, zero_debias=True)
144 | self.maintain_averages_op = ema.apply([self.initial_loss])
145 | self.ema_loss = ema.average(self.initial_loss)
146 | self.prev_loss = tf.Variable(0.0, trainable=False, name="prev_loss")
147 | self.delta_loss = tf.Variable(0.0, trainable=False, name="delta_loss")
148 | self.ema_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)
149 | self.init_ema_vars = tf.variables_initializer(self.ema_vars)
150 | # if self.compute_fisher:
151 | # self._fishers, _, _, _ = compute_fishers(self.model)
152 | # #fishers = compute_fisher_information(model)
153 | # self.vars['fishers'] = dict(zip(weights, self._fishers))
154 | # #fishers, avg_fishers, update_fishers, zero_fishers = compute_fisher_information(model)
155 |
156 | def _var_update(vars, update_fn):
157 | updates = []
158 | for w in params:
159 | updates.append(tf.assign(vars[w], update_fn(self.vars, w, vars[w])))
160 | return tf.group(*updates)
161 |
162 | def _compute_vars_update_op(updates):
163 | # Force task updates to happen sequentially
164 | update_op = tf.no_op()
165 | for name, update_fn in updates.items():
166 | with tf.control_dependencies([update_op]):
167 | update_op = _var_update(self.vars[name], update_fn)
168 | return update_op
169 |
170 | self._vars_step_update_op = _compute_vars_update_op(self.step_updates)
171 | self._vars_task_update_op = _compute_vars_update_op(self.task_updates)
172 | self._vars_init_update_op = _compute_vars_update_op(self.init_updates)
173 |
174 | # Create task-relevant update ops
175 | reset_ops = []
176 | update_ops = []
177 | for name, metric_fn in self.task_metrics.items():
178 | metric = metric_fn(self)
179 | for w in params:
180 | reset_ops.append(tf.assign(self.vars[name][w], 0*self.vars[name][w]))
181 | update_ops.append(tf.assign_add(self.vars[name][w], metric[w]))
182 | self._reset_task_metrics_op = tf.group(*reset_ops)
183 | self._update_task_metrics_op = tf.group(*update_ops)
184 |
185 | # Each step we update the weights using the optimizer as well as the step-specific variables
186 | self.step_op = tf.group(self._weight_update_op, self._vars_step_update_op)
187 | self.updates.append(self.step_op)
188 | # After each task, run task-specific variable updates
189 | self.task_op = self._vars_task_update_op
190 | self.init_op = self._vars_init_update_op
191 |
192 | if self.compute_average_weights:
193 | self.updates.append(self.maintain_weight_averages_op)
194 |
195 | if self.compute_average_loss:
196 | self.update_loss_op = tf.assign(self.prev_loss, self.ema_loss)
197 | bupdates = self.updates
198 | with tf.control_dependencies(bupdates + [self.update_loss_op]):
199 | self.updates = [tf.group(*[self.maintain_averages_op])]
200 | self.delta_loss = self.prev_loss - self.ema_loss
201 |
202 | return self.updates#[self._base_updates
203 |
204 | def init_task_vars(self):
205 | K.get_session().run([self.init_op])
206 |
207 | def init_acc_vars(self):
208 | K.get_session().run(self.init_ema_vars)
209 |
210 | def init_loss(self, X, y, batch_size):
211 | pass
212 | #sess = K.get_session()
213 | #xi, yi, sample_weights = self.model.model._standardize_user_data(X[:batch_size], y[:batch_size], batch_size=batch_size)
214 | #sess.run(tf.assign(self.prev_loss, self.initial_loss), {self.model.input:xi[0], self.model.model.targets[0]:yi[0], self.model.model.sample_weights[0]:sample_weights[0], K.learning_phase():1})
215 |
216 | def update_task_vars(self):
217 | K.get_session().run(self.task_op)
218 |
219 | def update_task_metrics(self, X, y, batch_size):
220 | # Reset metric accumulators
221 | n_batch = len(X) // batch_size
222 |
223 | sess = K.get_session()
224 | sess.run(self._reset_task_metrics_op)
225 | for i in range(n_batch):
226 | xi, yi, sample_weights = self.model._standardize_user_data(X[i * batch_size:(i+1) * batch_size], y[i*batch_size:(i+1)*batch_size], batch_size=batch_size)
227 | sess.run(self._update_task_metrics_op, {self.model.input:xi[0], self.model.targets[0]:yi[0], self.model.sample_weights[0]:sample_weights[0]})
228 |
229 |
230 | def reset_optimizer(self):
231 | """Reset the optimizer variables"""
232 | K.get_session().run(self.init_opt_vars)
233 |
234 | def get_config(self):
235 | raise ValueError("Write the get_config bro")
236 |
237 | def get_numvals_list(self, key='omega'):
238 | """ Returns list of numerical values such as for instance omegas in reproducible order """
239 | variables = self.vars[key]
240 | numvals = []
241 | for p in self.weights:
242 | numval = K.get_value(tf.reshape(variables[p],(-1,)))
243 | numvals.append(numval)
244 | return numvals
245 |
246 | def get_numvals(self, key='omega'):
247 | """ Returns concatenated list of numerical values such as for instance omegas in reproducible order """
248 | conc = np.concatenate(self.get_numvals_list(key))
249 | return conc
250 |
251 | def get_state(self):
252 | state = []
253 | vs = self.vars
254 | for key in vs.keys():
255 | if key=='oopt': continue
256 | v = vs[key]
257 | for p in v.values():
258 | state.append(K.get_value(p)) # FIXME WhyTF does this not work?
259 | return state
260 |
261 | def set_state(self, state):
262 | c = 0
263 | vs = self.vars
264 | for key in vs.keys():
265 | if key=='oopt': continue
266 | v = vs[key]
267 | for p in v.values():
268 | K.set_value(p,state[c])
269 | c += 1
270 |
--------------------------------------------------------------------------------
/pathint/protocols.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke
2 | # MIT License -- see LICENSE for details
3 | #
4 | # This file is part of the code to reproduce the core results of:
5 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through
6 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on
7 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention
8 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995.
9 | # http://proceedings.mlr.press/v70/zenke17a.html
10 | #
11 | from pathint.utils import ema
12 | from pathint.regularizers import quadratic_regularizer, get_power_regularizer
13 | from pathint.keras_utils import compute_fishers
14 | import tensorflow as tf
15 | import numpy as np
16 |
17 | """
18 | A protocol is a function that takes as input some parameters and returns a tuple:
19 | (protocol_name, optimizer_kwargs)
20 | The protocol name is just a string that describes the protocol.
21 | The optimizer_kwargs is a dictionary that will get passed to KOOptimizer. It typically contains:
22 | step_updates, task_updates, task_metrics, regularizer_fn
23 | """
24 |
25 |
26 |
27 | PATH_INT_PROTOCOL = lambda omega_decay, xi: (
28 | 'path_int[omega_decay=%s,xi=%s]'%(omega_decay,xi),
29 | {
30 | 'init_updates': [
31 | ('cweights', lambda vars, w, prev_val: w.value() ),
32 | ],
33 | 'step_updates': [
34 | ('grads2', lambda vars, w, prev_val: prev_val -vars['unreg_grads'][w] * vars['deltas'][w] ),
35 | ],
36 | 'task_updates': [
37 | ('omega', lambda vars, w, prev_val: tf.nn.relu( ema(omega_decay, prev_val, vars['grads2'][w]/((vars['cweights'][w]-w.value())**2+xi)) ) ),
38 | #('cached_grads2', lambda vars, w, prev_val: vars['grads2'][w]),
39 | #('cached_cweights', lambda vars, w, prev_val: vars['cweights'][w]),
40 | ('cweights', lambda opt, w, prev_val: w.value()),
41 | ('grads2', lambda vars, w, prev_val: prev_val*0.0 ),
42 | ],
43 | 'regularizer_fn': quadratic_regularizer,
44 | })
45 |
46 |
47 | FISHER_PROTOCOL = lambda omega_decay:(
48 | 'fisher[omega_decay=%s]'%omega_decay,
49 | {
50 | 'task_updates': [
51 | ('omega', lambda vars, w, prev_val: ema(omega_decay, prev_val, vars['task_fisher'][w]/vars['nb_data'])),
52 | ('cweights', lambda opt, w, prev_val: w.value()),
53 | ],
54 | 'task_metrics': {
55 | 'task_fisher': lambda opt: compute_fishers(opt.model),
56 | },
57 | 'regularizer_fn': quadratic_regularizer,
58 | })
59 |
60 | def sum_regularizer_fn(weights, vars):
61 | reg = 0.0
62 | for w in weights:
63 | reg += tf.reduce_sum(vars['sum_omega'][w] * w**2
64 | - 2 * vars['sum_omega_cweights'][w] * w
65 | + vars['sum_omega_cweights_squared'][w])
66 |
67 | return reg
68 |
69 |
--------------------------------------------------------------------------------
/pathint/regularizers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke
2 | # MIT License -- see LICENSE for details
3 | #
4 | # This file is part of the code to reproduce the core results of:
5 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through
6 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on
7 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention
8 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995.
9 | # http://proceedings.mlr.press/v70/zenke17a.html
10 | #
11 | import tensorflow as tf
12 |
13 | def quadratic_regularizer(weights, vars, norm=2):
14 | """Compute the regularization term.
15 |
16 | Args:
17 | weights: list of Variables
18 | _vars: dict from variable name to dictionary containing the variables.
19 | Each set of variables is stored as a dictionary mapping from weights to variables.
20 | For example, vars['grads'][w] would retreive the 'grads' variable for weight w
21 | norm: power for the norm of the (weights - consolidated weight)
22 |
23 | Returns:
24 | scalar Tensor regularization term
25 | """
26 | reg = 0.0
27 | for w in weights:
28 | reg += tf.reduce_sum(vars['omega'][w] * (w - vars['cweights'][w])**norm)
29 | return reg
30 |
31 | def get_power_regularizer(power=2.0):
32 | """Power regularizers with different norms"""
33 | def _regularizer_fn(weights, vars):
34 | return quadratic_regularizer(weights, vars, norm=power)
35 | return _regularizer_fn
36 |
--------------------------------------------------------------------------------
/pathint/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # Copyright (c) 2017 Ben Poole & Friedemann Zenke
3 | # MIT License -- see LICENSE for details
4 | #
5 | # This file is part of the code to reproduce the core results of:
6 | # Zenke, F., Poole, B., and Ganguli, S. (2017). Continual Learning Through
7 | # Synaptic Intelligence. In Proceedings of the 34th International Conference on
8 | # Machine Learning, D. Precup, and Y.W. Teh, eds. (International Convention
9 | # Centre, Sydney, Australia: PMLR), pp. 3987-3995.
10 | # http://proceedings.mlr.press/v70/zenke17a.html
11 | #
12 | """Utility functions for benchmarking online learning"""
13 | from __future__ import division
14 | import numpy as np
15 | import keras
16 | from keras.utils import np_utils
17 |
18 | from keras.datasets import mnist, cifar10, cifar100
19 | from keras.optimizers import Adam, RMSprop, SGD
20 | import keras.backend as K
21 |
22 | import pickle
23 | import gzip
24 |
25 | import tensorflow as tf
26 |
27 | def ema(decay, prev_val, new_val):
28 | """Compute exponential moving average.
29 |
30 | Args:
31 | decay: 'sum' to sum up values, otherwise decay in [0, 1]
32 | prev_val: previous value of accumulator
33 | new_val: new value
34 | Returns:
35 | updated accumulator
36 | """
37 | if decay == 'sum':
38 | return prev_val + new_val
39 | return decay * prev_val + (1.0 - decay) * new_val
40 |
41 | def leak(decay, prev_val, new_val):
42 | """Compute leaky integrator.
43 |
44 | Like ema, but expectation value depends on decay time constant.
45 |
46 | Args:
47 | decay: 'sum' to sum up values, otherwise decay in [0, 1]
48 | prev_val: previous value of accumulator
49 | new_val: new value
50 | Returns:
51 | updated accumulator
52 | """
53 | if decay == 'sum':
54 | return prev_val + new_val
55 | return decay * prev_val + new_val
56 |
57 | def extract_weight_changes(weights, update_ops):
58 | """Given a list of weights and Assign ops, identify the change in weights.
59 |
60 | Args:
61 | weights: list of Variables
62 | update_ops: list of Assign ops, typically computed using Keras' opt.get_updates()
63 |
64 | Returns:
65 | list of Tensors containing the weight update for each variable
66 | """
67 | name_to_var = {v.name: v.value() for v in weights}
68 | weight_update_ops = list(filter(lambda x: x.op.inputs[0].name in name_to_var, update_ops))
69 | nonweight_update_ops = list(filter(lambda x: x.op.inputs[0].name not in name_to_var, update_ops))
70 | # Make sure that all the weight update ops are Assign ops
71 | for weight in weight_update_ops:
72 | if weight.op.type != 'Assign':
73 | raise ValueError('Update op for weight %s is not of type Assign.'%weight.op.inputs[0].name)
74 | weight_changes = [(new_w.op.inputs[1] - name_to_var[new_w.op.inputs[0].name]) for new_w, old_w in zip(weight_update_ops, weights)]
75 | # Recreate the update ops, ensuring that we compute the weight changes before updating the weights
76 | with tf.control_dependencies(weight_changes):
77 | new_weight_update_ops = [tf.assign(new_w.op.inputs[0], new_w.op.inputs[1]) for new_w in weight_update_ops]
78 | return weight_changes, tf.group(*(nonweight_update_ops + new_weight_update_ops))
79 |
80 |
81 | def compute_updates(opt, loss, weights):
82 | update_ops = opt.get_updates(weights, [], loss)
83 | deltas, new_update_op = extract_weight_changes(weights, update_ops)
84 | grads = tf.gradients(loss, weights)
85 | # Make sure that deltas are computed _before_ the weight is updated
86 | return new_update_op, grads, deltas
87 |
88 |
89 | def split_dataset_by_labels(X, y, task_labels, nb_classes=None, multihead=False):
90 | """Split dataset by labels.
91 |
92 | Args:
93 | X: data
94 | y: labels
95 | task_labels: list of list of labels, one for each dataset
96 | nb_classes: number of classes (used to convert to one-hot)
97 | Returns:
98 | List of (X, y) tuples representing each dataset
99 | """
100 | if nb_classes is None:
101 | nb_classes = len(np.unique(y))
102 | datasets = []
103 | for labels in task_labels:
104 | idx = np.in1d(y, labels)
105 | if multihead:
106 | label_map = np.arange(nb_classes)
107 | label_map[labels] = np.arange(len(labels))
108 | data = X[idx], np_utils.to_categorical(label_map[y[idx]], len(labels))
109 | else:
110 | data = X[idx], np_utils.to_categorical(y[idx], nb_classes)
111 | datasets.append(data)
112 | return datasets
113 |
114 | def split_dataset_randomly(X, y, nb_splits, nb_classes=None):
115 | """Split dataset by labels.
116 |
117 | Args:
118 | X: data
119 | y: labels
120 | nb_splits: number of splits to return
121 | task_labels: list of list of labels, one for each dataset
122 | nb_classes: number of classes (used to convert to one-hot)
123 | Returns:
124 | List of (X, y) tuples representing each dataset
125 | """
126 | if nb_classes is None:
127 | nb_classes = len(np.unique(y))
128 | datasets = []
129 | idx = range(len(y))
130 | np.random.shuffle(idx)
131 | split_size = len(y)//nb_splits
132 | for i in range(nb_splits):
133 | data = X[idx[split_size*i:split_size*(i+1)]], np_utils.to_categorical(y[idx[split_size*i:split_size*(i+1)]], nb_classes)
134 | datasets.append(data)
135 | return datasets
136 |
137 | def get_mnist_variations(dsetnames=['MNIST_Rotated', 'MNIST_Basic'], datashape=(-1,1,28,28), validationset_fraction=0.1, multihead=False):
138 | """ Uses skdata package to import some MNIST variations
139 |
140 | The following dataset names exist in skdata:
141 | all = ['MNIST_Basic',
142 | 'MNIST_BackgroundImages',
143 | 'MNIST_BackgroundRandom',
144 | 'MNIST_Rotated',
145 | 'MNIST_Noise1',
146 | 'MNIST_Noise2',
147 | 'MNIST_Noise3',
148 | 'MNIST_Noise4',
149 | 'MNIST_Noise5',
150 | 'MNIST_Noise6' ]
151 |
152 | args:
153 | dsetnames: the names of the data sets from above list
154 | datashape: tuple with shape of the data (default (-1,1,28,28)
155 | validationset_fraction: the fraction of data to hold out
156 | multihead: whether to generate a multihead dataset or a single head one
157 |
158 | returns:
159 | doublet of training and validation set each being a list of tasks consisting of (X,y) tuples
160 | """
161 |
162 | from skdata import larochelle_etal_2007 as L2007
163 | def dset(name):
164 | rval = getattr(L2007, name)()
165 | return rval
166 |
167 | n_tasks = len(dsetnames)
168 | training_datasets = []
169 | validation_datasets = []
170 |
171 | for i, dsname in enumerate(dsetnames):
172 | aa = dset(dsname)
173 | task = aa.classification_task()
174 | raw_data, raw_labels = task
175 | nb_datapoints = len(raw_data)
176 | label_offset = 0
177 | if multihead:
178 | nb_classes = 10*n_tasks
179 | label_offset = i*10
180 | else:
181 | nb_classes = 10
182 | nb_training_examples = int(nb_datapoints*(1.0-validationset_fraction))
183 | data = raw_data.reshape(datashape)
184 | labels = np_utils.to_categorical(raw_labels+label_offset, nb_classes)
185 | training_datasets.append( (data[:nb_training_examples], labels[:nb_training_examples]) )
186 | validation_datasets.append( (data[nb_training_examples:], labels[nb_training_examples:]) )
187 |
188 | return training_datasets, validation_datasets
189 |
190 | def load_mnist(split='train'):
191 | (X_train, y_train), (X_test, y_test) = mnist.load_data()
192 | X_train = X_train.reshape(-1, 784)
193 | X_test = X_test.reshape(-1, 784)
194 | X_train = X_train.astype('float32')
195 | X_test = X_test.astype('float32')
196 | X_train /= 255
197 | X_test /= 255
198 |
199 | if split == 'train':
200 | X, y = X_train, y_train
201 | else:
202 | X, y = X_test, y_test
203 | nb_classes = 10
204 | y = np_utils.to_categorical(y, nb_classes)
205 | return X, y
206 |
207 | def construct_split_mnist(task_labels, split='train', multihead=False):
208 | """Split MNIST dataset by labels.
209 |
210 | Args:
211 | task_labels: list of list of labels, one for each dataset
212 | split: whether to use train or testing data
213 |
214 | Returns:
215 | List of (X, y) tuples representing each dataset
216 | """
217 | # Load MNIST data and normalize
218 | nb_classes = 10
219 | (X_train, y_train), (X_test, y_test) = mnist.load_data()
220 | X_train = X_train.reshape(-1, 784)
221 | X_test = X_test.reshape(-1, 784)
222 | X_train = X_train.astype('float32')
223 | X_test = X_test.astype('float32')
224 | X_train /= 255
225 | X_test /= 255
226 |
227 | if split == 'train':
228 | X, y = X_train, y_train
229 | else:
230 | X, y = X_test, y_test
231 |
232 | return split_dataset_by_labels(X, y, task_labels, nb_classes, multihead)
233 |
234 |
235 | def construct_randomly_split_mnist(nb_splits=10, mode='train'):
236 | """Split MNIST dataset by labels.
237 |
238 | Args:
239 | nb_splits: numer of splits
240 | mode: whether to use train or testing data
241 |
242 | Returns:
243 | List of (X, y) tuples representing each dataset
244 | """
245 | # Load MNIST data and normalize
246 | nb_classes = 10
247 | (X_train, y_train), (X_test, y_test) = mnist.load_data()
248 | X_train = X_train.reshape(-1, 784)
249 | X_test = X_test.reshape(-1, 784)
250 | X_train = X_train.astype('float32')
251 | X_test = X_test.astype('float32')
252 | X_train /= 255
253 | X_test /= 255
254 |
255 | if mode == 'train':
256 | X, y = X_train, y_train
257 | else:
258 | X, y = X_test, y_test
259 |
260 | return split_dataset_randomly(X, y, nb_splits, nb_classes)
261 |
262 | def construct_transfer_cifar10_cifar100(nb_tasks=4, split='train'):
263 | """
264 | Returns a two task dataset in which the first task is the full CIFAR10 dataset and the second task are 10 from CIFAR100
265 | classes from the CIFAR100 dataset.
266 |
267 | params:
268 | nb_tasks The total number of tasks
269 | split Whether to return training or validation data
270 |
271 | returns:
272 | A list with two tuples containing the two data sets
273 | """
274 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
275 | # X_train = X_train.reshape(-1, 3, 32, 32)
276 | # X_test = X_test.reshape(-1, 32**2)
277 | X_train = X_train.astype('float32')
278 | X_test = X_test.astype('float32')
279 | no = X_train.max()
280 | X_train /= no
281 | X_test /= no
282 |
283 | if split == 'train':
284 | X, y = X_train, y_train
285 | else:
286 | X, y = X_test, y_test
287 |
288 | nb_classes = nb_tasks*10
289 | datasets = [(X,np_utils.to_categorical(y, nb_classes))]
290 |
291 | # Load CIFAR100 data and normalize
292 | (X_train, y_train), (X_test, y_test) = cifar100.load_data()
293 | X_train = X_train.astype('float32')
294 | X_test = X_test.astype('float32')
295 | m = np.max( (np.max(X_train), np.max(X_test) ) )
296 | X_train /= m
297 | X_test /= m
298 |
299 | if split == 'train':
300 | X, y = X_train, y_train
301 | else:
302 | X, y = X_test, y_test
303 |
304 | # split dataset by labels
305 | task_labels = [ range(10*i,10*(i+1)) for i in range(1,nb_tasks) ]
306 | for labels in task_labels:
307 | idx = np.in1d(y+10, labels)
308 | data = X[idx], np_utils.to_categorical(y[idx]+10, nb_classes)
309 | datasets.append(data)
310 |
311 |
312 | all_task_labels = [range(10)]
313 | all_task_labels.extend(task_labels)
314 | return all_task_labels, datasets
315 |
316 | def construct_split_cifar100(num_tasks=3, num_classes=10):
317 | """Split CIFAR100 dataset and relabel classes num_classes
318 |
319 | Args:
320 | num_tasks: the number of tasks
321 | num_classes: the number of classes per task
322 |
323 | Returns:
324 | List of (X, y) tuples representing each dataset
325 | """
326 | # Load CIFAR100 data and normalize
327 | (X_train, y_train), (X_test, y_test) = cifar100.load_data()
328 | X_train = X_train.astype('float32')
329 | X_test = X_test.astype('float32')
330 | m = np.max( (np.max(X_train), np.max(X_test) ) )
331 | X_train /= m
332 | X_test /= m
333 |
334 | X, y = X_train, y_train
335 |
336 | # split dataset by labels
337 | # here we also flatten the labels of cifar100 to match num_classes via modulus operation
338 | task_labels = [ range(num_classes*i,num_classes*(i+1)) for i in range(num_tasks) ]
339 | datasets = []
340 | for labels in task_labels:
341 | idx = np.in1d(y, labels)
342 | data = X[idx], np_utils.to_categorical(y[idx]%num_classes, num_classes)
343 | datasets.append(data)
344 |
345 | return datasets
346 |
347 | def construct_permute_mnist(num_tasks=2, split='train', permute_all=False, subsample=1):
348 | """Create permuted MNIST tasks.
349 |
350 | Args:
351 | num_tasks: Number of tasks
352 | split: whether to use train or testing data
353 | permute_all: When set true also the first task is permuted otherwise it's standard MNIST
354 | subsample: subsample by so much
355 |
356 | Returns:
357 | List of (X, y) tuples representing each dataset
358 | """
359 | # Load MNIST data and normalize
360 | nb_classes = 10
361 | (X_train, y_train), (X_test, y_test) = mnist.load_data()
362 | X_train = X_train.reshape(-1, 784)
363 | X_test = X_test.reshape(-1, 784)
364 | X_train = X_train.astype('float32')
365 | X_test = X_test.astype('float32')
366 | X_train /= 255
367 | X_test /= 255
368 |
369 | X_train, y_train = X_train[::subsample], y_train[::subsample]
370 | X_test, y_test = X_test[::subsample], y_test[::subsample]
371 |
372 | permutations = []
373 | # Generate random permutations
374 | for i in range(num_tasks):
375 | idx = np.arange(X_train.shape[1],dtype=int)
376 | if permute_all or i>0:
377 | np.random.shuffle(idx)
378 | permutations.append(idx)
379 |
380 | both_datasets = []
381 | for (X, y) in ((X_train, y_train), (X_test, y_test)):
382 | datasets = []
383 | for perm in permutations:
384 | data = X[:,perm], np_utils.to_categorical(y, nb_classes)
385 | datasets.append(data)
386 | both_datasets.append(datasets)
387 | return both_datasets
388 |
389 |
390 | def construct_split_cifar10(task_labels, split='train'):
391 | """Split CIFAR10 dataset by labels.
392 |
393 | Args:
394 | task_labels: list of list of labels, one for each dataset
395 | split: whether to use train or testing data
396 |
397 | Returns:
398 | List of (X, y) tuples representing each dataset
399 | """
400 | # Load CIFAR10 data and normalize
401 | nb_classes = 10
402 | (X_train, y_train), (X_test, y_test) = cifar10.load_data()
403 | # X_train = X_train.reshape(-1, 3, 32, 32)
404 | # X_test = X_test.reshape(-1, 32**2)
405 | X_train = X_train.astype('float32')
406 | X_test = X_test.astype('float32')
407 | no = X_train.max()
408 | X_train /= no
409 | X_test /= no
410 |
411 | if split == 'train':
412 | X, y = X_train, y_train
413 | else:
414 | X, y = X_test, y_test
415 |
416 | return split_dataset_by_labels(X, y, task_labels, nb_classes)
417 |
418 |
419 | def online_benchmark(datasets, model, loss, optimizer, epochs_per_dataset=1,
420 | ages=1, batch_size=256, callbacks=None, **kwargs):
421 | """Benchmark online learning.
422 |
423 | Sequentially optimize a set of tasks, and compute
424 | the predictions for each task over time.
425 |
426 | Args:
427 | datasets: list of (inputs, labels) tuples
428 | model: Keras model
429 | loss: string or function
430 | optimizer: string or Keras Optimizer object
431 | epochs_per_dataset: number of passes through an individual dataset
432 | ages: number of passes over datasets
433 | batch_size: batch size
434 | callbacks: list of functions to call with the model at each iteration
435 |
436 | Returns:
437 | labels:
438 | predictions:
439 | """
440 |
441 | # Build the model
442 | model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy'])
443 |
444 | ndataset = len(datasets)
445 | predictions = [[] for i in range(ndataset)]
446 | labels = [[] for i in range(ndataset)]
447 | if callbacks is not None:
448 | callback_outputs = [[] for i in range(len(callbacks))]
449 | for cidx, callback in enumerate(callbacks):
450 | callback_outputs[cidx].append(callback(model))
451 |
452 | optimization_data = [[] for i in range(len(model.get_weights())) ]
453 | for age in range(ages):
454 | for didx, dataset in enumerate(datasets):
455 |
456 | model.fit(*dataset, batch_size=batch_size, nb_epoch=epochs_per_dataset, verbose=1)
457 | # Log w, g, g2, ...
458 | # For all variables, ... ,
459 | if isinstance(optimizer, Adam):
460 | weights = model.get_weights()
461 | opt_vars = optimizer.weights[1:]
462 | ms = opt_vars[:len(opt_vars)//2]
463 | vs = opt_vars[len(opt_vars)//2:]
464 | sess = K.get_session()
465 | stuff = sess.run([ms, vs])
466 | for i in range(len(model.get_weights())):
467 | optimization_data[i].append([stuff[0][i], stuff[1][i]])
468 | #optimization_data.append(stuff)
469 |
470 | # Evaluate on all datasets
471 | for eval_didx, eval_dataset in enumerate(datasets):
472 | # Evaluate model on dataset
473 | preds = model.predict(eval_dataset[0])
474 | predictions[eval_didx].append(preds)
475 | # Convert from 1-hot back to categorical
476 | labels[eval_didx].append(np.argmax(eval_dataset[1], 1))
477 | print(model.evaluate(*eval_dataset))
478 | print("")
479 | if callbacks is not None:
480 | for cidx, callback in enumerate(callbacks):
481 | callback_outputs[cidx].append(callback(model))
482 |
483 | if callbacks is None:
484 | callback_outputs = None
485 | # TODO(ben): might break some shit
486 |
487 |
488 | return dict(labels=labels, predictions=predictions,
489 | callback_outputs=callback_outputs,
490 | optimization_data=optimization_data)
491 |
492 |
493 | def save_zipped_pickle(obj, filename, protocol=-1):
494 | with gzip.open(filename, 'wb') as f:
495 | pickle.dump(obj, f, protocol)
496 |
497 |
498 | def load_zipped_pickle(filename):
499 | try:
500 | with gzip.open(filename, 'rb') as f:
501 | loaded_object = pickle.load(f)
502 | return loaded_object
503 | except IOError:
504 | print("Warning: IO Error returning empty dict.")
505 | return dict()
506 |
507 |
508 | def split_dataset(ds, split_sizes, permute_data=True):
509 | """ Helper function to split a single dataset into train, valid and test set.
510 |
511 | args:
512 | ds the dataset being a tuple of (data,labels)
513 | split_sizes a list of fractional split sizes of howto divide up the dataset
514 |
515 | returns:
516 | a list of datasets with the respective split ratios
517 | """
518 | raw_data, raw_labels = ds
519 | if permute_data:
520 | idx = range(len(raw_data))
521 | np.random.shuffle(idx)
522 | data = raw_data[idx]
523 | labels = raw_labels[idx]
524 | else:
525 | data = raw_data
526 | labels = raw_labels
527 | nelems = len(labels)
528 | nbegin = 0
529 | splits = []
530 | for split in split_sizes:
531 | nend = nbegin+int(split*nelems)
532 | splits.append( (data[nbegin:nend], labels[nbegin:nend]) )
533 | nbegin = nend
534 | return splits
535 |
536 | def mk_training_validation_splits( full_datasets, split_fractions = (0.8, 0.1, 0.1) ):
537 | """ Splits multiple a list of tasks into training, validation and test sets
538 |
539 | args:
540 | full_datasets: The full dataset as a list of tasks each being of the form (data, labels)
541 | split_fractions: A list of split fractions which should sum up to 1.0
542 |
543 | returns:
544 | a list of length len(split_fractions) each containing a list of tasks
545 | """
546 | results = [ [] for i in range(len(split_fractions)) ]
547 | for ds in full_datasets:
548 | splits = split_dataset(ds, split_fractions)
549 | for i,sp in enumerate(splits):
550 | results[i].append(sp)
551 | return results
552 |
553 | def mk_joined_dataset( full_datasets, split_fractions = (0.9, 0.1) ):
554 | """ Joins datasets from multiple tasks to a single dataset as a baseline control and returns training and validation splints. """
555 | l = len(full_datasets)
556 | data = np.concatenate([ full_datasets[i][0] for i in range(l) ], 0)
557 | labels = np.concatenate([ full_datasets[i][1] for i in range(l) ], 0)
558 | return split_dataset((data, labels), split_fractions)
559 |
560 |
561 |
562 | def main():
563 | """ Test code for permute MNIST task
564 |
565 | Plots the first digit of the first two tasks. """
566 | import matplotlib.pyplot as plt
567 | ds = construct_split_cifar100()
568 | plt.subplot(121)
569 | plt.imshow(ds[0][0][0].transpose((1,2,0) ), interpolation='nearest')
570 | plt.subplot(122)
571 | plt.imshow(ds[1][0][0].transpose((1,2,0)), interpolation='nearest')
572 | plt.show()
573 |
574 | if __name__ == "__main__":
575 | main()
576 |
577 |
--------------------------------------------------------------------------------