├── .gitignore ├── Create Animations.ipynb ├── Datasets ├── Bars.ipynb ├── Corners.ipynb ├── MNIST+Shape.ipynb ├── Multi-MNIST.ipynb ├── Shapes.ipynb ├── Simple Superposition.ipynb └── plot_tools.py ├── Get Results.ipynb ├── Networks ├── best_bars_dae.h5 ├── best_bars_dae_train_multi.h5 ├── best_corners_dae.h5 ├── best_corners_dae_train_multi.h5 ├── best_mnist_shape_dae.h5 ├── best_mnist_shape_dae_train_multi.h5 ├── best_multi_mnist_dae.h5 ├── best_multi_mnist_dae_train_multi.h5 ├── best_shapes_dae.h5 ├── best_shapes_dae_train_multi.h5 └── best_simple_superpos_dae.h5 ├── Plots.ipynb ├── README.md ├── Run Random Search.ipynb ├── animations ├── RC.gif ├── bars.gif ├── bars_train_multi.gif ├── corners.gif ├── corners_train_multi.gif ├── mnist_shape.gif ├── mnist_shape_train_multi.gif ├── multi_mnist.gif ├── multi_mnist_train_multi.gif ├── shapes.gif ├── shapes_train_multi.gif └── simple_superpos.gif ├── dae.py ├── dump.zip ├── extra imgs ├── DAE.png ├── FTW.png ├── NNFTW.png ├── Tiefighter.png ├── circles.png ├── interlocked.png ├── interrupted_lines.png └── split_lines.png ├── run_best_nets.py ├── run_evaluation.py └── run_random_search.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by http://www.gitignore.io 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | bin/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # Installer logs 29 | pip-log.txt 30 | pip-delete-this-directory.txt 31 | 32 | # Unit test / coverage reports 33 | htmlcov/ 34 | .tox/ 35 | .coverage 36 | .cache 37 | nosetests.xml 38 | coverage.xml 39 | 40 | # Translations 41 | *.mo 42 | 43 | # Mr Developer 44 | .mr.developer.cfg 45 | .project 46 | .pydevproject 47 | 48 | # Rope 49 | .ropeproject 50 | 51 | # Django stuff: 52 | *.log 53 | *.pot 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # IPython Notebooks 59 | .ipynb_checkpoints 60 | 61 | 62 | ### PyCharm ### 63 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm 64 | 65 | ## Directory-based project format 66 | .idea/ 67 | # if you remove the above rule, at least ignore user-specific stuff: 68 | # .idea/workspace.xml 69 | # .idea/tasks.xml 70 | # and these sensitive or high-churn files: 71 | # .idea/dataSources.ids 72 | # .idea/dataSources.xml 73 | # .idea/sqlDataSources.xml 74 | # .idea/dynamic.xml 75 | 76 | ## File-based project format 77 | *.ipr 78 | *.iws 79 | *.iml 80 | 81 | ## Additional for IntelliJ 82 | out/ 83 | 84 | # generated by mpeltonen/sbt-idea plugin 85 | .idea_modules/ 86 | 87 | # generated by JIRA plugin 88 | atlassian-ide-plugin.xml 89 | 90 | # generated by Crashlytics plugin (for Android Studio and Intellij) 91 | com_crashlytics_export_strings.xml 92 | 93 | # GEdit temporary files 94 | *~ 95 | 96 | -------------------------------------------------------------------------------- /Create Animations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import matplotlib.pyplot as plt\n", 12 | "import h5py\n", 13 | "import pickle\n", 14 | "from matplotlib.colors import hsv_to_rgb\n", 15 | "import numpy as np\n", 16 | "from dae import ex\n", 17 | "import os.path" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": { 24 | "collapsed": false, 25 | "scrolled": true 26 | }, 27 | "outputs": [ 28 | { 29 | "name": "stderr", 30 | "output_type": "stream", 31 | "text": [ 32 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 33 | "INFO - binding_dae - Running command 'evaluate'\n", 34 | "WARNING - binding_dae - No observers have been added to this run\n", 35 | "INFO - binding_dae - Started\n" 36 | ] 37 | }, 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "Average Score: 0.6794\n", 43 | "Average Confidence: 0.8952\n", 44 | "wrote the results to Results/multi_mnist_30_3.pickle" 45 | ] 46 | }, 47 | { 48 | "name": "stderr", 49 | "output_type": "stream", 50 | "text": [ 51 | "INFO - binding_dae - Result: 0.6794270494870049\n", 52 | "INFO - binding_dae - Completed after 0:03:08\n", 53 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 54 | " max_open_warning, RuntimeWarning)\n", 55 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 56 | "INFO - binding_dae - Running command 'evaluate'\n", 57 | "WARNING - binding_dae - No observers have been added to this run\n", 58 | "INFO - binding_dae - Started\n" 59 | ] 60 | }, 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "\n", 66 | "Average Score: 0.8916" 67 | ] 68 | }, 69 | { 70 | "name": "stderr", 71 | "output_type": "stream", 72 | "text": [ 73 | "INFO - binding_dae - Result: 0.8916011324539921\n", 74 | "INFO - binding_dae - Completed after 0:00:04\n", 75 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 76 | " max_open_warning, RuntimeWarning)\n", 77 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 78 | "INFO - binding_dae - Running command 'evaluate'\n", 79 | "WARNING - binding_dae - No observers have been added to this run\n", 80 | "INFO - binding_dae - Started\n" 81 | ] 82 | }, 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "\n", 88 | "Average Confidence: 0.9474\n", 89 | "wrote the results to Results/simple_superpos_30_2.pickle\n", 90 | "Average Score: 0.9537\n", 91 | "Average Confidence: 0.9470\n", 92 | "wrote the results to Results/shapes_30_3.pickle" 93 | ] 94 | }, 95 | { 96 | "name": "stderr", 97 | "output_type": "stream", 98 | "text": [ 99 | "INFO - binding_dae - Result: 0.9537316145962019\n", 100 | "INFO - binding_dae - Completed after 0:01:00\n", 101 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 102 | " max_open_warning, RuntimeWarning)\n", 103 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 104 | "INFO - binding_dae - Running command 'evaluate'\n", 105 | "WARNING - binding_dae - No observers have been added to this run\n", 106 | "INFO - binding_dae - Started\n" 107 | ] 108 | }, 109 | { 110 | "name": "stdout", 111 | "output_type": "stream", 112 | "text": [ 113 | "\n", 114 | "Average Score: 0.5882\n", 115 | "Average Confidence: 0.9480\n", 116 | "wrote the results to Results/mnist_shape_30_2.pickle" 117 | ] 118 | }, 119 | { 120 | "name": "stderr", 121 | "output_type": "stream", 122 | "text": [ 123 | "INFO - binding_dae - Result: 0.588232954116256\n", 124 | "INFO - binding_dae - Completed after 0:00:50\n", 125 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 126 | " max_open_warning, RuntimeWarning)\n", 127 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 128 | "INFO - binding_dae - Running command 'evaluate'\n", 129 | "WARNING - binding_dae - No observers have been added to this run\n", 130 | "INFO - binding_dae - Started\n" 131 | ] 132 | }, 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "\n", 138 | "Average Score: 0.9853\n", 139 | "Average Confidence: 0.9785\n", 140 | "wrote the results to Results/bars_30_12.pickle" 141 | ] 142 | }, 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "INFO - binding_dae - Result: 0.9853389285150709\n", 148 | "INFO - binding_dae - Completed after 0:02:33\n", 149 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 150 | " max_open_warning, RuntimeWarning)\n", 151 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 152 | "INFO - binding_dae - Running command 'evaluate'\n", 153 | "WARNING - binding_dae - No observers have been added to this run\n", 154 | "INFO - binding_dae - Started\n" 155 | ] 156 | }, 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "\n", 162 | "Average Score: 0.8972\n", 163 | "Average Confidence: 0.9830\n", 164 | "wrote the results to Results/corners_30_5.pickle" 165 | ] 166 | }, 167 | { 168 | "name": "stderr", 169 | "output_type": "stream", 170 | "text": [ 171 | "INFO - binding_dae - Result: 0.8971806767597252\n", 172 | "INFO - binding_dae - Completed after 0:01:31\n" 173 | ] 174 | }, 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "\n" 180 | ] 181 | }, 182 | { 183 | "name": "stderr", 184 | "output_type": "stream", 185 | "text": [ 186 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 187 | " max_open_warning, RuntimeWarning)\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "# run a longer (30 iterations) evaluation for visualization\n", 193 | "datasets = {\n", 194 | " 'bars': 12, \n", 195 | " 'corners': 5,\n", 196 | " 'shapes': 3,\n", 197 | " 'multi_mnist': 3,\n", 198 | " 'mnist_shape': 2,\n", 199 | " 'simple_superpos':2\n", 200 | "}\n", 201 | "nr_iters = 30\n", 202 | "nrows = 10\n", 203 | "ncols = 12\n", 204 | "\n", 205 | "\n", 206 | "\n", 207 | "for ds, k in datasets.items():\n", 208 | " results_filename = 'Results/{}_{}_{}.pickle'.format(ds, nr_iters, k)\n", 209 | " animation_dir = 'animations/{}'.format(ds)\n", 210 | " if not os.path.exists(animation_dir):\n", 211 | " os.makedirs(animation_dir)\n", 212 | " \n", 213 | " ex.run_command('evaluate', config_updates={\n", 214 | " 'dataset.name': ds,\n", 215 | " 'net_filename': 'Networks/best_{}_dae.h5'.format(ds),\n", 216 | " 'em.k': k,\n", 217 | " 'em.nr_iters': 30,\n", 218 | " 'em.dump_results': results_filename,\n", 219 | " 'em.nr_samples': nrows * ncols,\n", 220 | " 'seed': 42}) \n", 221 | " \n", 222 | " with h5py.File('/home/greff/Datasets/{}.h5'.format(ds)) as f:\n", 223 | " true_groups = f['test']['groups'][:]\n", 224 | " with open(results_filename, 'rb') as f:\n", 225 | " scores, likelihoods, results = pickle.load(f)\n", 226 | " \n", 227 | " if results.shape[-1] != 3:\n", 228 | " nr_colors = results.shape[-1]\n", 229 | " hsv_colors = np.ones((nr_colors, 3))\n", 230 | " hsv_colors[:, 0] = (np.linspace(0, 1, nr_colors, endpoint=False) + 2/3) % 1.0\n", 231 | " color_conv = hsv_to_rgb(hsv_colors)\n", 232 | " results = results.reshape(-1, nr_colors).dot(color_conv).reshape(results.shape[:-1] + (3,))\n", 233 | " \n", 234 | " for it in range(nr_iters+1):\n", 235 | " fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows))\n", 236 | " for r in range(nrows):\n", 237 | " for c in range(ncols):\n", 238 | " axes[r, c].imshow(results[ncols*r + c, it, 0, :, :, 0, :], interpolation='nearest')\n", 239 | " axes[r, c].set_xticks([])\n", 240 | " axes[r, c].set_yticks([])\n", 241 | " plt.subplots_adjust(wspace=0, hspace=0)\n", 242 | " fig.savefig(animation_dir + '/img_{:02d}.png'.format(it), bbox_inches='tight', pad_inches=0, dpi=72.26)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 8, 248 | "metadata": { 249 | "collapsed": false, 250 | "scrolled": true 251 | }, 252 | "outputs": [ 253 | { 254 | "name": "stderr", 255 | "output_type": "stream", 256 | "text": [ 257 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 258 | "INFO - binding_dae - Running command 'evaluate'\n", 259 | "WARNING - binding_dae - No observers have been added to this run\n", 260 | "INFO - binding_dae - Started\n", 261 | "INFO - binding_dae - Result: 0.29725815533634\n", 262 | "INFO - binding_dae - Completed after 0:00:07\n", 263 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 264 | " max_open_warning, RuntimeWarning)\n", 265 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 266 | "INFO - binding_dae - Running command 'evaluate'\n", 267 | "WARNING - binding_dae - No observers have been added to this run\n", 268 | "INFO - binding_dae - Started\n" 269 | ] 270 | }, 271 | { 272 | "name": "stdout", 273 | "output_type": "stream", 274 | "text": [ 275 | "Average Score: 0.2973\n", 276 | "Average Confidence: 1.0000\n", 277 | "wrote the results to Results/mnist_shape_30_2_train_multi.pickle\n", 278 | "Average Score: 0.7046\n", 279 | "Average Confidence: 1.0000\n", 280 | "wrote the results to Results/corners_30_5_train_multi.pickle" 281 | ] 282 | }, 283 | { 284 | "name": "stderr", 285 | "output_type": "stream", 286 | "text": [ 287 | "INFO - binding_dae - Result: 0.7045697660668432\n", 288 | "INFO - binding_dae - Completed after 0:00:10\n", 289 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 290 | " max_open_warning, RuntimeWarning)\n", 291 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 292 | "INFO - binding_dae - Running command 'evaluate'\n", 293 | "WARNING - binding_dae - No observers have been added to this run\n", 294 | "INFO - binding_dae - Started\n" 295 | ] 296 | }, 297 | { 298 | "name": "stdout", 299 | "output_type": "stream", 300 | "text": [ 301 | "\n", 302 | "Average Score: 0.8507\n", 303 | "Average Confidence: 1.0000\n", 304 | "wrote the results to Results/bars_30_12_train_multi.pickle" 305 | ] 306 | }, 307 | { 308 | "name": "stderr", 309 | "output_type": "stream", 310 | "text": [ 311 | "INFO - binding_dae - Result: 0.8507084656354675\n", 312 | "INFO - binding_dae - Completed after 0:00:18\n", 313 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 314 | " max_open_warning, RuntimeWarning)\n", 315 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 316 | "INFO - binding_dae - Running command 'evaluate'\n", 317 | "WARNING - binding_dae - No observers have been added to this run\n", 318 | "INFO - binding_dae - Started\n" 319 | ] 320 | }, 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "\n", 326 | "Average Score: 0.6322\n", 327 | "Average Confidence: 1.0000\n", 328 | "wrote the results to Results/multi_mnist_30_3_train_multi.pickle" 329 | ] 330 | }, 331 | { 332 | "name": "stderr", 333 | "output_type": "stream", 334 | "text": [ 335 | "INFO - binding_dae - Result: 0.6322366970257197\n", 336 | "INFO - binding_dae - Completed after 0:00:13\n", 337 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 338 | " max_open_warning, RuntimeWarning)\n", 339 | "WARNING - root - Changed type of config entry \"em.dump_results\" from NoneType to str\n", 340 | "INFO - binding_dae - Running command 'evaluate'\n", 341 | "WARNING - binding_dae - No observers have been added to this run\n", 342 | "INFO - binding_dae - Started\n" 343 | ] 344 | }, 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "\n", 350 | "Average Score: 0.7558\n", 351 | "Average Confidence: 1.0000\n", 352 | "wrote the results to Results/shapes_30_3_train_multi.pickle" 353 | ] 354 | }, 355 | { 356 | "name": "stderr", 357 | "output_type": "stream", 358 | "text": [ 359 | "INFO - binding_dae - Result: 0.7558324914261669\n", 360 | "INFO - binding_dae - Completed after 0:00:07\n" 361 | ] 362 | }, 363 | { 364 | "name": "stdout", 365 | "output_type": "stream", 366 | "text": [ 367 | "\n" 368 | ] 369 | }, 370 | { 371 | "name": "stderr", 372 | "output_type": "stream", 373 | "text": [ 374 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib/pyplot.py:516: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).\n", 375 | " max_open_warning, RuntimeWarning)\n" 376 | ] 377 | } 378 | ], 379 | "source": [ 380 | "# run a longer (30 iterations) evaluation for visualization\n", 381 | "datasets = {\n", 382 | " 'bars': 12, \n", 383 | " 'corners': 5,\n", 384 | " 'shapes': 3,\n", 385 | " 'multi_mnist': 3,\n", 386 | " 'mnist_shape': 2,\n", 387 | "# 'simple_superpos':2\n", 388 | "}\n", 389 | "nr_iters = 30\n", 390 | "nrows = 10\n", 391 | "ncols = 12\n", 392 | "\n", 393 | "\n", 394 | "\n", 395 | "for ds, k in datasets.items():\n", 396 | " results_filename = 'Results/{}_{}_{}_train_multi.pickle'.format(ds, nr_iters, k)\n", 397 | " animation_dir = 'animations/{}_train_multi'.format(ds)\n", 398 | " if not os.path.exists(animation_dir):\n", 399 | " os.makedirs(animation_dir)\n", 400 | " \n", 401 | " ex.run_command('evaluate', config_updates={\n", 402 | " 'dataset.name': ds,\n", 403 | " 'net_filename': 'Networks/best_{}_dae_train_multi.h5'.format(ds),\n", 404 | " 'em.k': k,\n", 405 | " 'em.e_step': 'max',\n", 406 | " 'em.nr_iters': nr_iters,\n", 407 | " 'em.dump_results': results_filename,\n", 408 | " 'em.nr_samples': nrows * ncols,\n", 409 | " 'seed': 42})\n", 410 | " \n", 411 | " with h5py.File('/home/greff/Datasets/{}.h5'.format(ds)) as f:\n", 412 | " input_image = f['test']['default'][:]\n", 413 | " true_groups = f['test']['groups'][:]\n", 414 | " with open(results_filename, 'rb') as f:\n", 415 | " scores, likelihoods, results = pickle.load(f)\n", 416 | " \n", 417 | " if results.shape[-1] != 3:\n", 418 | " nr_colors = results.shape[-1]\n", 419 | " hsv_colors = np.ones((nr_colors, 3))\n", 420 | " hsv_colors[:, 0] = (np.linspace(0, 1, nr_colors, endpoint=False) + 2/3) % 1.0\n", 421 | " color_conv = hsv_to_rgb(hsv_colors)\n", 422 | " results = results.reshape(-1, nr_colors).dot(color_conv).reshape(results.shape[:-1] + (3,))\n", 423 | " \n", 424 | " for it in range(nr_iters+1):\n", 425 | " fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows))\n", 426 | " for r in range(nrows):\n", 427 | " for c in range(ncols):\n", 428 | " groups = results[ncols*r + c, it, 0, :, :, 0, :]\n", 429 | " in_img = input_image[0, ncols*r + c] * 0.7 + 0.3\n", 430 | " \n", 431 | " axes[r, c].imshow(groups * in_img, interpolation='nearest')\n", 432 | " axes[r, c].set_xticks([])\n", 433 | " axes[r, c].set_yticks([])\n", 434 | " plt.subplots_adjust(wspace=0, hspace=0)\n", 435 | " fig.savefig(animation_dir + '/img_{:02d}.png'.format(it), bbox_inches='tight', pad_inches=0, dpi=72.26)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 1, 441 | "metadata": { 442 | "collapsed": true 443 | }, 444 | "outputs": [], 445 | "source": [ 446 | "import os.path" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 2, 452 | "metadata": { 453 | "collapsed": true 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "subdirs = [f for f in os.listdir('animations') if os.path.isdir(os.path.join('animations', f))]" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 3, 463 | "metadata": { 464 | "collapsed": false 465 | }, 466 | "outputs": [], 467 | "source": [ 468 | "from subprocess import call\n", 469 | "\n", 470 | "for d in subdirs:\n", 471 | " call(['convert', '-delay', '20', '-loop', '0', 'animations/{}/*.png'.format(d), 'animations/{}.gif'.format(d)])\n" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": null, 477 | "metadata": { 478 | "collapsed": true 479 | }, 480 | "outputs": [], 481 | "source": [] 482 | } 483 | ], 484 | "metadata": { 485 | "kernelspec": { 486 | "display_name": "Python 3", 487 | "language": "python", 488 | "name": "python3" 489 | }, 490 | "language_info": { 491 | "codemirror_mode": { 492 | "name": "ipython", 493 | "version": 3 494 | }, 495 | "file_extension": ".py", 496 | "mimetype": "text/x-python", 497 | "name": "python", 498 | "nbconvert_exporter": "python", 499 | "pygments_lexer": "ipython3", 500 | "version": "3.4.3" 501 | } 502 | }, 503 | "nbformat": 4, 504 | "nbformat_minor": 0 505 | } 506 | -------------------------------------------------------------------------------- /Datasets/Bars.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stderr", 12 | "output_type": "stream", 13 | "text": [ 14 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib-1.5.0+783.g23bc09d-py3.4-linux-x86_64.egg/matplotlib/__init__.py:877: UserWarning: axes.color_cycle is deprecated and replaced with axes.prop_cycle; please use the latter.\n", 15 | " warnings.warn(self.msg_depr % (key, alt_key))\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "from plot_tools import plot_groups, plot_input_image\n", 23 | "%matplotlib inline\n", 24 | "np.random.seed(516371)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "# Adapted Bars Problem\n", 32 | "\n", 33 | "Binary images with a fixed number of randomly placed horizontal and vertical bars.\n", 34 | " \n", 35 | "With width=height=20 and nr_horizontal_bars=nr_vertical_bars=6 this mimics the setup from [2]." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": { 42 | "collapsed": true 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "def generate_bars(width, height, nr_horizontal_bars, nr_vertical_bars):\n", 47 | " img = np.zeros((height, width), dtype=np.float)\n", 48 | " grp = np.zeros_like(img)\n", 49 | " \n", 50 | " idx_vert = np.random.choice(np.arange(width), replace=False, size=nr_vertical_bars)\n", 51 | " img[:, idx_vert] = 1.\n", 52 | " k = 1\n", 53 | " for i in idx_vert:\n", 54 | " grp[:, i] = k\n", 55 | " k += 1\n", 56 | " \n", 57 | " idx_horiz = np.random.choice(np.arange(height), replace=False, size=nr_horizontal_bars)\n", 58 | " img[idx_horiz, :] += 1.\n", 59 | " for i in idx_horiz:\n", 60 | " grp[i, :] = k\n", 61 | " k += 1\n", 62 | " \n", 63 | " grp[img > 1] = 0\n", 64 | " img = img != 0\n", 65 | " \n", 66 | " return img, grp\n", 67 | " " 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": { 74 | "collapsed": false 75 | }, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA4oAAAElCAYAAACiWBzqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAHOpJREFUeJzt3V2sHdWZJuDlQIKx5CbENlhEGRlDyxZWuIj5EU0E3LTj\nxpYARQg4VgPiPzLYMWkJSJCFGMckF2CBiIBAEJOoD0JRGpBwB2ghQWSa+ACRhigMrTE2mp9WADOE\nWOOmZ3py5mJC9yx5rTKrzjpVtfd5nsuqXeurVa69qz6fvd+aNz09HQAAAOATn+l7BwAAABgWjSIA\nAAARjSIAAAARjSIAAAARjSIAAAARjSIAAACRI5tWzps3r/jZGe+99/vk8uOO+5PibYaoaR4hpOcy\n1Lk37df09PS8Lvbh/fcPFJ1jbY7/KDrcPFNK596mRs36S5Ys7OQca/M5xnjo6nPMOTZ3DeEcq3Xd\na3tN6OLaMy7X9pxRvB+rrYt75b5rNOmzftM55i+KAAAARDSKAAAARDSKAAAARDSKAAAARDSKAAAA\nRBpTT2G21U7erD3eKOl77qX1p6cFRQLM1Kh99g+1BnAojSIw1mo+smcI0doldbp6nEzN4zLUxwk1\naXO+tBlvXAzhfVRSX5MCzFW+egoAAEBEowgAAEBEowgAAEBEowgAAEBEowgAAEBE6im9qpmI2Ga8\noWqTslf7WM52fQBmbohJsbXrjPv1RbIuQ+UvigAAAEQ0igAAAEQ0igAAAEQ0igAAAESE2dCr2j/g\nnss/CO977qX1p6enZ2lPYm2OS1fbtFGzThf73Of+dnWOAYy6vu8hmozatWoU6+f4iyIAAAARjSIA\nAAARjSIAAAARjSIAAAARjSIAAAARjSIAAAARjSIAAAARz1GkV++99/ui1x/uOTOl4w1Vm+fp1D6W\ns12/K7n9app/zW1K1TzHu3q/1DwuXRxjAGJ9f772fX3t6ho+1Po5/qIIAABARKMIAABARKMIAABA\nRKMIAABARKMIAABARKMIAABARKMIAABARKMIAABARKMIAABARKMIAABA5MjaAx533J90ss1Qlc5l\nnOYOAHSr7/uILur3Pce5YMjHeC6cY33Xz/EXRQAAACIaRQAAACIaRQAAACIaRQAAACIaRQAAACIa\nRQAAACIaRQAAACLVn6P43nu/Ty5vej5IbpshOtxzTlJzGerch/rMFgDg06l1H9H2nqC0fps6o3Sf\n2MYQ7sf6PsZd3Cv3XaNJ3/Vz/EURAACAiEYRAACAiEYRAACAiEYRAACASPUwGwAAutF3EEoX9fue\nI8xV/qIIAABARKMIAABARKMIAABARKMIAABARKMIAABApHrqaZtkqnFKsyqdyzjNHWAuqv05Ppev\nC33Pve/6AEPi8RgAACPqvfd+X2Wctk1yaf02dWrNcaj8BwVD5aunAAAARDSKAAAARDSKAAAARDSK\nAAAARDSKAAAARKSe0iux8vX0PffS+tPT07O0J7GuHtnT1fGvWaeLfe5zf7s6xwBgHFVvFHMRxk0X\n+FGKPT7cjUpqLkOde9+NBQAAMEy+egoAAEBEowgAAEBEowgAAEBEowgAAEBE6ikAwIjqO5hu1NKT\ngU9PowgAM9Am7bvNeOOiiyTwmsdekwLMVb56CgAAQESjCAAAQESjCAAAQESjCAAAQKR6mE2bH32P\n0w/FS+cyTnMHmItqf47P5etC33Pvuz7AkEg9pVelCXeHu4iPS1pgm5uV2sdytut3pU0iZc1tStU8\nx7t6v9Q8Ll0cYwDg8Hz1FAAAgIhGEQAAgIhGEQAAgIhGEQAAgIhGEQAAgIjUUwAAmMOG/GiYLvat\n7/n3XT+neqPYZ6x8F9rEzQ917kM9KQFGSZvrXpvxxsUQHjNTUn/o18ohHrPadebyewL65KunAAAA\nRDSKAAAARDSKAAAARDSKAAAARKqH2bT5Qe44/Yi3dC7jNPc2as9/Lh/PvudeWn96enqW9gQAgJny\neAwAAJjD+k6W7TsNedTSg2vXz/HVUwAAACIaRQAAACIaRQAAACIaRQAAACLCbABgBqQ319P33Puu\n30bf+9xF/b7nOBcM+RjPhXOs7/o5GkV6VZrydLg3Ut+pXbW0+cCofSxnuz4AAMPlq6cAAABENIoA\nAABENIoAAABENIoAAABENIoAAABEpJ4CY61NumtX27RRs86oRY6XjjU9PV2tNgDMNRpFAJiB3KNh\n2jbJ4/6omabjUmvuNY/9UJ9v9okhHrPadebye6IrfR/jvj8XRu19VLt+jq+eAgAAENEoAgAAENEo\nAgAAENEoAgAAEBFmAwAAc9gQAnVyRi2hexTr51RvFNukv/WdtFTicP+QpYlpfc59qCclAADQL189\nBQAAIKJRBAAAIKJRBAAAIKJRBAAAIKJRBAAAIOLxGMBYq5nE3EWCcZtk5S7GaluntMZQU6IBYK7x\nF0UAAAAiGkUAAAAiGkUAAAAiGkUAAAAiGkUAAAAiUk/p1eFSGfseb5T0PffS+tPT07O0J9Atn2P1\n9D33vusDDIlGEQAA5rC+Hz/U9+OnunrE1VDr5/jqKQAAABGNIgAAABGNIgAAABGNIgAAABGNIgAA\nAJF5IuoBAAD4//mLIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAA\nABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGN\nIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAABGNIgAAAJEjm1b+5e5rplPLb1z6Qnab\n5QveSC4/4/n7s9tMrbmxaTcGZfKDixrXTyx68pBl616/K/v6natvm/E+tXXm5MPZdXs33zyvi314\n//0DyXMs5+a9WxrX37N8x4z2Zyhy59m9z52f3Wb3xLVFNfYePDW5/NJd12W3WbLoQHZd6bm8ZMnC\nTs6xv1i2JXuO/fjVOw9ZdvnpWxvHa7NNqVSNtnVqjlVap22NWmP9/J0dnZxjf/6Zi7Pn2OS7jxaN\nNXH8VTPen1GQOi61517z2OfG6upzLHc/FkK9617umhBCCJc9dVN2Xem1J3dPtGHZVHab1L3VOGm6\nR9634bZB3o/V1nT+5XqLUk338LXOsab71F17Tsquq9UL5e7tF6/Yn6+9dnv2HPMXRQAAACIaRQAA\nACIaRQAAACIaRQAAACIaRQAAACKNqacw286575Gi15+2vu54Q3XDhvJtSuf+2DXlNWrW/0//vjnB\ntgvplLULG7dJp6Z9ucr+fCKXCrioxVj5BLaFLUbLSyWtrWg5VioBcGXLsfqWO/4v//KU5PJ8Jt54\nSb/3TqtaI5ck+c3VueT2Y7Nj5d6TU2u3l+5WdV/bWue698Ct3dQ/8oLZrzFyvtr3DkCaRpGR8toz\nq8IvNqU7nHFpEkPIxzRPTKRf32buVz6yKXksp9YUDzXoY597PEQuivuOl57KRnHnorXXPvvratHa\nTY/T+eDBo4seQ9IU0734iQPVYvVzcdz/sOVLxdH5uZv7t7573GAfpZR7dMLhHueT8vaOs4qP2ajJ\nvfe27XutWgx+0+MGcta/+WHxI666knuv1mygvvG9TeG5O9PX192Za0+b+v/y9JJsnVo1Rs2xu+YX\nHRPoiq+eAgAAENEoAgAAENEoAgAAENEoAgAAEBFmw8jJB0Ss6nQ/ZlMuNOXe585PLj8qzGtVJxUq\ncemu67KvX7LoQG5Nq/oAHF5X171c0NBlT92UXP75lteeVEDQhmVTuVe3qjFqcgFM+zZ8+gAzqE2j\nSK9yCaY5TSmCp63/TbUUx75NfrCz6PX/vHi6OCkxd0PQxsKz3y9K4wRg5lZdWe+61+aa8LuV5dee\n0hTZKzbvrJYqPVRtUnqhC756CgAAQESjCAAAQESjCAAAQESjCAAAQESjCAAAQETqKb2auODeotcv\nPUy4W+l4Q7X+0fJtSue+7fHyGjXr/93f3153BzIuuv6h5PK7G86l3DYXby/fplj+6STFdU68pd5Y\njc7N/59jcY2v51eVjrXrb/6qrDbAHNX3/VPTPUmtfWu6t6pV43D3qbNdP1yyoM44f6RRZKT8dsuy\nMPn05uS6vj/kaspFgU9MZF7fYu63X3Z18lhOrSkeaiSP/be2bAxPPnT9IcubmpGffntt8TalPvfD\nLyRrtKmz7/srq43V5ISX/lDtuHzxZ5+d9WMMwLDk7klq3l88c9V5s16j6T41LE8vrln/pCcO5uu3\n4KunAAAARDSKAAAARDSKAAAARDSKAAAARITZMHJu3rsls2ZZl7sxqyY/uCi5/N7nzk8uP6llnb0H\nTz1k2aW78rGbSxYdSC4/JtRN2epKav4hbGzcJv1vs7bK/nxi3et3JZd/LnyheKz8+2Vl8VhNzpx8\n+JBlJ7Qc64zn7z9k2RfDZ1uOBsAoSF+Tr65aI30NP69qjdx1d9ee9N3ayVWrp6/Hi1fsz75+am0+\n0l2jSK9Kk5nyN70hLN3xTrhnectc4oGZ/ODFote/fcmCsHvi2qJt0h/I7Xy09WDYufq2auPVlEv9\nzM3/7h0/CMsXvJFcl2vgL97+bDaptlSuSQwhhP913f8oOs5N75cTb3mr2vsldVEKIYR/PPczxedl\nqkkMIYT//vX/HabW3Fi8bwAcXs2kzDZy1+Rtj/8oe00ulbuGr3/0xWrX8Kbrbs6ejUdUu77lrsdt\n+eopAAAAEY0iAAAAEY0iAAAAEY0iAAAAEY0iAAAAEY0iAAAAEY0iAAAAEY0iAAAAEY0iAAAAEY0i\nAAAAEY0iAAAAkSNrD3jOfY+kV6xssc0A3bCheX1qLgvPLnt9Zxb3VxoAmLnLT99aZZw7Xuqo/oNH\nz36NUfPd4/regzBx/FW91t+2L7+u1r6tf3P2ayx95ZhW21U7/jvOqjPOHzU2ivcs31E0WFPTM/+t\n+eEXm64p2maIHvzrdcl5hJCfy4GXlwxy7kftn5edCwAwDLn7sZoN1B3nXhh+/OqdyXW7J9LbtKm/\n6IZ/ytapVWPUrPzOe0XHZBzdfuJpYfLdRw9ZXrOBfeaUY2e9xm/P+ihZI4QQwvL04pr1T9rySr5+\nC756CgAAQESjCAAAQESjCAAAQESjCAAAQKQxzObmvVuSy29c+kJmi02Nxc54/v5Dls0P8xu3GaLJ\nDy7KrFmX3Wbd63clli6psj8zcebkw8nlezff3PGeAAApufuxEBZWrbP34KnJ5Zc9dVNy+YqWdVL3\nRBuWTWVe/eWWVUZL6h45hBD2bbit4z3pT/r8O61qjfQ9/LFVa+Ter7v2nJRcfnLV6ul7+8Ur9mdf\nP7V2e3Zd1cdjPHbNfWH5gjeS63JvgI9Xfhym1txYczdmVb5JDOGGDTvDxKInD1mebhJDWHj2+2Hn\n6v4+AHJNIgAwfIufOFCcUJ+TaxKb/MOWL4XdE9cWbZO7J8pZ++yvk/dW4yR3j9ylmkmZbeTOv237\nXsv2FqVy9/Dr3/yw2jmW/0+dvD0/+Uq1Xqj2vb2vngIAABDRKAIAABDRKAIAABDRKAIAABCpGmYD\nAEB3Ji64t8o42x7vqP7WDmqMmo1H9L0HkKRRBAAYsFyyac0G6vbLrg6TT29Orts9kd6mTf1j7lyQ\nrVOrxqg5+Qf/p+iYQFd89RQAAICIRhEAAICIRhEAAICIRhEAAICIRhEAAIBI9dTTieOvSq/4yVfK\ntxmg9W82r0/O5W//tOz1XdlxVn+1AcbE5aen8/4XP1F3vHFxx0v5ddXm/t3jWm2WrP/g0TPcmZm7\nee+WzJplVevsPXhqcvllT92UXH5SyzrrXr/rkGUblk1lXn1eyyqj5Yzn708u37fhto73BP5NY6OY\ni2POaWp6Tv7LX4XJdx8t2maInjnl2OQ8QsjP5Zjz//Mg537SlleycwFgZvZfsjD8+NU70yuXpxeP\ne5MYQgh3nHth8rjUnPvK77yXP/bhxuTSXP1FN/xTw1j9WrrjneJ7tZxck9jk7UsWhN0T1xZtk2oS\nm6x/9MUwsejJom1GTa5JhL756ikAAAARjSIAAAARjSIAAAARjSIAAACRxjCbXMrWjUtfyGxxWmOx\n1I91T27cYpgmP7gos+bY7DapH28fU2l/ZuLMyYeTy/duvrmT+hdd/1DR60+8pe54Q3Xx9vJtSud+\nd538g9b1d/3NX9XdgYxcaNS2feXbNKUeVwunakhJLq2z9JXmT5lq+9yQoFxco2JC9t/94adltQGA\nf1X18Rjb9r0Wli94I7kul+i05ydfCVNr0glkQ5RvEkNY/+aHyWSuXMLXR3/7p2Hn6v5ij3NN4pDt\n+/7K8ORD1yfXjUuTGELIJrxNTKRf32bu39qyMXksp9YUDzWSx/72E08rTiPOpR7XTDDOpSS3qfPb\nsz6qNlaTXIJymxqjmJD97l/8u0GONWr6nnvf9QGGxldPAQAAiGgUAQAAiGgUAQAAiGgUAQAAiFQN\nswEA/k0uPfzlX56SXP75MG82d2cw9h48NbF0U9UauRC9b67OJbevy46VC6WbWtsiorqyWmFibZOw\ni+tf10GNUfP1vncA0jSK9Orop3aXb5RJPW011lBl5pjTeu6FdarX78B/+86fzerr227TRs06Xezz\nqO1vW8/deU1yea5JbPK7ldNh98S1M92lQUs3iSE8cOt92eT0UrkmsckVm3cWJZd36Z7l6S6uZgOV\nS8IOIYTdFRO3P/fDL2Tr1Koxar74s88WHRPoiq+eAgAAENEoAgAAENEoAgAAENEoAgAAENEoAgAA\nEKmeejpxwb3pFRuPKN9mgNY/2rw+OZetha/vyiUL+qs9A22SBEfN5AcXJZff+9z5yeUntayTSh+8\ndFc+u3zJogPJ5ce0rN+3NhH96X+bfKx+G/mUxSXFY+XfL6uKx2py5uTDhyw7quWjHlKJlfPD/FZj\ndSH3Ob605eMGRuma2Ma2x/Prqs294Z6jSek1HGCczZuens6ufP/9A/mVCYf7gJ98enPxNkOUmkcI\nzXMZ6txzc1myZGEnD/MqPccO1yTmIsRHTWmjGEIojtTPRdS3aRRDCGHn6tuK6vd9juXmH0LIRvTn\n/l1CCMlY/TYOF8Vfcpy7er+kmsRPlJ6XTY81mFpzY9FYXZ1jf/5n27KfY7nP2JwhXBe60MU1seax\nH+q1cuL4q6rWmXz3MP8bXql+SZ3acxyq3DHp+xzrSptrcqkuruFN191de/L/rV96fcvJXY8Xr9if\nr712e/Yc89VTAAAAIhpFAAAAIhpFAAAAIhpFAAAAIo2pp7kfZN649IXMFlc3FkuFFJzcuMUw5X8M\ne152m1RAxTGh/9TR3I9e926+ueM9ARg/uevoy788Jbm8bYLxqEkHVzTfQ5TKBSN9c3XuHua87Fi5\nkKmptdsL96qdrtK+c4Eilz11U3J52/M1dTw3LJvKvPrYllVGS+583behLCiurXPue6STOjmPXZNf\nV2vfbtgw+zVOW99uu2rHf3GdYT5R9fEY2x7/UTaZKPcG2LPxiGpJP11oSkxa/+iLydSk3AXmo60H\ni5Mia2pKKgTg08klYra5uX/7kgXFSbGjJteMNN1DlGpKz80pvYYPwdJXjqmWXtyUOpnz9o6zis/X\n0uO5/s0PqyVSDlWb83XcXPnIpvCLTYd2izUb2Af/et2s13jtmVXJGiGEEJanF9esf9T+efn6Lfjq\nKQAAABGNIgAAABGNIgAAABGNIgAAAJGqYTYhhHDR9Q+lV3y9xTYDdPFhAs6Sc7mu8PVdObf//yf4\n2tayH/CuurLueEN1RTobo1Hp3B+4tbxGzfq/eqCbFD8AYPjSgUqbqtZIh1Kuq1ojF2S2a086J3h+\nmF+1fiqscvGK/dnXN6U3NzaKpUlaTU3PF3/22fDkQ9cXbTNEP/322uQ8QsjP5XM//MIg537CS3/I\nzmWofvPYqvDcnek0p3FpEkMI2YS3iYn069vM/Rvf25Q8llNriocaq2MPAHNNzaTMNnKpu49dc1+1\nNOTckwtu2LCzWrJum7Trj1d+XO0JELWfaND/n5QAAAAYFI0iAAAAEY0iAAAAEY0iAAAAEY0iAAAA\nkeqPx4DZlk+UWtXpfsymXDLXvc+dn1z++TCvVZ1Uytilu/LPc1my6EBuTav6MA5yCdYn3lJ3vHFx\nd0OgerW5NzySq0npI66G4Jz76qROP9Yy9LK0/sKzZ7/GyFnZ9w6EcPnpW3utf8dL+XW19m3ts7Nf\nY/ET7bardvy3fKnOOH9UtVE8ePwRnWwzVKVzGae5t5V71EVOU+zwqit/U/xIl6Ga/GBn0et/t3I6\n7J64tmibXBR1G0de8H7Yufq2auPBONj3/ZX5RxAtTy8e9yYxhBC+tWXjrD8yKvdIrv8nHUNf+oir\nLuWubTUbqCsf2ZR9RMLuzKOZ2tQ/8PKSokcxjH2TGEKY/9b83h9P0bc7zr0w/PjVOw9ZXrOBfXbt\nl2e9xv5LFiZrhBCyn/s166/Y8V/z9Vvw1VMAAAAiGkUAAAAiGkUAAAAiGkUAAAAiUk+Bsfa1rekg\nhAduLd/mis3l25Q68oLm9SV1Vl1Zb6xGDYl9xTW+Wm+sXz2QD7/qSi6A6+VfnpJcfsJs7syApMO0\nNlatccbz9yeXf3P1C5kt1mbHWvf6XcnlU2u3l+5WK12lfedCzi576qbk8qNaJm6njueGZVO5V7eq\nMWpy5+u+DXMnKC59/l1YtUY6Vf7LVWvk3q+79pyUXF479PbMyYcPWbZ4xf7s65s+x6o2isf//L80\nvyARwnPYbYYoEybUOJehzr1eMBKMlG98b1MydbepGfkP964r3qbUvzy9JJsGXFrnN4+tqjZWk8+/\nNa/acTl21/xZP8a15RIxm1Kac/7x3M8UJxiPmlwzcveOH4TlC96oUiN3093k4u3PholFTx6yPNck\nDsFp6+ulfbdJwv7nxeWJ26XH84YNO5P/LuOkzflaW82kzDZy598dLz1V7XMh9+ixtc/+uto51uZz\n/63vHhem1qTTmEulmsSZ8NVTAAAAIhpFAAAAIhpFAAAAIhpFAAAAIo1hNrkfZN64NJcY1pxMlPqx\nbu2kny7kfgzblJqU+vH2okr7MxO5H73u3Xxzx3sCAEAfLrr+oV7r392QyVRr3y5uCCmuVePEW9pt\nV+34n1v3b4BVU0+bkolyiU41k366kG8S86lJuYSvDx48Ouxc3V/sce1kJBiiRf/xf87q69tu00bN\nOl3s86jtLwD9+NaWjckE6ZoN7E+/vXbWa+z7/spsEnZYnl5cs/4JL/0hX78FXz0FAAAgolEEAAAg\nolEEAAAgolEEAAAgolEEAAAgUjX1FGCU7D14amLp1Y3bpJOPz6uxO/8ql5R8TFhQPFbuMUchLCse\nq0kqRfmklmOlUrJPbjkWANCORhEYa5NPb04uTzeJIWx7/EfZx/zkHo+z/tEXk4/GaSPXJIYQwkdb\nDxY9UiffJIawdMc74Z7lDQ+uKpB71M7blywIuyeuLRor9yilPRuPGKlHKQGMkoPHH9H3LiR1sV99\nz73v+k189RQAAICIRhEAAICIRhEAAICIRhEAAICIMBt6dfnpW4tev/iJuuMN1dpny7cpnfsdL5XX\nqFn/5+/UCVKBvp1z3yPJ5aetrzveuHjsmvy6anNf2W6zVP2FZ89wX4BPJR0yt6lqjXQo3bqqNXJB\ncrv2pLPAjw3zq9ZPBcwtXrE/+/qptduz6zSKjJT9lywMP371zuS6cWkSQwjZBM2JifTr28z9jnMv\nTB7LqTXFQ43VsYdaXntmVfjFpkxXtDy9eNybxBBCuPKRTcnjUnPu89+anz/2IZ2em6t/4OUlDWPB\neHjuzn7P8VwS+QO33pdNIi+VSy6/YvPOasnlTWnjOR9+9eNqqd65FPK2fPUUAACAiEYRAACAiEYR\nAACAiEYRAACAiEYRAACAyLzp6em+9wEAAIAB8RdFAAAAIhpFAAAAIhpFAAAAIhpFAAAAIhpFAAAA\nIhpFAAAAIv8X0PMpBMHheyQAAAAASUVORK5CYII=\n", 80 | "text/plain": [ 81 | "" 82 | ] 83 | }, 84 | "metadata": {}, 85 | "output_type": "display_data" 86 | } 87 | ], 88 | "source": [ 89 | "fig, axes = plt.subplots(ncols=6, nrows=2, figsize=(16,5))\n", 90 | "for ax in axes.T:\n", 91 | " img, grp = generate_bars(20, 20, 6, 6)\n", 92 | " plot_input_image(img, ax[0])\n", 93 | " plot_groups(grp, ax[1])" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "# Save as HDF5 Dataset\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "metadata": { 107 | "collapsed": true 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "import h5py\n", 112 | "import os\n", 113 | "import os.path\n", 114 | "\n", 115 | "data_dir = os.environ.get('BRAINSTORM_DATA_DIR', '.')" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 5, 121 | "metadata": { 122 | "collapsed": true 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "np.random.seed(471958)\n", 127 | "nr_train_examples = 60000\n", 128 | "nr_test_examples = 10000\n", 129 | "nr_single_examples = 200\n", 130 | "width = 20\n", 131 | "height = 20\n", 132 | "nr_vert = 6\n", 133 | "nr_horiz= 6\n", 134 | "\n", 135 | "data = np.zeros((1, nr_train_examples, height, width, 1), dtype=np.float32)\n", 136 | "grps = np.zeros_like(data)\n", 137 | "for i in range(nr_train_examples):\n", 138 | " data[0, i, :, :, 0], grps[0, i, :, :, 0] = generate_bars(width, height, nr_horiz, nr_vert)\n", 139 | "\n", 140 | "test_data = np.zeros((1, nr_train_examples, height, width, 1), dtype=np.float32)\n", 141 | "test_grps = np.zeros_like(test_data)\n", 142 | "for i in range(nr_train_examples):\n", 143 | " test_data[0, i, :, :, 0], test_grps[0, i, :, :, 0] = generate_bars(width, height, nr_horiz, nr_vert)\n", 144 | "\n", 145 | "single_data = np.zeros((1, nr_single_examples, height, width, 1), dtype=np.float32)\n", 146 | "single_grps = np.zeros_like(single_data)\n", 147 | "for i in range(nr_single_examples // 2):\n", 148 | " single_data[0, i, :, :, 0], single_grps[0, i, :, :, 0] = generate_bars(width, height, 1, 0)\n", 149 | "for i in range(nr_single_examples // 2, nr_single_examples):\n", 150 | " single_data[0, i, :, :, 0], single_grps[0, i, :, :, 0] = generate_bars(width, height, 0, 1)\n", 151 | "\n", 152 | "shuffel_idx = np.arange(nr_single_examples)\n", 153 | "np.random.shuffle(shuffel_idx)\n", 154 | "single_data = single_data[:, shuffel_idx]\n", 155 | "single_grps = single_grps[:, shuffel_idx]" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 6, 161 | "metadata": { 162 | "collapsed": false 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "with h5py.File(os.path.join(data_dir, 'bars.h5'), 'w') as f:\n", 167 | " single = f.create_group('train_single')\n", 168 | " single.create_dataset('default', data=single_data, compression='gzip', chunks=(1, 100, height, width, 1))\n", 169 | " single.create_dataset('groups', data=single_grps, compression='gzip', chunks=(1, 100, height, width, 1))\n", 170 | " train = f.create_group('train_multi')\n", 171 | " train.create_dataset('default', data=data, compression='gzip', chunks=(1, 100, height, width, 1))\n", 172 | " train.create_dataset('groups', data=grps, compression='gzip', chunks=(1, 100, height, width, 1))\n", 173 | " test = f.create_group('test')\n", 174 | " test.create_dataset('default', data=test_data, compression='gzip', chunks=(1, 100, height, width, 1))\n", 175 | " test.create_dataset('groups', data=test_grps, compression='gzip', chunks=(1, 100, height, width, 1))" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "# References\n", 183 | "[1] P. Földiák, [Forming sparse representations by local anti-Hebbian learning](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.17.1244&rep=rep1&type=pdf), Biological Cybernetics 1990\n", 184 | "\n", 185 | "[2] David P. Reichert and Thomas Serre, [Neuronal Synchrony in Complex-Valued Deep Networks](http://arxiv.org/abs/1312.6115), ICLR 2014\n", 186 | "\n", 187 | "[3] David E. Rumelhart and David Zipser, [Feature discovery by competitive learning](http://www.sciencedirect.com/science/article/pii/S0364021385800100), Cognitive Science 1985\n" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.4.3" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 0 212 | } 213 | -------------------------------------------------------------------------------- /Datasets/Corners.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stderr", 12 | "output_type": "stream", 13 | "text": [ 14 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib-1.5.0+783.g23bc09d-py3.4-linux-x86_64.egg/matplotlib/__init__.py:877: UserWarning: axes.color_cycle is deprecated and replaced with axes.prop_cycle; please use the latter.\n", 15 | " warnings.warn(self.msg_depr % (key, alt_key))\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "from plot_tools import plot_groups, plot_input_image\n", 23 | "%matplotlib inline\n", 24 | "np.random.seed(746519283)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "# Corners Problem\n", 32 | "\n", 33 | "Binary images containing 8 corner-pieces each. Four of them are arranged in a square, while the other 4 are randomly distributed. Introduced in [1] to investigate binding in deep networks.\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "width = 28\n", 45 | "height = 28\n", 46 | "\n", 47 | "corner = np.zeros((5, 5))\n", 48 | "corner[:2, :] = 1.0\n", 49 | "corner[:, :2] = 1.0\n", 50 | "\n", 51 | "corners = [\n", 52 | " corner.copy(),\n", 53 | " corner[::-1, :].copy(),\n", 54 | " corner[:, ::-1].copy(),\n", 55 | " corner[::-1, ::-1].copy()\n", 56 | "]\n", 57 | "\n", 58 | "square = np.zeros((20, 20))\n", 59 | "square[:5, :5] = corners[0]\n", 60 | "square[-5:, :5] = corners[1]\n", 61 | "square[:5, -5:] = corners[2]\n", 62 | "square[-5:, -5:] = corners[3]\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": { 69 | "collapsed": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "def generate_corners_image(width, height, nr_squares=1, nr_corners=4):\n", 74 | " img = np.zeros((height, width))\n", 75 | " grp = np.zeros_like(img)\n", 76 | " k = 1\n", 77 | " \n", 78 | " for i in range(nr_squares):\n", 79 | " x = np.random.randint(0, width-19)\n", 80 | " y = np.random.randint(0, height-19)\n", 81 | " region = (slice(y,y+20), slice(x,x+20))\n", 82 | " img[region][square != 0] += 1\n", 83 | " grp[region][square != 0] = k \n", 84 | " k += 1\n", 85 | " \n", 86 | " for i in range(nr_corners):\n", 87 | " x = np.random.randint(0, width-4)\n", 88 | " y = np.random.randint(0, height-4)\n", 89 | " corner = corners[np.random.randint(0, 4)]\n", 90 | " region = (slice(y,y+5), slice(x,x+5))\n", 91 | " img[region][corner != 0] += 1\n", 92 | " grp[region][corner != 0] = k\n", 93 | " k += 1\n", 94 | " \n", 95 | " grp[img > 1] = 0\n", 96 | " img = img != 0\n", 97 | " return img, grp\n", 98 | " " 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "metadata": { 105 | "collapsed": false 106 | }, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA4oAAAElCAYAAACiWBzqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADWtJREFUeJzt3bFrZMcdB/BVbEhld7rDlWvXDsEEAkf+CRk1hkOdIJhL\nkfwDbg9jUHccuFFQY9IHzHVO4z/ClTlfF9dmUxjDfRWtNPv2zb6ZN59PeVrtPc3Oruar9/vNnGy3\n2w0AAAD85ndLXwAAAABtERQBAAAIgiIAAABBUAQAACAIigAAAARBEQAAgPDufV988+ZnZ2cM6vT0\nvZNj/D/m2LjMMWpb4xx79Oj9qs//00//rfr8a7PGOcbDar4Pb78HzbFljPRZe98cc0cRAACAICgC\nAAAQBEUAAADCvT2KAMD9jtmv1IPWxqP0enoca4Ca3FEEAAAgCIoAAAAEQREAAIAgKAIAABBsZgMA\nnZi64Urtw6NhJDU3VaINXuNfuaMIAABAEBQBAAAIgiIAAAChmR5FB+ICS6ndV9Db51ZrB6YDHMMa\ne8x4mAyyWzNBEQDWbJRFhk0gANZB6SkAAABBUAQAACAoPQUAhjel/HW73Va4EoA2uKMIAABAEBQB\nAAAIgiIAAABBUAQAACDYzAYAgJ1qnnM5yvmi0CNBEaDQKAsaB6avzyhzFw5R+j7xecdd1vg5q/QU\nAACAICgCAAAQuis9nXq7v7fbwaP0A4zyc5YyHgDA6GqX9x6zxWKO9VfN8dhutzu/1l1QBICW+CPM\n4WqOodcHYBqlpwAAAARBEQAAgCAoAgAAEARFAAAAgs1sAAA4mI2DYF3cUQQAACAIigAAAARBEQAA\ngDBkj+KjR+8XPU6tPQAAMKLuguIo4W3Kz1kagHvkdd9tza/7sYwyv2oyhsAo/K5uxxy/e3r4/TX1\nGg+dd0pPAQAACIIiAAAAobvSUwCAu+xTZtVDuRnAktxRBAAAIAiKAAAABEERAACAICgCAAAQBEUA\nAACCXU8BANjJDrEwJncUAQAACIIiAAAAoevS030O1gUAAKBMM0FR/fvhehzDHq+5NcYQANrmd/V+\njNc8Dh1HpacAAAAEQREAAIAgKAIAABCa6VEE6F3pBlt6L1iLmnPehnVsNj5XYUnuKAIAABAERQAA\nAIKgCAAAQNCjyCrV7m3RC0Gvar43vC8AuM3vnX51HRRNDgAAgPkpPQUAACAIigAAAARBEQAAgCAo\nAgAAELrezAYAWCcb1gEsyx1FAAAAgqAIAABAUHoKcGRTDx9WikevzHk2m7oHrwPzExQZmkUIo5ky\n5y3uAKjNmqw9Sk8BAAAIgiIAAABBUAQAACDoUQSAA5T2cOq/2d/tsTWGAMfjjiIAAABBUAQAACAI\nigAAAARBEQAAgGAzG4CZ2GiD0cw150s3BKJvPiOhL+4oAgAAEARFAAAAgqAIAABA0KMIDavdt6Nf\n5FcOTJ9HyTgaQwDog6DIKlmMwnxK3k82IwHgLtZk/VJ6CgAAQBAUAQAACEpPAeAIppTn9liypQwZ\noFzL+1G4owgAAEAQFAEAAAiCIgAAAEFQBAAAIDSzmU3LjZwAAKMqXaNZa+1mDOlRM0ERmJdfNgAA\nTKX0FAAAgCAoAgAAEJSerkiP9e89XnMvbo+tMTzc1F7q3sbegekALMn6cD4PjeV2u935NUERAA5g\noZKMB8D8pny2HvqHX6WnAAAABEERAACAICgCAAAQ9Cg2ShPv+tkwhBH5bAOAPnQXFN9ePFhoAwAA\nzE/pKQAAAEFQBAAAIHRXegrUoXcMgENMaQnq8XeK1idGISgCw+txodKCfcfN4goA+qH0FAAAgCAo\nAgAAEARFAAAAQnc9inpc0ijjMUqDPAAAZaaug3tcIy6x5u8uKMJa3PUhNUrwBwCgLLQutT5UegoA\nAEAQFAEAAAiCIgAAAEGPIgDNsYEVACxLUAQAYCd/hEnGg1EoPQUAACAIigAAAASlpzAIZzQyN3MK\nANarmaCo3ntMXvc013hYwAMAvbI+TEuNh9JTAAAAgqAIAABAEBQBAAAIgiIAAAChmc1sSJp4gTXy\n2QYAfXBHEQAAgCAoAgAAEARFAAAAwsl2u136GgAAAGiIO4oAAAAEQREAAIAgKAIAABAERQAAAIKg\nCAAAQBAUAQAACIIiAAAAQVAEAAAgCIoAAAAEQREAAIAgKAIAABAERQAAAIKgCAAAQBAUAQAACIIi\nAAAAQVAEAAAgCIoAAAAEQREAAIAgKAIAABAERQAAAIKgCAAAQBAUAQAACIIiAAAAQVAEAAAgvHvf\nF9+8+Xl7rAuhLaen750c4/8xx8ZljlHbGufYk+cvqj7/q2cXVZ9/bdY4x3hYzffh7fegObaM88dP\nqz7/9euXVZ9/H/fNMXcUAQAACIIiAAAAQVAEAAAg3NujCADc78sfP6v23J9/8HW1566ltfE4O78q\netzN9eXezw2wZu4oAgAAEARFAAAAgqAIAABAEBQBAAAIgiIAAADBrqcA0IlXzy4mfd+T5y9mvhIY\n15T3ofdgX65fv9z7e84fP61wJctyRxEAAIAgKAIAABCaKT0tvSU/tewGYJeaB4RvNv0dmt7agekA\nx1C1dPDvf6r33Bzk7Pyq6HE315eVr6Q9zQRFAFizUULylJ+z9h9rANif0lMAAACCoAgAAEAQFAEA\nAAh6FAGA4ZVuaPG2b//9jwpXAtAGdxQBAAAIgiIAAABBUAQAACDoUQQAYKea51yOcr4o9EhQBCg0\nyoLGgenr8+rZxdKXAM27fv2y6HF3fd499LlZ/hnpvdqr0vnTE6WnAAAABEERAACAICgCAAAQ9Cg2\napTG8VF+zlLGAwAYXe2+9ylrorPzq0n/18315aTve1vN8fji9JudX2smKGq0B6BH/ghzuJpjOMci\nDd42Zb76nKBHSk8BAAAIgiIAAAChmdJTAABo3SfffvV///afv/x1gSuBugRFAAAOpg8P1kXpKQAA\nAEFQBAAAIAiKAAAAhGZ6FEsPkpyj/v388dOix12/fnnw/wUAANCbZoIiaUogLg3bPRqlQd7rvoxR\n5ldNxhAYhR1O2zHH756b68sZrqSuqT/noWtEpacAAAAEQREAAICg9BQAWIW7DkLfRfkgwP3cUQQA\nACAIigAAAARBEQAAgCAoAgAAEGxmAwAAFZVutGSTJVoiKAIAsNMch5oD/VF6CgAAQBAUAQAACF2X\nnp4/frr0JQAAAKzOyXa73fnFN29+3v3FBtQOitevX1Z9/padnr53coz/p/U5Rj3mGLWZY+Mp3TBk\ns5ln0xBzjNrMMWq7b44pPQUAACAIigAAAARBEQAAgND1ZjYALXGgMqMp3StgSs//Pv2GrNfZ+VXR\n426uLytfCYzHHUUAAACCoAgAAEAQFAEAAAh6FFklZ2zC3Wq+N7wvALitZr+xnv+6ug6KFiUAAADz\nU3oKAABAEBQBAAAIgiIAAABBUAQAACB0vZkNALBOdjMEWJY7igAAAARBEQAAgKD0FODIph4+rBSP\nXp0/fjrp+5yXvC5n51dLXwKwB0GRoVmEMJopc37qIh8ASvljaHuUngIAABAERQAAAIKgCAAAQNCj\nCAAH+PLHz4oe9/kHX1e+kvW5PbbGEOB4BEUAYHG3N1oqDeAA1KH0FAAAgCAoAgAAEJSeAgCLU2oK\n0BZBEWAmDgtmNLf7CqcSEsdwc3259CUAe1B6CgAAQBAUAQAACIIiAAAAQY8iNOzs/Krq8+sX+ZUD\n0+dRMo7GEAD6ICiySnNtsACUvZ9sRgLAXWz01i+lpwAAAARBEQAAgCAoAgAAEPQoAsARTOnj7HHz\nH/2qAOWePH9R9flfPbuY/L2CInTood1Ka++WCgDAuik9BQAAIAiKAAAAhGZKTz/59quqz+8MF9bq\nrvfOh5t3FrgSANaotO+0x57aYzGG9KiZoAjM54eLX/xxBACAyZSeAgAAEARFAAAAgqAIAABA0KO4\nIj02Svd4zb24PbbG8HBTDxLvbewdmA7AkqwP5/PQWH5x+s3OrzUTFEs23qi9Myr06K73joU+HI+F\nSjIeAPN79exi7+85dD2o9BQAAIAgKAIAABCaKT0tuzX6cfXraMXZ+VXR426uLytfCbWUlFJ/uHnn\nCFcCx+OzDQD60ExQLPHpR99H74OeRdbuh4tfivp3oQff/e2PS18CAFBI6SkAAABBUAQAACAIigAA\nAISuehRhRCW9uHP0MT55/qLocVPO8QFg/aac2dbjuZvOKmYUgiIwvB4XKjX9659/Lnrcvn80+G5T\ntuMpALA8pacAAAAEQREAAIDQXelp1oV/vNh1tGKUOvlR+h6mGmUeAAD8Zur6p8c14hJrve6C4ts+\n/ej7Ll/oEjfXlw8+Rjjo210b0HhNAQDGUdLvv9T6UOkpAAAAQVAEAAAgCIoAAACErnsUN5vymt21\n9jICrNHZ+f5nLpb0dgMAZboPiiMTfgGA2qw3kvFgFEpPAQAACIIiAAAAQelpo6b055TSxzOmJ89f\nLH0JrMy+c+pxpesAAObXTFBU7z0mr3uaazyEQgCgV9aHaanxUHoKAABAEBQBAAAIgiIAAABBUAQA\nACA0s5kNyc6kwNq8/sPvN6+eXSx9GQBAAXcUAQAACIIiAAAAQVAEAAAgnGy326WvAQAAgIa4owgA\nAEAQFAEAAAiCIgAAAEFQBAAAIAiKAAAABEERAACA8D8FjwRSVQxxagAAAABJRU5ErkJggg==\n", 111 | "text/plain": [ 112 | "" 113 | ] 114 | }, 115 | "metadata": {}, 116 | "output_type": "display_data" 117 | } 118 | ], 119 | "source": [ 120 | "fig, axes = plt.subplots(ncols=6, nrows=2, figsize=(16, 5))\n", 121 | "for ax in axes.T:\n", 122 | " img, grp = generate_corners_image(28, 28, 1, 4)\n", 123 | " plot_input_image(img, ax[0])\n", 124 | " plot_groups(grp, ax[1])" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "# Save as dataset" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 5, 137 | "metadata": { 138 | "collapsed": true 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "import h5py\n", 143 | "import os\n", 144 | "import os.path\n", 145 | "\n", 146 | "data_dir = os.environ.get('BRAINSTORM_DATA_DIR', '.')" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 6, 152 | "metadata": { 153 | "collapsed": false 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "nr_train_examples = 60000\n", 158 | "nr_test_examples = 10000\n", 159 | "nr_single_examples = 5000\n", 160 | "\n", 161 | "width = 28\n", 162 | "height = 28\n", 163 | "nr_squares = 1\n", 164 | "nr_corners = 4\n", 165 | "\n", 166 | "data = np.zeros((1, nr_train_examples, height, width, 1), dtype=np.float32)\n", 167 | "grps = np.zeros_like(data)\n", 168 | "for i in range(nr_train_examples):\n", 169 | " data[0, i, :, :, 0], grps[0, i, :, :, 0] = generate_corners_image(width, height, nr_squares, nr_corners)\n", 170 | " \n", 171 | "data_test = np.zeros((1, nr_test_examples, height, width, 1), dtype=np.float32)\n", 172 | "grps_test = np.zeros_like(data_test)\n", 173 | "for i in range(nr_test_examples):\n", 174 | " data_test[0, i, :, :, 0], grps_test[0, i, :, :, 0] = generate_corners_image(width, height, nr_squares, \n", 175 | " nr_corners)\n", 176 | "\n", 177 | "data_single = np.zeros((1, nr_single_examples, height, width, 1), dtype=np.float32)\n", 178 | "grps_single = np.zeros_like(data_single)\n", 179 | "for i in range(nr_single_examples // 2):\n", 180 | " data_single[0, i, :, :, 0], grps_single[0, i, :, :, 0] = generate_corners_image(width, height, 1, 0)\n", 181 | "for i in range(nr_single_examples // 2, nr_single_examples):\n", 182 | " data_single[0, i, :, :, 0], grps_single[0, i, :, :, 0] = generate_corners_image(width, height, 0, 1)\n", 183 | "\n", 184 | "shuffel_idx = np.arange(nr_single_examples)\n", 185 | "np.random.shuffle(shuffel_idx)\n", 186 | "data_single = data_single[:, shuffel_idx]\n", 187 | "grps_single = grps_single[:, shuffel_idx]" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 7, 193 | "metadata": { 194 | "collapsed": true 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "with h5py.File(os.path.join(data_dir, 'corners.h5'), 'w') as f:\n", 199 | " single = f.create_group('train_single')\n", 200 | " single.create_dataset('default', data=data_single, compression='gzip', chunks=(1, 100, height, width, 1))\n", 201 | " single.create_dataset('groups', data=grps_single, compression='gzip', chunks=(1, 100, height, width, 1))\n", 202 | " \n", 203 | " train = f.create_group('train_multi')\n", 204 | " train.create_dataset('default', data=data, compression='gzip', chunks=(1, 100, height, width, 1))\n", 205 | " train.create_dataset('groups', data=grps, compression='gzip', chunks=(1, 100, height, width, 1))\n", 206 | " \n", 207 | " test = f.create_group('test')\n", 208 | " test.create_dataset('default', data=data_test, compression='gzip', chunks=(1, 100, height, width, 1))\n", 209 | " test.create_dataset('groups', data=grps_test, compression='gzip', chunks=(1, 100, height, width, 1))" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "# References\n", 217 | "[1] David P. Reichert and Thomas Serre, [Neuronal Synchrony in Complex-Valued Deep Networks](http://arxiv.org/abs/1312.6115), ICLR 2014\n", 218 | "\n" 219 | ] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "Python 3", 225 | "language": "python", 226 | "name": "python3" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.4.3" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 0 243 | } 244 | -------------------------------------------------------------------------------- /Datasets/MNIST+Shape.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stderr", 12 | "output_type": "stream", 13 | "text": [ 14 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib-1.5.0+783.g23bc09d-py3.4-linux-x86_64.egg/matplotlib/__init__.py:877: UserWarning: axes.color_cycle is deprecated and replaced with axes.prop_cycle; please use the latter.\n", 15 | " warnings.warn(self.msg_depr % (key, alt_key))\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "from plot_tools import plot_groups, plot_input_image\n", 23 | "%matplotlib inline\n", 24 | "np.random.seed(985619)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "import h5py\n", 36 | "import os\n", 37 | "import os.path\n", 38 | "\n", 39 | "data_dir = os.environ.get('BRAINSTORM_DATA_DIR', '.')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "# MNIST + Shape\n", 47 | "\n", 48 | "Binary images containing a thresholded MNIST digit and one random shape each. Introduced in [1] to investigate binding in deep networks." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "square = np.array(\n", 60 | " [[1, 1, 1, 1, 1, 1, 1],\n", 61 | " [1, 1, 1, 1, 1, 1, 1],\n", 62 | " [1, 1, 0, 0, 0, 1, 1],\n", 63 | " [1, 1, 0, 0, 0, 1, 1],\n", 64 | " [1, 1, 0, 0, 0, 1, 1],\n", 65 | " [1, 1, 1, 1, 1, 1, 1],\n", 66 | " [1, 1, 1, 1, 1, 1, 1]])\n", 67 | "\n", 68 | "triangle = np.array(\n", 69 | " [[0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],\n", 70 | " [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],\n", 71 | " [0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n", 72 | " [0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0],\n", 73 | " [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0],\n", 74 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", 75 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n", 76 | "\n", 77 | "shapes = [square, triangle, triangle[::-1, :].copy()]" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": { 84 | "collapsed": true 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "# Load the MNIST Dataset as prepared by the brainstorm data script\n", 89 | "# You will need to run brainstorm/data/create_mnist.py first\n", 90 | "with h5py.File(os.path.join(data_dir, 'MNIST.hdf5'), 'r') as f:\n", 91 | " mnist_digits = f['normalized_full/training/default'][0, :]\n", 92 | " targets = f['normalized_full/training/targets'][:]\n", 93 | " mnist_digits_test = f['normalized_full/test/default'][0, :]\n", 94 | " targets_test = f['normalized_full/test/targets'][:]" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 5, 100 | "metadata": { 101 | "collapsed": true 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "def generate_mnist_shape_img(digit_nr, nr_shapes=1, test=False):\n", 106 | " if digit_nr is None:\n", 107 | " img = np.zeros((28, 28), dtype=np.float)\n", 108 | " elif not test:\n", 109 | " img = (mnist_digits[digit_nr].reshape(28, 28) > 0.5).astype(np.float)\n", 110 | " else:\n", 111 | " img = (mnist_digits_test[digit_nr].reshape(28, 28) > 0.5).astype(np.float)\n", 112 | " grp = (img > 0.5).astype(np.float)\n", 113 | " mask = grp.copy()\n", 114 | " k = 2\n", 115 | " \n", 116 | " for i in range(nr_shapes):\n", 117 | " shape = shapes[np.random.randint(0, len(shapes))]\n", 118 | " sy, sx = shape.shape\n", 119 | " x = np.random.randint(0, 28-sx+1)\n", 120 | " y = np.random.randint(0, 28-sy+1)\n", 121 | " region = (slice(y,y+sy), slice(x,x+sx))\n", 122 | " img[region][shape != 0] = 1\n", 123 | " mask[region][shape != 0] += 1\n", 124 | " grp[region][shape != 0] = k \n", 125 | " k += 1\n", 126 | " \n", 127 | " grp[mask > 1] = 0\n", 128 | " return img, grp\n", 129 | " " 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 6, 135 | "metadata": { 136 | "collapsed": false 137 | }, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA4oAAAElCAYAAACiWBzqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAD/FJREFUeJzt3cGR3MYVBuChi1fpRjqADULFVHh0lQJgJAxAVT7y4hAc\nAEtBKACTN+nu9WFL1j4sZgcDdKNfd3/fTaS4gx286cY/8x7mzePj4wUAAAD+9LfWBwAAAEAugiIA\nAACBoAgAAEAgKAIAABAIigAAAASCIgAAAMHb1/7y+/c/fHfGpN69++HNGY+jxualxqhNjVGbGqM2\nNUZtr9WYTxQBAAAIBEUAAAACQREAAIDg1RlFAOp7//7HF3/27dvvDY4EAOCJTxQBAAAIBEUAAAAC\nQREAAIBAUAQAACAQFAEAAAgERQAAAAJBEQAAgEBQBAAAIHjb+gAARvb+/Y/F/t23b78fPRyobm/N\nr1HzAO34RBEAAIBAUAQAACAQFAEAAAjMKF5xZMbCTAXMqeRsFmR0do0vH8/+Crym9T482hrlE0UA\nAAACQREAAIBAUAQAACAQFAEAAAimvJlN60FXoH/WEUaXscbXjmm0m0cAOdefLUZbo6YMigBAPfde\nGPV6UQgwMq2nAAAABIIiAAAAwRStp1pa4KXR+uhry7CO+PJxaspQ43t4XdDa1teO2nzS61qzV8/X\nW1MExZJ6ObFAHrfWjdk2TfpWYx/c8jO9TgDOpfUUAACAQFAEAAAgEBQBAAAIzChCIxnnbXoeuC4p\n47lZ43yxVy81zlhGuemL1899PF/9EhSfyb4wAflZRwCAEWg9BQAAIBAUAQAACARFAAAAguFmFA3M\nkoE67EvJ89X63C8f38wkl0v7umQ+o9TcKL8HuZSqq9p7/HBBEdhPqAAAenHrukXQP0brKQAAAIGg\nCAAAQKD1FArQ2gDAzEbYB0f4HaCkaYOiWSyyKlGbNjtgD3sjlOP19GT5PBy5RvGcnkvrKQAAAIGg\nCAAAQCAoAgAAEHQ9o2gOixay1V2242EOJevOzAn0ZZR9Z5TfA2rpOihC71pfILd+/CzOfB62Xpg4\nN5Q0cj2N/LvRD3W43ZnP1dpj9foGQYsa03oKAABAICgCAAAQaD2FG7K1KGQ7HsajxshAHeaR/Vxk\nOL5ax/D4+Fjl58IW0wRFveNk0LIOvQYAyMbexNmWNZfhjYY1GV4bWk8BAAAIBEUAAAACQREAAIBg\nmhlFAF7KOpsBtGFNAP4kKEIlGYaQASATeyPZrNXk2W+YZH1daD0FAAAgEBQBAAAItJ4ecORj6awf\nMZOLWZH+ZTuH2Y4HLhd12ZLnntl5DVwnKMIOPQT9Ho6Rv2Q/X3uOz+ZLadlfJ4zpzPVPjc+hl/Os\n9RQAAIBAUAQAACAQFAEAAAimmVHMNiuzPJ5eepWpK1ud0rcz60ntjufaOS2xX6mXeZW4/slWP9mO\nh9c5X9t1HRQzfEEm4+shxPdwjPRFTQHUZ60dX8/nWOspAAAAgaAIAABA0HXr6UjWWmZ7/qia27RJ\nU5qZRLY6cv62/Nvn+5da4YgM9ZPhGNjHuTtGUKzABgn0xJtSZKU2GYl6Hs/o51TrKQAAAIGgCAAA\nQCAoAgAAEJhRhJOYV6UXanU8Z59TNcRWaoWS1FNZwwXF0YdKmY+aZiTqmWvURhu1nveMF+y3ftfS\nx6ymxzLj+dR6CgAAQCAoAgAAEAzXegoZZGy5IbdlzZzZIqVex+J8AjOw1tUnKCYxY98z69QCvVK7\nfSlxvlyocc2W+qpdP/fW+Nr/v/UYrX99c/7WaT0FAAAgEBQBAAAIBEUAAAACM4qFmdeYk/NOaWqK\nrc6slZqPteVnL+eIjhyPmaTzWdc4Qv2cT1BsxAYF9M46BmRRaz2yzjEzracAAAAEgiIAAACB1tMD\njvRK1+qz1iLRv3u/eB1a2bOOqed+ZJwHynhMQHle6zkIigAwgTND+q3HOvsi0BsUwHPWhG20ngIA\nABAIigAAAASCIgAAAIEZxSsM0XLN2bWx9nh662FOpdYfexy1qTHukeHaagSlrw8FRYCKtizaJTes\nM28i4g0LLpe6daDGuEZtQH1aTwEAAAgERQAAAIKuW0+vtVBl+/4m+paxXpbHpAUHuKXVWnbkca11\n/ci4VwLHdB0UM1huWlsXyns3OwswjCvTF6HDc+qFDNQhtKH1FAAAgEBQBAAAIBAUAQAACIacUTTP\nx1691s7acZvpAFqouY5a6wDOM2RQLMXmQ23u0AvcY23NKLlO2PfIQB3OZ+/NIalL6ykAAACBoAgA\nAECg9ZRpjd7W4IuqgdoyrKPWOoA6ug6KR2Y1bCS0cG/d1Z5HAvrX635mRhu4Zu+6tnfd6HUdrU3r\nKQAAAIGgCAAAQCAoAgAAEHQ9o9gzsxdzannefVE1cFQPe5e1DqCM4YKizQAActmzN/sC7r64/iID\ndViW1lMAAAACQREAAIBAUAQAACAYbkYRADiu5Exg6/lCN7gBuJ9PFAEAAAgERQAAAAJBEQAAgEBQ\nBAAAIHAzm8IMx/fDuQIAgHU+UQQAACAQFAEAAAgERQAAAAJBEQAAgMDNbACAF9zwC2BuPlEEAAAg\nEBQBAAAIBEUAAACCN4+Pj62PAQAAgER8oggAAEAgKAIAABAIigAAAASCIgAAAIGgCAAAQCAoAgAA\nEAiKAAAABIIiAAAAgaAIAABAICgCAAAQCIoAAAAEgiIAAACBoAgAAEAgKAIAABAIigAAAASCIgAA\nAIGgCAAAQCAoAgAAEAiKAAAABIIiAAAAgaAIAABAICgCAAAQCIoAAAAEgiIAAADB29f+8vv3Px7P\nOhByeffuhzdnPI4am5caozY1Rm1qjNrUGLW9VmM+UQQAACAQFAEAAAgERQAAAIJXZxQBqO+nL7+8\n+LNfP/7c4EgAAJ74RBEAAIBAUAQAACAQFAEAAAgERQAAAAJBEQAAgEBQBAAAIBAUAQAACARFAAAA\ngretDwBgZB///o/b/9PnDy/+6Kcvv7z4s18//lzikKCqtdrdS80DtOMTRQAAAAJBEQAAgEBQBAAA\nIBAUAQAACNzM5oojw/iG72FOm25cAx0reaOaPY9nfwVec/YatTTaGiUoAjT28Onr5ct//vn//269\n0QEAaD0FAAAgEBQBAAAIpmw91dYFHGUekdFl3CvXjmm0mSAg5/qzxWhr1JRBESCznjcVuFzur+Fe\nLwoBRqb1FAAAgEBQBAAAIBAUAQAACKaYUTT7AC+NNnBdW+ub16w9/vPvXoSjet0rl8dtHeNsW187\navNJr2vNXj1fb00RFEvq5cQCeQh0jKTGPrjlZ852cQnQmtZTAAAAAkERAACAQOspNJKxjarnPvqS\nzp5H3Pu8m1tkr4zrD+MbZZbP6+c+nq9+CYrPZF+YgPwENQBgBFpPAQAACARFAAAAAkERAACAYLgZ\nRQOzZKAO+1Ly5jVbflbNOcbl45uZ5HKxJnG+UWpulN+DvI5cg9Te44cLisB+bujUhucdAO53a/88\n+y7mo705q/UUAACAQFAEAAAg0HoKBZhhoKSzW2UAjhphHxzhd4CSpg2KZoLIqkRt2uyAPeyNUI7X\n05Pl83DkGiXbczraTOKS1lMAAAACQREAAIBAUAQAACDoekbRHBYtZKu7bMfDbbfO2cNJx3FEybrL\nNnMCvG6UfWeU34N29tZQL/te10ERetd6oWj9+FmcOYy+ZVP57fMH54aiRq6nkX83+qEOtzvzuVrb\n3/fcWfy3zx9KHM4hLW6co/UUAACAQFAEAAAg0HoKN2SbYch2PIxHjZGBOswj+7m41kp4ZqvennbG\nLf79339V+bm0VateSpsmKOodJ4OWdeg1AEA2I+yLvVz082T5BsKW8/fw6euLf7f1DZRlnW2tlxYz\niUtaTwEAAAgERQAAAAJBEQAAgGCaGUUAXsp+kwrgXNaEyPwhR/VcQ4IiVOLmMQAQ1dgbM9z0g37t\nrZ8ttdzTjWvWaD0FAAAgEBQBAAAItJ4ecKSPX1siW5gV6V+2c5jteOByUZctjfDcX/sd7r3WWvs5\nD7uOiJ7s/T7EpZ5nEa8RFGGHHoJ+D8fIX7Kfrz3HN8IFKLlkf50wnt8+f9hVd3tDQ9ZZNcrq5Txr\nPQUAACAQFAEAAAgERQAAAIJpZhSzzcosj8fcBZdLvjqlb2fWk9odT6kbhNzzsxlfieufrfVz1rXW\niDcxGdne9Wft341+/d51UFw7OTYfSuthEejhGOmLmgJm8/Dp6/pfVFwPe7mpCfv1fI61ngIAABAI\nigAAAARdt56OZMa+59lpk6Y0M4lsdeT8bfm3z/cvtcIRJetn+bOutpoumEHsl/XnGEGxAhsk0BNv\nSpGV2mQkPc+qcduI51frKQAAAIGgCAAAQCAoAgAAEJhRhJOYV6UXanU8Z59TNcRWaoWS1FNZwwVF\ng++MRk0zEvXMNWqjjVrPe8YL9he/6+K/S9/ddMSbm8xsxjVK6ykAAACBoAgAAEAwXOspZJCx5Ybc\nljVzq8Wl5pdQ0zfnE5iBta4+QTGJGfueWacW6JXa7UuJ8+VCjWu21Fft+rm3xtdmCrfOLZpH7Jv9\na53WUwAAAAJBEQAAgEBQBAAAIDCjWJh5jTk575SmptjqzFqp+VhbfvZyjujI8ZhJOp91jSPUz/kE\nxUZsUEDvrGNAFrXWIzepYWZaTwEAAAgERQAAAAJBEQAAgMCM4gFHhmq3foHrvfTS929ZG84pWe1Z\nA8019iPjjSMyHhNQntd6DoJiBc8vhGoFQvr38Olr60MAJnJmSL/1WGdfBHqDAnjOmrCN1lMAAAAC\nQREAAIBA6+kVeqO55uzaWGtfNrcIcyq1/tjjqE2NcY+z62XU+izdUisowknMJM5py6JdcsM6czbM\njAeXy7462LoeelOMa6w/UJ/WUwAAAAJBEQAAgEBQBAAAIOh6RvHarE3r72/y3YljyTjwvKwxczzA\nLa3WsiN7orWuHxn3SuCYroNiBstQunVDvHezEz774sY13CPTF6HDc+qFDNQhtKH1FAAAgEBQBAAA\nIBiy9VSfPHv1WjtrrclmeYAWao5KWOsAzjNkUCxFTzxb7Z1JvHWBYzYVeG5tXyr5Bpd9jwzU4XyW\n57zXN+5Ho/UUAACAQFAEAAAgEBQBAAAIzCgyrdH7331RNVBbhjlqax1AHV0HxSND/Qal2WvvjWsu\nl/svYNb+/wwXZkAeve5nbuYFXLN3Xdv7IUCv62htWk8BAAAIBEUAAACCrltPe6alZk4tz7svqgaO\n6mHvstYBlDFcUNRjTA3P6+rjgRlFgBntCWrLf9NDSJ2Z6y8yUIdlaT0FAAAgEBQBAAAIBEUAAACC\n4WYUoQazMcBsSq57rddQN7gBuJ9PFOGGBzevAQBgMoIiAAAAgaAIAABAICgCAAAQuJlNYYbj+7H5\nS1l9eSsAAJPxiSIAAACBoAgAAEAgKAIAABAIigAAAARuZgMAvODmbABz84kiAAAAgaAIAABAICgC\nAAAQvHl8fGx9DAAAACTiE0UAAAACQREAAIBAUAQAACAQFAEAAAgERQAAAAJBEQAAgOB/nUthqfNZ\nP4IAAAAASUVORK5CYII=\n", 142 | "text/plain": [ 143 | "" 144 | ] 145 | }, 146 | "metadata": {}, 147 | "output_type": "display_data" 148 | } 149 | ], 150 | "source": [ 151 | "fig, axes = plt.subplots(ncols=6, nrows=2, figsize=(16, 5))\n", 152 | "for ax in axes.T:\n", 153 | " digit_nr = np.random.randint(0, 60000)\n", 154 | " img, grp = generate_mnist_shape_img(digit_nr, 1)\n", 155 | " plot_input_image(img, ax[0])\n", 156 | " plot_groups(grp, ax[1])" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "# Save as HDF5 Dataset" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 7, 169 | "metadata": { 170 | "collapsed": false 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "np.random.seed(985619)\n", 175 | "nr_shapes = 1\n", 176 | "nr_training_examples = 60000\n", 177 | "nr_test_examples = 10000\n", 178 | "nr_single_examples = 10000\n", 179 | "\n", 180 | "\n", 181 | "data = np.zeros((1, nr_training_examples, 28, 28, 1), dtype=np.float32)\n", 182 | "grps = np.zeros_like(data)\n", 183 | "for i in range(nr_training_examples):\n", 184 | " data[0, i, :, :, 0], grps[0, i, :, :, 0] = generate_mnist_shape_img(i, nr_shapes)\n", 185 | " \n", 186 | "data_test = np.zeros((1, nr_test_examples, 28, 28, 1), dtype=np.float32)\n", 187 | "grps_test = np.zeros_like(data_test)\n", 188 | "for i in range(nr_test_examples):\n", 189 | " data_test[0, i, :, :, 0], grps_test[0, i, :, :, 0] = generate_mnist_shape_img(i, nr_shapes, test=True)\n", 190 | " \n", 191 | "\n", 192 | "data_single = np.zeros((1, nr_single_examples, 28, 28, 1), dtype=np.float32)\n", 193 | "grps_single = np.zeros_like(data_single)\n", 194 | "for i in range(nr_single_examples // 2):\n", 195 | " digit_nr = np.random.randint(0, 60000)\n", 196 | " data_single[0, i, :, :, 0], grps_single[0, i, :, :, 0] = generate_mnist_shape_img(digit_nr, 0)\n", 197 | "for i in range(nr_single_examples // 2, nr_single_examples):\n", 198 | " data_single[0, i, :, :, 0], grps_single[0, i, :, :, 0] = generate_mnist_shape_img(None, 1)\n", 199 | "\n", 200 | "shuffel_idx = np.arange(nr_single_examples)\n", 201 | "np.random.shuffle(shuffel_idx)\n", 202 | "data_single = data_single[:, shuffel_idx]\n", 203 | "grps_single = grps_single[:, shuffel_idx]" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 8, 209 | "metadata": { 210 | "collapsed": false 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "with h5py.File(os.path.join(data_dir, 'mnist_shape.h5'), 'w') as f:\n", 215 | " single = f.create_group('train_single')\n", 216 | " single.create_dataset('default', data=data_single, compression='gzip', chunks=(1, 100, 28, 28, 1))\n", 217 | " single.create_dataset('groups', data=grps_single, compression='gzip', chunks=(1, 100, 28, 28, 1))\n", 218 | " \n", 219 | " train = f.create_group('train_multi')\n", 220 | " train.create_dataset('default', data=data, compression='gzip', chunks=(1, 100, 28, 28, 1))\n", 221 | " train.create_dataset('groups', data=grps, compression='gzip', chunks=(1, 100, 28, 28, 1))\n", 222 | " train.create_dataset('targets', data=targets, compression='gzip', chunks=(1, 100, 1))\n", 223 | " \n", 224 | " test = f.create_group('test')\n", 225 | " test.create_dataset('default', data=data_test, compression='gzip', chunks=(1, 100, 28, 28, 1))\n", 226 | " test.create_dataset('groups', data=grps_test, compression='gzip', chunks=(1, 100, 28, 28, 1))\n", 227 | " test.create_dataset('targets', data=targets_test, compression='gzip', chunks=(1, 100, 1))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": {}, 233 | "source": [ 234 | "# References\n", 235 | "[1] David P. Reichert and Thomas Serre, [Neuronal Synchrony in Complex-Valued Deep Networks](http://arxiv.org/abs/1312.6115), ICLR 2014\n", 236 | "\n" 237 | ] 238 | } 239 | ], 240 | "metadata": { 241 | "kernelspec": { 242 | "display_name": "Python 3", 243 | "language": "python", 244 | "name": "python3" 245 | }, 246 | "language_info": { 247 | "codemirror_mode": { 248 | "name": "ipython", 249 | "version": 3 250 | }, 251 | "file_extension": ".py", 252 | "mimetype": "text/x-python", 253 | "name": "python", 254 | "nbconvert_exporter": "python", 255 | "pygments_lexer": "ipython3", 256 | "version": "3.4.3" 257 | } 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 0 261 | } 262 | -------------------------------------------------------------------------------- /Datasets/Multi-MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stderr", 12 | "output_type": "stream", 13 | "text": [ 14 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib-1.5.0+783.g23bc09d-py3.4-linux-x86_64.egg/matplotlib/__init__.py:877: UserWarning: axes.color_cycle is deprecated and replaced with axes.prop_cycle; please use the latter.\n", 15 | " warnings.warn(self.msg_depr % (key, alt_key))\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "from plot_tools import plot_groups, plot_input_image\n", 23 | "import h5py\n", 24 | "import os\n", 25 | "import os.path\n", 26 | "%matplotlib inline\n", 27 | "np.random.seed(9825619)\n", 28 | "\n", 29 | "data_dir = os.environ.get('BRAINSTORM_DATA_DIR', '.')" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": { 36 | "collapsed": true 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "# Load the MNIST Dataset as prepared by the brainstorm data script\n", 41 | "# You will need to run brainstorm/data/create_mnist.py first\n", 42 | "with h5py.File(os.path.join(data_dir, 'MNIST.hdf5'), 'r') as f:\n", 43 | " mnist_digits = f['normalized_full/training/default'][0, :]\n", 44 | " mnist_targets = f['normalized_full/training/targets'][:]\n", 45 | " mnist_digits_test = f['normalized_full/test/default'][0, :]\n", 46 | " mnist_targets_test = f['normalized_full/test/targets'][:]" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "# Multi MNIST\n", 54 | "\n", 55 | "Binary images containing three thresholded MNIST digits. (Same thresholding as in [1]).\n", 56 | "\n", 57 | "We chose the image size to be $48 \\times 48 = 2304$ because that is roughly $28 \\times 28 \\times 3 = 2352$" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 3, 63 | "metadata": { 64 | "collapsed": false 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "def crop(d):\n", 69 | " return d[np.sum(d, 1) != 0][:, np.sum(d, 0) != 0]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": { 76 | "collapsed": false 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "def generate_multi_mnist_img(digit_nrs, size=(48, 48), test=False, binarize_threshold=0.5):\n", 81 | " if not test:\n", 82 | " digits = [crop(mnist_digits[nr].reshape(28, 28)) for nr in digit_nrs]\n", 83 | " else:\n", 84 | " digits = [crop(mnist_digits_test[nr].reshape(28, 28)) for nr in digit_nrs]\n", 85 | " \n", 86 | " img = np.zeros(size)\n", 87 | " grp = np.zeros(size)\n", 88 | " mask = np.zeros(size)\n", 89 | " k = 1\n", 90 | " \n", 91 | " for i, digit in enumerate(digits):\n", 92 | " h, w = size\n", 93 | " sy, sx = digit.shape\n", 94 | " x = np.random.randint(0, w-sx+1)\n", 95 | " y = np.random.randint(0, h-sy+1)\n", 96 | " region = (slice(y,y+sy), slice(x,x+sx))\n", 97 | " m = digit >= binarize_threshold\n", 98 | " img[region][m] = 1 \n", 99 | " mask[region][m] += 1 \n", 100 | " grp[region][m] = k \n", 101 | " k += 1\n", 102 | " \n", 103 | " grp[mask > 1] = 0 # ignore overlap regions\n", 104 | " return img, grp\n", 105 | " " 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "metadata": { 112 | "collapsed": false 113 | }, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA4oAAAElCAYAAACiWBzqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAE6lJREFUeJzt3T+uLEfVAPDrT06BAD17BazAGRv4JAjIHVpIrMEJAQlr\n+KRPDp0TGIkNkLECVuD35MSkSJfAGtM1nunbM11/Tp36/UL73bl9p6tP16lzuvqj19fXFwAAALj4\nn9EHAAAAQCwSRQAAAAoSRQAAAAoSRQAAAAoSRQAAAAoSRQAAAAof7/3PDx/+5d0Zi3r37mcf9fg9\nxti6jDFaM8ba++STn5/+jPfvv69wJGMYY7RmjNHa3hhTUQQAAKCwW1EEALhWo5K491kzVxkBslBR\nBAAAoCBRBAAAoKD1FAAIZduOqg0VYAyJIgDwtKOJXM3nGgFoT+spAAAABRVFAOAhz7SD3voZVUaA\nuFQUAQAAKKgoAgBh2dgGYAyJIgAwhHZUgLi0ngIAAFCQKAIAAFCQKAIAAFDwjCIAQGJHnvu0URBw\nTUURAACAgkQRAACAgtZTAIDFvdWeqjUV1qOiCAAAQEFFESCo3i8eVzFgpN7jPbva3+fRz6sdR/Z+\nr5gFbUkUAYCwJAMAY2g9BQAAoKCiCBBAhLa7yzGo4NDC6DG+/f1Zx/jo77imI3+LmAVtSRQBABbx\nTFKVKQEFjtN6CgAAQEFFEYDCvepBjfau68/WMpabShTAvCSKAIG1TKRM4mEdZ2LJ5WfFDFiLRLGC\n1oHTintOR8aNc08WJpg868wmNMbdXJ49XytsVAQjSBQBAhgxudn+ThNqWnhrXBt3AHHZzAYAAICC\niuLL7ZaFSKucRzeWsElEXDXaaS6cV2ppHecixVFy8N68fMQJWms5xlrFor1j7hn/JIoA3GVCTku3\nxpfEoZ6e12+v3xVxQR+y0noKAABAYdmK4r2VqEwrVFp0xmo1luzuxhmZYhzcYozDunpf/yPm2j1/\n57KJIgBjWOBgj914n5f12sr6d0F0EsWJqRjG1HNio7rIUb3GpYk9ZxlD8bQ4J84zNY0eT0c3npyN\nRLGS2gNh9ICnj71xYwyQzew3TKAvMQPGspkNAAAABRVFqCBC9U8rMiNFuAaYmzFEDR7JyE2c6GvZ\nRPHew/KCCq08OrZs6MAo4iAwytH4432K0J7WUwAAAArLVhQBsju70n7989cr/VbyGeVWJ5DxWJfv\nk0hmHY+zPxYkUXyZ9+RdH7dWxXhqjS0TIWY1a3wlJ+OxLt8ns7P7/D6JIgCFezdHN032vFWBPvIz\nNX4vcTlXMBeJIkBSugyA2lpUEVUmae2ZMXbrZ1a7l9rMBgAAgIKKIkxgtRUsYH634pbKEazjyNxl\nlWcEZ30Vn0QRYAGPbIh0tt1mppsg87PZ19q02HOt5T1otfGm9RQAAICCimIgZ0v05LLCShWxnI0v\n4hNvaRHXxMp6fJecNXIMuQfVp6IIAABAYdmKouodo1m5JTpjFIAjItwvLsdg/l7PsokiZCMwAiMd\njUG9J5RiI8TgWpyP1lMAAAAK6SqKNVcqe73zJEK5njacW6IwFoF7Zo4PMx/7rHzn60iXKEJ2Wjeo\nyU6n9FbjPWTendifa51nGDdzS5EoulEA1CGe0lqNMWacQi6zXNOzHGctKRLF7KzG8PJiHBCHsQg8\no0Y1mfhG3SMi35siH9sem9kAAABQmKKiaNXpnF6b8qym97g8+vucY54hzsJaZr/mZz/+2dT6vmc6\nbzMdaytTJIqteBie0a6TOmORyCxCMNr7998fipPG6rqc+3F89z81+3ei9RQAAIBCuIriKhWVVf7O\nmc16jt467tlXt6hr1nHOfIw17vGITB6zX+ezH39t4RLFVmYMPDMeM/cdOZ92hCMacQjW5frnLcZI\nbsskinBU9gTt3t8n2APRZY/PxDjHe8ew0r3yyLmIcL6eNfOx95ImUVzpwj3CRj3ze2tMO7cAJXOB\n83p9h+YpXJt5g7+sscdmNgAAABTSVBRnMNPKyGqcm/9+B1lXxSLR/gsQj7nAeGfOQYbz98jf0GMT\nqHCJ4pFWhMyTqcx/WyZ756nX7m3PfHaGIMpzHmnzEocgv0zXeeu/JdN39YjaG+xl/B4z/k1bWk8B\nAAAohKso9rJqZcW7imBNq8Y8xjLuWJ1HDZhZ2ETRBVSH9/LVk2FMGg/syTDGWY9xO7cz9yXnvp8j\n52mV87HK3/nyEjhRpA7JQD2rfJeqzm2sMn7g5cV4Zx7GKtwnUQQAQvGOvTVFWpyMdCzRjPpu3r//\nvntM6P23Rot9NrMBAACgoKLYQZRVAaA/1z8AMCOJ4o6e5WYtDkRhLPY1+vse/fthj/EJbHm3Y19a\nTwEAACgsV1HUBgYA9bm/MgtjNR7nJKblEsXVKbHfppWB3kbtpAaRGafAUavGi55/9xKJolUK+IFr\nAQDacq8liyUSxUe0yNJXXfHIzDkFgHzc32NxPsaymQ0AAAAFFUVYgDaYvnzfAMDsJIrJKdlDHK5H\nMqu9KRhk5D7AEVHGSdpE0U0KAADgOWkTxUdFydxhFNcAMILYAxCTzWwAAAAoqChCUtqvAQB41tKJ\nonYXAKjP/ZWZGK9wm9ZTAAAACukqitrtgBHEHoDYWsfpI5+/929UNokmXaIIHOem1IfvGQCYTYpE\n0Uo+/MC1AABtuMf2t/3OLbr2lyJRBACAWUh6SpfvQzIei81sAAAAKExbUbTiAADAWeaUcNu0iSLw\nHO0uANCP+y6z0noKAABAQUUREtA2M86Z7/7Rn7UqDQD0skyiaILF6lwDAKysxs6ae/dSr3I4z/cW\nyzKJIkAUz05STEIAgF6mTRS3kyRtdwAAHGGhDY6xmQ0AAACFaSuKADOp3flw+Twr4wDj6GojsxSJ\nookSMKujGyMAAPSk9RQAAICCRBEAAIBCitZT4DZt2XGdPTfOLQDQkkQRoLFHnzX0bCIAMJpEERJQ\nXRrn8t33SO6cZwCgF88oAgAAUFBRBBhEiylATvfiu84QZiJRhCs9WwkBgHltEz/zBrLRegoAAEBB\noghP+OSTn1s55BTjBwCITOspQAW12488xwKQg3jOrFQUAQAAKKgoAlR2pLpohRkgF3GdbFQUAQAA\nKEgUAQAAKGg9hTv22ge1l3CUsQIAzEhFEQAAgIKKIhygKgQAwEpUFAEAAChIFAEAACh89Pr6OvoY\nAAAACERFEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAA\ngIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJE\nEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgIJEEQAAgMLHe//zw4d/vfY6EGJ59+5nH/X4PcbY\nuowxWjPG2vvtP/58+jO++ezLCkcyhjFGa8YYre2NMRVFAAAACrsVRQCAazUqiXufNXOVESALFUUA\nAAAKEkUAAAAKWk8BgFC27ajaUAHGkCgCAE87msjVfK4RgPa0ngIAAFBQUQQAHvJMO+itn1FlBIhL\nRREAAICCiiIAEJaNbQDGkCgCAENoRwWIS+spAAAABYkiAAAABYkiAAAABc8oAgAk9r9//P83/83f\n/vT7DkcCzERFEQAAgIJEEQAAgILWUwCAxb3Vnqo1FdajoggAAEBBoggAAEBB6ylAUEd2KqxJaxkj\n/fYffx59CKnUjh9HP692HPn80y/u/r+vv/2q6u8CShJFACCsbz77cvQhACxJoggQQO/q4d4xqCzS\nwuiK4fYayzrGI8SRWvYqidf/RmUR2pAoAgAs4pkkOVMCChxnMxsAAAAKKooA3HTd+lWjvavFZxLX\n6HZTAJ4nUQQIrOWzVNrJYB1nYsnlZ8UMWIvWUwAAAAoqihW0XmHLujvb6o60ZNkWniyO7GAIt2xj\n5aMxUQVsLs/Gie3PaWeHeiSKAAGMWBCyCEVrbyV2nmEEiEui+HL73UqRViHvHcv1JO/6hqsaFcez\nk6FbP+e8UkvrjWVUEantEhPFwTzECVpruSDVKhbtXRc9q+YSRQBu0sJFa7cmWaqM9fTsGuj1uy5x\nSYIJ7dnMBgAAgMKyFcV77ZyRWk7P0qIzVqtV8TMbO8Bbq/DXMdBzjMwm030ceEzvjoQRc+3LfbxH\n18+yiSIAY2hpZc92wqUN9TFZF3bEDBhDojixy6pp1hvDrHpObFQXOerI8zw1KjGeG+KsYgz99Vfj\nDoQftajSihXUNHpR6d7vn31uJlGspHaypnVmDXsBZHTQYy3f/eHXzX+HqgCP+sVv/mncLMy5h7Fs\nZgMAAEBBRREqiFD9s3kRI2kj4yxjiBq240hFMp8I862VLJsobltFt22envejlUcTOBs60MNlIiUO\nAhEcTe68TxHa03oKAABAYdmKIkB2Z1far3/+eqXfSj4j3BuXNoGry/dJJLN2Vs3+WJBE8WXeNqvr\n49aqGE+twHD5HOeVVlrFQc8IEcms9/uofJ/Mzu7z+ySKAPxor0qogsie60nVkYWyGmPqx8/o8IoX\nzhFDYC4SRYCkttU8OwECNbSoIopJtPZMh9etn1mtymgzGwAAAAoqijCB1VawaGtvk4pfdjwOcrsV\nt848t61tEeZyZEOkvQp1prnP9m+ZaWMbiSLAAo7uDPndH35988b9yCRdGxlnPTKG7Hq6tnst9qyr\nZSK22saRWk8BAAAoqCgGYjWUrRVWqoC1tHinmHtnPb5LzjKGcpEoTsY7i4BWzsYXLac86uyY+c4r\nMWBpMz3vN6NlE8UjfewmPbSkYkh0nvehFWMLchldSdz+fkWVepZNFCEbq2rASEdjUO8k0aQRYqh1\nLZrv9GMzGwAAAArpKoo1Vyq3n9WyDXV0uZ52tJcShTgD3DNzfNDG3N/M44XHpEsUITstF9RkAxt6\nq/HeO+9O7E8LL8+oOW7Mf/rTegoAAEAhRUVR2wFAHeIprdUYYyqJkMuta/rj3z3+M62t9khRikQx\nO+0ea/jmsy93g55xQBTaTRnJuxPnVaPtmPhqzVf+/Zd3D31W5HnSrG2zUySKgsk529WPWQdqRL1X\nlY6unEUOlMQlzsJaZq/Kill91RovM4271aqHt0yRKLZyWdkSbBhF4s5MVBIZ7etvvzo00bRgti5x\nahzX3U/NPs+zmQ0AAACFcBXFVap7M5XeVzVry8FbY8uKH1urxFzGM9a4p9d7q2lv9vntrHO/VsIl\niq3MGHhM6HM50n6wPeezB1vi++X//f3NfzNj7ATqMA/hLSPHyMe/+zDsd69C6ykAAACFZSqKcFT2\ntoN7lUorx0B0Oi3yi9CivDcPmH1zkkccud5mviazz/dqSJMoao8qXQKZi2BebyVuMwdngEcdeX+i\nBa/zen2Hdp7n2szXb9YFhDSJ4gxM7OOSUP93fM4cqGdxb7xlvdEAzMBcYLy9ufLf/vT73XOU4fw9\nsnDSYxOocInikRWmzNVDk/Q57E3ot4Gq5cT/mbFisWJdN8fL5r/ZdRDWkmm+0TpmrbqIV3uDvUxj\n7iL72LCZDQAAAIVwFcVeVu2J71XtAmJRMWSEVe+1cOFRg3oytJbOJmyiaCJTxzYQucAely2Qe08j\n94i5RGQDm/y2sefRhQVxq58j84eW12KkdyZmmxvuCZsoUofksL7s3+n2BmACVk+kcXNvMmbSRS0q\nicwiUmyGaCSKcMNKq0UA0VwWLXQ+rCXSYpV5wH0jFpH//Zd3b+562kLvcRDttTE2swEAAKCgotiB\nFVFYl7YmAGBGEsUdPVsgPAtGFMZiX6PbmyK1esE18QjYqr1J4+h7cHRaTwEAACgsV1GM8nAosV3G\niWoLwDHur8zCIwHxeEwrpuUSxdUpsd+2/V5qTHZ8z7xl1E5qEJlxChy16lyrZ5xcIlG0ysmzbo6d\nv/6q/4FUYsUOANpSsSSLJRLFR7TI0j2Mn8svfvNPq94AkNCqVaqozKHHspkNAAAABRVFWICW0760\nHQEAs5MoJqeFAuJwPZLZtiXf3gBwm/sAR0R5xEnrKQAAAIW0FUWrmQAAAM9Jmyg+KkqJF0axsxgw\ngvsvQEwSRUjKBjYAADxr6UTRKiYA1Of+ykxsMAO32cwGAACAQrqKok1sgBG8OxEgttZx+sjn7/0b\nlU2iSZcoAsfZwKYPN38AYDZaTwEAACikqChqN4Uf2OkUANrwiEF/23mNLqj+UiSKAAAwC48klC5J\noAXvWKZNFFURAQA4S6UQbps2UQSeo3UDAPpRPWRWNrMBAACgoKIICejpH+dMy9KjP2tVGgDoZZlE\n8etvvxp9CDCUllMAVnZZbDuzwLe3YLf9XAt7zzFXiUXrKQAAAIVlKooAUTy7mm21GgDoZdpEcdtK\n6lUZAAAcYaENjpk2UQSYSe33dF0+z4QHYBzvYCSzFImijWqAWR3dGAEAoCeb2QAAAFCQKAIAAFBI\n0XoK3OZ9RHGdfbbQs4kAQEsSRYDGHn3W0LOJAMBoEkVIQOVwnEtlr0dyp4oIAPTiGUUAAAAKKooA\ng2gxBcjpXnzXGcJMJIpw5fJezs8//WLwkQAAkW0TP4t/ZKP1FAAAgIJEEZ7w+adfqDhyipVnACAy\nracAFdRuP/IcC0AO4jmzUlEEAACgoKIIUNmR6qIVZoBcxHWyUVEEAACgIFEEAACgoPUU7ri8T/Hl\n5afvVNz+P9ijFQkAmJGKIgAAAAWJIgAAAAWtp3CAVlMAAFaioggAAEDho9fX19HHAAAAQCAqigAA\nABQkigAAABQkigAAABQkigAAABQkigAAABQkigAAABT+A7ahjmU84PXUAAAAAElFTkSuQmCC\n", 118 | "text/plain": [ 119 | "" 120 | ] 121 | }, 122 | "metadata": {}, 123 | "output_type": "display_data" 124 | } 125 | ], 126 | "source": [ 127 | "fig, axes = plt.subplots(ncols=6, nrows=2, figsize=(16, 5))\n", 128 | "for ax in axes.T:\n", 129 | " digit_nrs = np.random.randint(0, 60000, 3)\n", 130 | " img, grp = generate_multi_mnist_img(digit_nrs)\n", 131 | " plot_input_image(img, ax[0])\n", 132 | " plot_groups(grp, ax[1])" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "## Save as HDF5" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 6, 145 | "metadata": { 146 | "collapsed": false 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "np.random.seed(36520)\n", 151 | "nr_digits = 3\n", 152 | "nr_training_examples = 60000\n", 153 | "nr_test_examples = 10000\n", 154 | "nr_single_examples = 60000\n", 155 | "size = (48, 48)\n", 156 | "\n", 157 | "data = np.zeros((1, 60000) + size + (1,), dtype=np.float32)\n", 158 | "grps = np.zeros_like(data)\n", 159 | "targets = np.zeros((1, 60000, nr_digits), dtype=np.int)\n", 160 | "for i in range(60000):\n", 161 | " digit_nrs = np.random.randint(0, 60000, nr_digits)\n", 162 | " data[0, i, :, :, 0], grps[0, i, :, :, 0] = generate_multi_mnist_img(digit_nrs, size=size)\n", 163 | " targets[0, i, :] = mnist_targets[0, digit_nrs, 0]\n", 164 | " \n", 165 | "data_test = np.zeros((1, 10000) + size + (1,), dtype=np.float32)\n", 166 | "grps_test = np.zeros_like(data_test)\n", 167 | "targets_test = np.zeros((1, 10000, nr_digits), dtype=np.int)\n", 168 | "for i in range(10000):\n", 169 | " digit_nrs = np.random.randint(0, 10000, nr_digits)\n", 170 | " data_test[0, i, :, :, 0], grps_test[0, i, :, :, 0] = generate_multi_mnist_img(digit_nrs, size=size, test=True)\n", 171 | " targets_test[0, i, :] = mnist_targets_test[0, digit_nrs, 0]\n", 172 | " \n", 173 | "data_single = np.zeros((1, nr_single_examples) + size + (1,), dtype=np.float32)\n", 174 | "grps_single = np.zeros_like(data_single )\n", 175 | "targets_single = np.zeros((1, nr_single_examples, 1), dtype=np.int)\n", 176 | "for i in range(nr_single_examples):\n", 177 | " data_single [0, i, :, :, 0], grps_single[0, i, :, :, 0] = generate_multi_mnist_img([i], size=size)\n", 178 | " targets_single[0, i, :] = mnist_targets[0, i, 0]" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 7, 184 | "metadata": { 185 | "collapsed": false 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "with h5py.File(os.path.join(data_dir, 'multi_mnist.h5'), 'w') as f:\n", 190 | " single = f.create_group('train_single')\n", 191 | " single.create_dataset('default', data=data_single, compression='gzip', chunks=(1, 100) + size + (1,))\n", 192 | " single.create_dataset('groups', data=grps_single, compression='gzip', chunks=(1, 100) + size + (1,))\n", 193 | " single.create_dataset('targets', data=targets_single, compression='gzip', chunks=(1, 100, 1))\n", 194 | " \n", 195 | " train = f.create_group('train_multi')\n", 196 | " train.create_dataset('default', data=data, compression='gzip', chunks=(1, 100) + size + (1,))\n", 197 | " train.create_dataset('groups', data=grps, compression='gzip', chunks=(1, 100) + size + (1,))\n", 198 | " train.create_dataset('targets', data=targets, compression='gzip', chunks=(1, 100, nr_digits))\n", 199 | " \n", 200 | " test = f.create_group('test')\n", 201 | " test.create_dataset('default', data=data_test, compression='gzip', chunks=(1, 100) + size + (1,))\n", 202 | " test.create_dataset('groups', data=grps_test, compression='gzip', chunks=(1, 100) + size + (1,))\n", 203 | " test.create_dataset('targets', data=targets_test, compression='gzip', chunks=(1, 100, nr_digits))" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "# References\n", 211 | "\n", 212 | "[1] Bishop, Christopher M. (2006) \"Pattern Recognition and Machine Learning (Information Science and Statistics)\"\n", 213 | "Springer-Verlag New York, Inc., Secaucus, NJ, USA.\n", 214 | "Section 9.3.3 p. 447" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "Python 3", 221 | "language": "python", 222 | "name": "python3" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.4.3" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 0 239 | } 240 | -------------------------------------------------------------------------------- /Datasets/Shapes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stderr", 12 | "output_type": "stream", 13 | "text": [ 14 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib-1.5.0+783.g23bc09d-py3.4-linux-x86_64.egg/matplotlib/__init__.py:877: UserWarning: axes.color_cycle is deprecated and replaced with axes.prop_cycle; please use the latter.\n", 15 | " warnings.warn(self.msg_depr % (key, alt_key))\n" 16 | ] 17 | } 18 | ], 19 | "source": [ 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "from plot_tools import plot_groups, plot_input_image\n", 23 | "%matplotlib inline\n", 24 | "np.random.seed(104174)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "# Shapes Problem\n", 32 | "\n", 33 | "Binary images containing 3 random shapes each. Introduced in [1] to investigate binding in deep networks.\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": { 40 | "collapsed": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "square = np.array(\n", 45 | " [[1, 1, 1, 1, 1, 1, 1],\n", 46 | " [1, 1, 1, 1, 1, 1, 1],\n", 47 | " [1, 1, 0, 0, 0, 1, 1],\n", 48 | " [1, 1, 0, 0, 0, 1, 1],\n", 49 | " [1, 1, 0, 0, 0, 1, 1],\n", 50 | " [1, 1, 1, 1, 1, 1, 1],\n", 51 | " [1, 1, 1, 1, 1, 1, 1]])\n", 52 | "\n", 53 | "triangle = np.array(\n", 54 | " [[0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],\n", 55 | " [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],\n", 56 | " [0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n", 57 | " [0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0],\n", 58 | " [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0],\n", 59 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", 60 | " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n", 61 | "\n", 62 | "shapes = [square, triangle, triangle[::-1, :].copy()]" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": { 69 | "collapsed": true 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "def generate_shapes_image(width, height, nr_shapes=3):\n", 74 | " img = np.zeros((height, width))\n", 75 | " grp = np.zeros_like(img)\n", 76 | " k = 1\n", 77 | " \n", 78 | " for i in range(nr_shapes):\n", 79 | " shape = shapes[np.random.randint(0, len(shapes))]\n", 80 | " sy, sx = shape.shape\n", 81 | " x = np.random.randint(0, width-sx+1)\n", 82 | " y = np.random.randint(0, height-sy+1)\n", 83 | " region = (slice(y,y+sy), slice(x,x+sx))\n", 84 | " img[region][shape != 0] += 1\n", 85 | " grp[region][shape != 0] = k \n", 86 | " k += 1\n", 87 | " \n", 88 | " grp[img > 1] = 0\n", 89 | " img = img != 0\n", 90 | " return img, grp" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 4, 96 | "metadata": { 97 | "collapsed": false 98 | }, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAA4oAAAElCAYAAACiWBzqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADjFJREFUeJzt3T2OHNcVBtChoVRWYIy4Aq1AmTdAQAqcMyQIeA1KGDDR\nGgwQCrUDEeAGnGkFXoE0cCBrAeOIEL9S9XT1q5/3U+dkxEx3FbvfVNdXfe+tZ4+Pj3cAAADw0V9q\n7wAAAABtERQBAAAIgiIAAABBUAQAACAIigAAAARBEQAAgPDZUz98ePjdvTNO6v7+82dHbMcaOy9r\n7A9ffvnX2ruwm19//V+1bVtjY2nh72S6nq0x9maNsben1phvFAEAAAiCIgAAAEFQBAAAIDzZo0ha\n2h9RsycHAABgLd8oAgAAEARFAAAAgqAIAABA0KN4wZr7NU0fq2fxeC3cb2tP1hQAAHsSFAEqmwb/\nNRc6triIULp9FzBoQck6HP3iIkAJpacAAAAEQREAAIAgKAIAABC66lHstYdgbr/18gB7uHWYVq/H\nVQBgX10FRYAzmAt3LQY6F7wAYFxKTwEAAAiCIgAAAEHpaSW39hEBlGixZBUAPjXyZ1XP5/iC4o0M\nhujD1j1ebmJObdN1cPSxxjoEgHNRegoAAEAQFAEAAAiCIgAAAEGPYiPm+o30BLWjt5uY197+3Z31\nC9TRwvEPYARDBsWaQx9qD5zgsl7eGwGLJbYe2LTk+QGA81B6CgAAQBAUAQAACEOWnu5ZUnhkueKt\nfXEcp9WyVQAA2MKQQRGW2LvHq3QfoFRpH651CEBNW86R2OIzrXT7o32eKj0FAAAgCIoAAAAEQREA\nAICgRxEAgMONPhhutH61Xtw6DHL0dbiGoAifONtNzG/dHwfTvrS23mAN6xnOo4WBg0uMflxSegoA\nAEAQFAEAAAiCIgAAAEGPIgAAMKQWext7ISjeaKvJSaM3v45k+l55jwEAjlV6PrbX9s9A6SkAAABB\nUAQAACAoPQWYaKGf4YwlLgBAOwRFuJETeABYb8ubqm/x2bzmIqFzg+NtuX6WPv/ZKD0FAAAgCIoA\nAAAEQREAAICgRxFOrIWhLQCw1tzn2Vb3voaz6ioobtVU6obpAADtqX1T9aWcI7apdP14P+cpPQUA\nACAIigAAAISuSk+P1mq5AzC+rY4/ymmAs3IeB+sIigCFloQwJyoA5fa+qXrpPtAH7906Sk8BAAAI\ngiIAAABBUAQAACAIigAAAATDbGBAmrcBGFXpTdVLnhvOzDeKAAAABEERAACAoPSUIS0tQ1FiAgAA\nf3bKoCgcAK27dpw6+obTAK2aO166YAzrKT0FAAAgCIoAAAAEQREAAIBwyh5FxlParzX3OP0KLKVP\nkC21sJ4c/wD4SFAEAGAoLnrAekpPAQAACIIiAAAAQekp3dm7j0ffIgAAZ+cbRQAAAIKgCAAAQBAU\nAQAACIIiAAAAwTAbmtfCTain+2C4DXtrYd0DAOclKDKsa2HOiTiXbHkhwDpjayXr0zoE4FZKTwEA\nAAiCIgAAAEHpKU3ppTxqbj/1LQKUaeHY7xgOkARFhnHrh/zc77dwsgIAALUpPQUAACAIigAAAARB\nEQAAgKBHkaq27Ams3V843b7BCAAA9EpQBNiRCwYAQI+UngIAABAERQAAAILSUwDoSO1+bADOQVAE\nALpQ0vMrWAOUUXoKAABAEBQBAAAIgiIAAABBjyIAAKfTQv9qT/fabeH12ste78PS16zVdSAoUlWr\nfxgAAHBmSk8BAAAIgiIAAABB6SkAQMP0hkH71vydzj22hb8NQREAGtDCSQEAfKT0FAAAgCAoAgAA\nEARFAAAAgh5FAKALIw91Afqy9/Fo+vw1+tgFRQAAuKDkBH3Eixpzr0Pp/3OL0LPmNTY8bBmlpwAA\nAARBEQAAgKD0FAAAOEzJDeZrlvO2UEpc8pqtJSgCADRsejJYuzer9vZpx5Zrcy+111zLAfgapacA\nAAAEQREAAIAgKAIAABD0KAIAnERvQ0Q4j5bWWUv78pTpfm7djykoAgBV1R420Zstb3y+J+/r+bSw\nNmuuu5Jtt/CaXaL0FAAAgCAoAgAAEARFAAAAgh5FAIATa6UfqlVen7Ft+f7WXislw6qeIigCAHRu\nejJY44TV8Brm7Lk2rbl9KT0FAAAgCIoAAAAEQREAAICgRxEAYDB738Rbbxil1qxN6+5YvlEEAAAg\nCIoAAAAEQREAAIAgKAIAABAMswEAOAFDRNKI/6de1XwvrIPLfKMIAABAEBQBAAAIgiIAAADh2ePj\nY+19AAAAoCG+UQQAACAIigAAAARBEQAAgCAoAgAAEARFAAAAgqAIAABAEBQBAAAIgiIAAABBUAQA\nACAIigAAAARBEQAAgCAoAgAAEARFAAAAgqAIAABAEBQBAAAIgiIAAABBUAQAACAIigAAAARBEQAA\ngCAoAgAAEARFAAAAgqAIAABAEBQBAAAIgiIAAADhs6d++PDw++NRO0Jb7u8/f3bEdqyx87LG/vDi\nzbvau7CbD29fV9u2NTaWl89f1d6Fux9/+SH+bY2xN2uMvT21xnyjCAAAQBAUAQAACIIiAAAA4cke\nRdLS/ohpDwMAAEBPfKMIAABAEBQBAAAIgiIAAABBUAQAACAYZnPBmhv7Th9ruM3xWrgx856sKQAA\n9iQoAlT24e3r+PeLN+82e64SpdvfYtuwVsmFtNEvLgKUUHoKAABAEBQBAAAIXZWerinHWmKvsqm5\nkhY9ZsAepsfJa8e1vY+rAECfugqKR5oLcnoYgCPMhbsWA52eRAAYl9JTAAAAgqAIAABAEBQBAAAI\nehQrmfY7Gm4D7KHF3kYA+NTIn1U99/OfNiiWvGlLwpyBN23YehjRFkG+dPsuInB39+dj1tEfqj1/\n0AEAt1N6CgAAQBAUAQAACKctPZ0rA/y0xO9SWdde5VfX9oe6bu0prV2C3EKtv1JFoIbax1+AUQwZ\nFKcnqEd+aEwDhA+sdvTy3rhAwBJzQXzLCwSCPgCcm9JTAAAAgqAIAABAEBQBAAAIQ/Yottp7dqtb\nB6hwnFHWGAAAzBkuKP7tX/9e9HvXQtdWgxy2vvE722nhvRH+2dL0uLV0uI3BNQDUVPr5teS5SpRu\nf7TPU6WnAAAABEERAACAMFzpKQAA7fv25+9r78Kufvr6u9q7cErTstFr5aBb3oN4NIIifGLvvsXW\nehJvraV3MO3LaL0SnFtrx09gP3OfXy2eg4z+Oav0FAAAgCAoAgAAEARFAAAAgh7FCy7VQV+rRV7S\nmP1F0R4BAAC3aLG3sRfDBcX//vPvtXfhqt/ef2USVkemAxSWDrcxeAEAYBvTL2uODoCjD66Zo/QU\nAACAICgCAAAQhis9BVirhZtAK08HAGoSFOFGeg8BYL25C2KlF+q2uLi25iKhi3vHm+sZ3LJv8Yw9\niVNKTwEAAAiCIgAAAEFQBAAAIJy2R/F6Hfr9IfsBNbkJLQAjmDuvu9Y32MLgMmhZV0Fxq6bSJQeG\nz/7xoDEZAOBA03OvVsOcc8Q2TbPC0gviBtfMU3oKAABAEBQBAAAIXZWeHq3VcgdgfC+fv9rkedz3\nEzgr53GwjqAIUGhJj8pc4LsW3rYKiQC9mzvOHh0A9SP2S+/hOkpPAQAACIIiAAAAQVAEAAAgCIoA\nAAAEw2xgQJq3ARjVdLjMlsNtDK6BP/hGEQAAgCAoAgAAEJSeMqSlZShKTGjBdL1+UWk/AAA+OmVQ\nFA6Alv32/qurx6mXz18dtDcAbZs7XrpgDOspPQUAACAIigAAAARBEQAAgHDKHkXGU3oPpbnH6Vdg\nqSXrzmAaltryXnClHP8A+EhQBNjRksE0AGzLcRfWU3oKAABAEBQBAAAISk/pzt59PPoWAQA4O98o\nAgAAEARFAAAAgqAIAABAEBQBAAAIhtnQvBZuQj3dB8NtuMXL56+e/oX3X93+GACAHQmKDOtamGsh\ngNKmLS8ELAl8X3zzn822x/hK1qfjHQC3UnoKAABAEBQBAAAISk9pSi/lUXP7qW8RoMyLN+9q78Ld\nh7eva+8CQFMERYZxa1Cb+/1egioAAOxJ6SkAAABBUAQAACAIigAAAAQ9ilS1ZU9g7f7C6fYNtwEA\noFeCIsCOfvzlh9q7AABwM6WnAAAABEERAACAICgCAAAQ9CgCQEdqD+4C4BwERQCgCx/evr75MS/e\nvNthTwDGp/QUAACAICgCAAAQlJ4CAHA6L5+/qr0LXd1rd+T+6J++/m6X5126xlpdB4IiVe31hwkA\nAJRTegoAAEAQFAEAAAiCIgAAAEGPIgBAw0a+F2TJvTGhRWuGI809toUBN4IiADTAcC8AWqL0FAAA\ngCAoAgAAEJSeAgBdGLlXD+jLmp7Ekuev0bMoKAIAwAUlJ+h7h4ga5vqov/35+82e61al295q+2eg\n9BQAAIAgKAIAABAERQAAAIIeRQAA4DBz/YXX+gbX9CSu1ULP6dw+7D3gRlAEAGjYh7ev499rpr9O\nn6tE7e3Tjmm4qxnmLqk9uOZamGshhF6i9BQAAIAgKAIAABCUngIAnMRc2ei1ctA1paawVEtlqy2X\ng35qup9b9ywKigBAVfrWbjP3erUY5ryv5zPXD3h0AKzZk1gS1OYe00pQVXoKAABAEBQBAAAIgiIA\nAABBjyIAwIm12N/Yklb6xdjHlu9v7bUyt/01A24ERQCAzk0Hx9QIf4bXMGc6XGbL4TY1B9ecgdJT\nAAAAgqAIAABAEBQBAAAIehQBAAYz1y+4Zd+ifkRKzfUVLu1b1JN4LN8oAgAAEARFAAAAgqAIAABA\nEBQBAAAIzx4fHy/+8OHh98s/ZGj3958/O2I71th5WWPszRpjbyOssaUDbgyvqWOENUbbnlpjvlEE\nAAAgCIoAAAAEQREAAIDwZI8iAAAA5+MbRQAAAIKgCAAAQBAUAQAACIIiAAAAQVAEAAAgCIoAAACE\n/wMRaps31beRmQAAAABJRU5ErkJggg==\n", 103 | "text/plain": [ 104 | "" 105 | ] 106 | }, 107 | "metadata": {}, 108 | "output_type": "display_data" 109 | } 110 | ], 111 | "source": [ 112 | "fig, axes = plt.subplots(ncols=6, nrows=2, figsize=(16, 5))\n", 113 | "for ax in axes.T:\n", 114 | " img, grp = generate_shapes_image(28, 28, 3)\n", 115 | " plot_input_image(img, ax[0])\n", 116 | " plot_groups(grp, ax[1])" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "# Save as HDF5 Dataset" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "metadata": { 130 | "collapsed": true 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "import h5py\n", 135 | "import os\n", 136 | "import os.path\n", 137 | "\n", 138 | "data_dir = os.environ.get('BRAINSTORM_DATA_DIR', '.')" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 6, 144 | "metadata": { 145 | "collapsed": false 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "np.random.seed(265076)\n", 150 | "nr_train_examples = 60000\n", 151 | "nr_test_examples = 10000\n", 152 | "nr_single_examples = 10000\n", 153 | "width = 28\n", 154 | "height = 28\n", 155 | "nr_shapes = 3\n", 156 | "\n", 157 | "data = np.zeros((1, nr_train_examples, height, width, 1), dtype=np.float32)\n", 158 | "grps = np.zeros_like(data)\n", 159 | "for i in range(nr_train_examples):\n", 160 | " data[0, i, :, :, 0], grps[0, i, :, :, 0] = generate_shapes_image(width, height, nr_shapes)\n", 161 | " \n", 162 | "data_test = np.zeros((1, nr_test_examples, height, width, 1), dtype=np.float32)\n", 163 | "grps_test = np.zeros_like(data_test)\n", 164 | "for i in range(nr_test_examples):\n", 165 | " data_test[0, i, :, :, 0], grps_test[0, i, :, :, 0] = generate_shapes_image(width, height, nr_shapes)\n", 166 | "\n", 167 | "data_single = np.zeros((1, nr_single_examples, height, width, 1), dtype=np.float32)\n", 168 | "grps_single = np.zeros_like(data_single)\n", 169 | "for i in range(nr_single_examples):\n", 170 | " data_single[0, i, :, :, 0], grps_single[0, i, :, :, 0] = generate_shapes_image(width, height, 1)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 7, 176 | "metadata": { 177 | "collapsed": false 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "import h5py\n", 182 | "\n", 183 | "with h5py.File(os.path.join(data_dir, 'shapes.h5'), 'w') as f:\n", 184 | " single = f.create_group('train_single')\n", 185 | " single.create_dataset('default', data=data_single, compression='gzip', chunks=(1, 100, height, width, 1))\n", 186 | " single.create_dataset('groups', data=grps_single, compression='gzip', chunks=(1, 100, height, width, 1))\n", 187 | " \n", 188 | " train = f.create_group('train_multi')\n", 189 | " train.create_dataset('default', data=data, compression='gzip', chunks=(1, 100, height, width, 1))\n", 190 | " train.create_dataset('groups', data=grps, compression='gzip', chunks=(1, 100, height, width, 1))\n", 191 | " \n", 192 | " test = f.create_group('test')\n", 193 | " test.create_dataset('default', data=data_test, compression='gzip', chunks=(1, 100, height, width, 1))\n", 194 | " test.create_dataset('groups', data=grps_test, compression='gzip', chunks=(1, 100, height, width, 1))" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "# References\n", 202 | "[1] David P. Reichert and Thomas Serre, [Neuronal Synchrony in Complex-Valued Deep Networks](http://arxiv.org/abs/1312.6115), ICLR 2014\n", 203 | "\n" 204 | ] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.4.3" 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 0 228 | } 229 | -------------------------------------------------------------------------------- /Datasets/Simple Superposition.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Simple Superposition\n", 8 | "A very simple problem with moderate superposition introduced by Rao et al. [1].\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": { 15 | "collapsed": false 16 | }, 17 | "outputs": [ 18 | { 19 | "name": "stderr", 20 | "output_type": "stream", 21 | "text": [ 22 | "/home/greff/venv/py3/lib/python3.4/site-packages/matplotlib-1.5.0+783.g23bc09d-py3.4-linux-x86_64.egg/matplotlib/__init__.py:877: UserWarning: axes.color_cycle is deprecated and replaced with axes.prop_cycle; please use the latter.\n", 23 | " warnings.warn(self.msg_depr % (key, alt_key))\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import numpy as np\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "from plot_tools import plot_groups, plot_input_image\n", 31 | "%matplotlib inline\n", 32 | "np.random.seed(5163871)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": { 39 | "collapsed": true 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "objects = []\n", 44 | "# Object 1\n", 45 | "objects.append(np.array(\n", 46 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 47 | " [0, 1, 1, 1, 1, 1, 1, 0],\n", 48 | " [0, 1, 0, 0, 0, 0, 1, 0],\n", 49 | " [0, 1, 0, 0, 0, 0, 1, 0],\n", 50 | " [0, 1, 0, 0, 0, 0, 1, 0],\n", 51 | " [0, 1, 1, 1, 1, 1, 1, 0],\n", 52 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 53 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 54 | "# Object 2\n", 55 | "objects.append(np.array(\n", 56 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 57 | " [0, 0, 0, 0, 0, 1, 1, 0],\n", 58 | " [0, 0, 0, 0, 1, 1, 1, 0],\n", 59 | " [0, 0, 0, 1, 1, 0, 1, 0],\n", 60 | " [0, 0, 1, 1, 0, 0, 1, 0],\n", 61 | " [0, 1, 1, 0, 0, 0, 1, 0],\n", 62 | " [0, 1, 1, 1, 1, 1, 1, 0],\n", 63 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 64 | "# Object 3\n", 65 | "objects.append(np.array(\n", 66 | " [[0, 0, 0, 1, 0, 0, 0, 0],\n", 67 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 68 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 69 | " [1, 1, 1, 1, 1, 1, 1, 1],\n", 70 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 71 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 72 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 73 | " [0, 0, 0, 1, 0, 0, 0, 0]]))\n", 74 | "# Object 4\n", 75 | "objects.append(np.array(\n", 76 | " [[0, 0, 0, 1, 1, 0, 0, 0],\n", 77 | " [0, 0, 1, 0, 0, 1, 0, 0],\n", 78 | " [0, 1, 0, 0, 0, 0, 1, 0],\n", 79 | " [1, 0, 0, 0, 0, 0, 0, 1],\n", 80 | " [1, 0, 0, 0, 0, 0, 0, 1],\n", 81 | " [0, 1, 0, 0, 0, 0, 1, 0],\n", 82 | " [0, 0, 1, 0, 0, 1, 0, 0],\n", 83 | " [0, 0, 0, 1, 1, 0, 0, 0]]))\n", 84 | "# Object 5\n", 85 | "objects.append(np.array(\n", 86 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 87 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 88 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 89 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 90 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 91 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 92 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 93 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 94 | "# Object 6\n", 95 | "objects.append(np.array(\n", 96 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 97 | " [0, 1, 1, 1, 1, 1, 1, 0],\n", 98 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 99 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 100 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 101 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 102 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 103 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 104 | "# Object 7\n", 105 | "objects.append(np.array(\n", 106 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 107 | " [0, 0, 0, 0, 0, 0, 1, 0],\n", 108 | " [0, 0, 0, 0, 0, 0, 1, 0],\n", 109 | " [0, 0, 0, 0, 0, 0, 1, 0],\n", 110 | " [0, 0, 0, 0, 0, 0, 1, 0],\n", 111 | " [0, 0, 0, 0, 0, 0, 1, 0],\n", 112 | " [0, 0, 0, 0, 0, 0, 1, 0],\n", 113 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 114 | "# Object 8\n", 115 | "objects.append(np.array(\n", 116 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 117 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 118 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 119 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 120 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 121 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 122 | " [0, 1, 1, 1, 1, 1, 1, 0],\n", 123 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 124 | "# Object 9\n", 125 | "objects.append(np.array(\n", 126 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 127 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 128 | " [0, 1, 1, 0, 0, 0, 0, 0],\n", 129 | " [0, 0, 1, 1, 1, 0, 0, 0],\n", 130 | " [0, 0, 0, 0, 1, 1, 0, 0],\n", 131 | " [0, 0, 0, 0, 0, 1, 1, 0],\n", 132 | " [0, 0, 0, 0, 0, 0, 1, 0],\n", 133 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 134 | "# Object 10\n", 135 | "objects.append(np.array(\n", 136 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 137 | " [0, 0, 0, 0, 0, 0, 1, 0],\n", 138 | " [0, 0, 0, 0, 0, 1, 0, 0],\n", 139 | " [0, 0, 0, 0, 1, 0, 0, 0],\n", 140 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 141 | " [0, 0, 1, 0, 0, 0, 0, 0],\n", 142 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 143 | " [1, 0, 0, 0, 0, 0, 0, 0]]))\n", 144 | "# Object 11\n", 145 | "objects.append(np.array(\n", 146 | " [[0, 0, 0, 1, 0, 0, 0, 0],\n", 147 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 148 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 149 | " [0, 0, 0, 1, 0, 0, 0, 0],\n", 150 | " [1, 1, 1, 1, 0, 0, 0, 0],\n", 151 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 152 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 153 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 154 | "# Object 12\n", 155 | "objects.append(np.array(\n", 156 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 157 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 158 | " [0, 0, 0, 0, 0, 1, 1, 1],\n", 159 | " [0, 0, 0, 0, 0, 1, 0, 0],\n", 160 | " [0, 0, 0, 0, 0, 1, 0, 0],\n", 161 | " [0, 0, 0, 0, 0, 1, 1, 0],\n", 162 | " [0, 0, 0, 0, 0, 0, 1, 1],\n", 163 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 164 | "# Object 13\n", 165 | "objects.append(np.array(\n", 166 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 167 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 168 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 169 | " [0, 0, 0, 1, 1, 0, 0, 0],\n", 170 | " [0, 0, 1, 0, 0, 1, 0, 0],\n", 171 | " [0, 1, 0, 0, 0, 0, 1, 0],\n", 172 | " [0, 1, 0, 0, 0, 0, 1, 0],\n", 173 | " [0, 0, 1, 0, 0, 1, 0, 0]]))\n", 174 | "# Object 14\n", 175 | "objects.append(np.array(\n", 176 | " [[0, 0, 0, 0, 1, 0, 0, 0],\n", 177 | " [0, 0, 1, 1, 1, 1, 0, 0],\n", 178 | " [1, 1, 1, 0, 0, 1, 0, 0],\n", 179 | " [1, 0, 0, 0, 0, 1, 1, 1],\n", 180 | " [1, 0, 0, 0, 0, 0, 0, 1],\n", 181 | " [1, 0, 0, 0, 0, 0, 0, 0],\n", 182 | " [1, 0, 0, 0, 0, 0, 0, 0],\n", 183 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 184 | "# Object 15\n", 185 | "objects.append(np.array(\n", 186 | " [[0, 0, 0, 0, 0, 0, 0, 0],\n", 187 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 188 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 189 | " [0, 0, 1, 0, 0, 0, 0, 0],\n", 190 | " [0, 0, 1, 0, 0, 0, 0, 0],\n", 191 | " [0, 1, 1, 1, 1, 1, 1, 0],\n", 192 | " [0, 0, 0, 0, 0, 1, 0, 0],\n", 193 | " [0, 0, 0, 0, 0, 0, 0, 0]]))\n", 194 | "# Object 16\n", 195 | "objects.append(np.array(\n", 196 | " [[0, 1, 0, 0, 0, 0, 0, 0],\n", 197 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 198 | " [1, 1, 1, 1, 0, 0, 0, 0],\n", 199 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 200 | " [0, 1, 0, 0, 0, 0, 0, 0],\n", 201 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 202 | " [0, 0, 0, 0, 0, 0, 0, 0],\n", 203 | " [0, 0, 0, 0, 0, 0, 0, 0]]))" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 3, 209 | "metadata": { 210 | "collapsed": false 211 | }, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAJSCAYAAADH8R+wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAH99JREFUeJzt3X+w7Hdd3/HXhRtiQaI45pQEcLA4fIQ62gtDEEiiFUL5\nERtFQVSqFCjT2gFxtIQfg7ECJYJDodMf0gaa+gOjNvwQIkJoDQk/YtBSIdPmTZFaQdRzyYxzIUA0\n5vSP3RtPLvdecs6evfvd93k8Zhju2T373c937/vsPs/3uzd7YGtrKwAAnd1t1QsAAFg2wQMAtCd4\nAID2BA8A0J7gAQDaEzwAQHsHT3bl4cOf82/W94Ezz7z3gWVu3xz1Z4a+3MbGGQvdfnPzyB6tZH2Y\noy+36BwtYh1n8GQz5AgPANCe4AEA2hM8AEB7ggcAaE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQnuAB\nANoTPABAe4IHAGhP8AAA7Z3009IXscpPeN2P1vFTbflyPhkZ+lnX18NF1j3F55OlBQ8AsLhVxMO6\nRtrJOKUFALQneACA9gQPANCe4AEA2hM8AEB7ggcAaE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQnuAB\nANoTPABAe4IHAGjv4KoXAN1sbJyx6iUAe8zP9c4s+nhtbh7Zo5X8jckGzzJ2dsr8MAH0tW6vaYus\nd6qvZ05pAQDtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4\nAID2BA8A0N5kPy0dVmmqn/a7X+3Hv4/9uM9bW1tL3f5+fEzX1W7/rk42Q4IHJmZz88iubufJHJZr\ntz+b+82ij9Oynsuc0gIA2hM8AEB7ggcAaE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABA\ne4IHAGhP8AAA7QkeAKA9n5ZOWz49HICjBA8swebmkVUvoZV1fDwXDe513GeYMqe0AID2BA8A0J7g\nAQDaEzwAQHuCBwBoT/AAAO0JHgCgPcEDALQneACA9gQPANCe4AEA2hM8AEB7ggcAaG+yn5a+6CcN\ns/7MADAlizwnbW4e2cOVTNtUn7snGzywavvpCQr2g0V+pqf6It7VMp5/ndICANoTPABAe4IHAGhP\n8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7gAQDaEzwAQHuCBwBoT/AAAO0JHgCgPZ+WDifg05GBvbLI\n88kyPjn8K+n4/Le04FnFXxAALMMir2kd4+ErmWIDOKUFALQneACA9gQPANCe4AEA2hM8AEB7ggcA\naE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe0v7tHQAYHH78dPWl0HwMFmbm0dWvQSA\nhS3yXLbK2On2HOyUFgDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7gAQDa\nEzwAQHuCBwBoT/AAAO0JHgCgvQNbW1urXgMAwFI5wgMAtCd4AID2BA8A0J7gAQDaEzwAQHuCBwBo\nT/AAAO0JHgCgPcEDALQneACA9gQPANCe4AEA2ju46gUsaoyxkeQ1SQ4l+WKSA0leW1VXzK+/Jskr\nquq9x9zuRUk+VlVX7fD+7pnkCVX1luNcN5L8WpIbq+oZu9gdVmQqczTGOC3Jf0jykCR/K8mvVtVr\ndrVTnFITmqGvSfKmJGcmOT3Ju6vqp3e1U5xyU5mjY77nLUmOVNUzd7LtqelwhOdtST5SVd9aVY9M\n8v1JXjbGuOBkN6qqS3c6GHOHkjzl2AvHGPdKclmSd+xim6zeJOYoyXOTnF5Vj0nymCTPH2M8cBfb\n59Sbygz9UJIbqur8zGboh8cYf28X22c1pjJHSZIxxjOSbOxiu5NzYGtra9Vr2LUxxuOTvHw+FNsv\n/4dJXlhV585r+GNJvjnJ2fPvv2KMcXmS91fVZWOMpyV5XmYlfTjJc6rq5jHGhUkuSfKlJB9P8oIk\nH05ynyT/papeuO0+DyY5LckPJHmcIzzrY2JzdHqSg1V1y/zrm5L8cFX9/vIeARY1pRk65v43kvxu\nknOr6k/2er/ZW1ObozHG2Ul+I8nLkzzdEZ7VOpTkhuNc/qEkD9v29d2q6oIkFyV5/Rjjjv0eYzwg\nyUszi5Rzk1yT5CXzw3yXJXlSVZ2X5LPzbV6a5OpjB6OqbquqL+7ZnnEqTWmObt0WO09J8oUkH9mT\nvWSZJjND27Z3dZIbk/yM2FkbU5ujX0jyE5kF0tpb9+C5JSfeh9u3/fnqJKmqT8y/PnPbdY9KclaS\nd8/L+enzrx+a5FNVdXh+24ur6n17t3QmZHJzNMb4/iSvSvKUqrr9K30/Kze5GZq/IP7dJD81xnjU\nXd8VVmgyczTGeE5m70c9XoCtpXV/0/JHkzzrOJc/Ineu5O2DciDJ9vN4t2Z2vvvC7RsYYzw86x+E\n3DWTmqMxxg8m+akk31lVf7qT27Iyk5mhMcZ3JPnDqvp0VR0eY7w3ybmZHSVg2iYzR0m+L8nXjzG+\nK8kZSc4cY/xcVV28g21Mylq/oFfVtUmOjDHu+AsYY5yV2W/GL9v2rY+dX/fgJLdldk7zqA8nOWeM\ncd/59zx1jHFRkpuS3G+Mcf/55a+bX357Zu/VoYkpzdF82y/J7HC02FkTU5qhJE9K8pPz7z2Y5JGZ\nndpi4qY0R1X1xKp6RFV9e5IfS3LVOsdOsubBM3dhkgeOMT46xrg+yZVJLqmqD2z7ntvGGG9P8tYk\nz6+qO2q4qj6T5MeTvHOMcW2SZye5fv4+imcnuXKMcV1mb+q6KrPKPn+M8abtixhjPGaMcWOSVyS5\ncIxx4xjjucvaafbcJOZovo17J3nrGOOa+f+evJQ9Zq9NZYZemeRvz7dxfZL3VdW7lrHDLMVU5qid\ntf5XWosYY7w5yXuq6vJVr4X1ZY5YlBliL5ijr6zDEZ4dG7P/QNM5Sa5b9VpYX+aIRZkh9oI5umv2\n7REeAGD/OOm/0jp8+HNqaB8488x7H1jm9s1Rf2aIvWCOWNTJZmhfntICAPYXwQMAtCd4AID2BA8A\n0J7gAQDaEzwAQHuCBwBoT/AAAO0JHgCgPcEDALQneACA9gQPANCe4AEA2hM8AEB7B5e14Y2NMxa6\n/ebmkT1aCetq0RliZ/zMkXjupi9HeACA9gQPANCe4AEA2hM8AEB7ggcAaE/wAADtCR4AoD3BAwC0\nJ3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKC9g6teAJzI5uaRVS8BgCYc4QEA2hM8AEB7\nggcAaE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4\nAID2BA8A0J7gAQDaEzwAQHuCBwBoT/AAAO0JHgCgPcEDALQneACA9gQPANCe4AEA2hM8AEB7ggcA\naE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2\nBA8A0J7gAQDaEzwAQHuCBwBoT/AAAO0JHgCgPcEDALQneACA9gQPANCe4AEA2hM8AEB7ggcAaE/w\nAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGjv4KoXAABTt7FxxqqXsK9sbh7Z8206\nwgMAtCd4AID2BA8A0J7gAQDaEzwAQHuCBwBoT/AAAO0JHgCgPcEDALQneACA9gQPANCe4AEA2hM8\nAEB7ggcAaE/wAADtHVzWhjc3jyxr0wAsiefu4/O4rD9HeACA9gQPANCe4AEA2hM8AEB7ggcAaE/w\nAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7R3Y2tpa9RoAAJbKER4AoD3B\nAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2Dq56AYsaY2wkeU2S\nQ0m+mORAktdW1RXz669J8oqqeu8xt3tRko9V1VU7vL97JnlCVb3lmMufmeRnkvzRtosvrKrP72T7\nrMZU5mh+3TOT/GSSv0zyW1X1sh3vEKfcVGZojPHsJP9o20UPTPKGqnrVjnaIlZjQHN0jyRuSPCjJ\naUk+UFU/taudmoi1D54kb0vy61X1o0kyxviGJO8aY9xcVVef6EZVdeku7+9Qkqck+bIXqiSXV9XP\n7HK7rNYk5miM8aAk/zLJw5L8RZJfHWPct6r+bJf3w6kziRmqqjcmeeN8Dacn+WCS/7zL++DUm8Qc\nJXlaktOq6vwxxoEkvzvGOK+qrtvl/azcWgfPGOPxSe5eVa87ellV/fEY48VJLklydDguGmNcnOTs\nJC+vqivGGJcneX9VXTbGeFqS52VW0oeTPKeqbh5jXDjfzpeSfDzJCzJ7IrnPGOPVVfXCU7OnLNPE\n5uh7kry1qm6ef/20Je02e2hiM7TdT2T24imY18DE5ujPk3zdGOPumR3hOT3JzVlj6/4enkNJbjjO\n5R/K7Dfko+5WVRckuSjJ68cYd+z3GOMBSV6a5HFVdW6Sa5K8ZH6Y77IkT6qq85J8dr7NS5NcfYIn\nmAvGGFeNMT44xnj+4rvHKTKlOfqmJH89xrhyjPG7Y4yX7MkesmxTmqGj27tXkmcl+XcL7hunzmTm\naH406dNJ/t/8/99ZVf9rT/ZyRdY9eG7Jiffh9m1/vjpJquoT86/P3Hbdo5KcleTd83OjT59//dAk\nn6qqw/PbXlxV7zvJWq5P8q+q6slJvjvJPxtjPHZnu8OKTGmOkuRbkvxQkscleer8tz6mbWozlCTP\nyOxFyvsI18dk5miM8dQk90/yjUm+Icljxxjn7XSHpmStT2kl+Whmv8Ec6xG5cyVvH5QDSba2fX1r\nkhuq6sLtGxhjPDw7CMKquinJTfM/3zzG+O3Mav2/3dVtsDKTmaMkn0nyp1V1a5Jbxxj/Pcm3JXnP\nDrbBqTelGTrqKUleuYvbsTpTmqO/n+SqqvqrJH81xvidJOcmWdv38Kz1EZ6qujbJkfm5zCTJGOOs\nJK9Ksv1ftjx2ft2Dk9yW2TnNoz6c5Jwxxn3n3/PUMcZFmcXL/cYY959f/rr55bdndj7zTsYYLx5j\nPGf+53skOT/JR/ZqX1meKc1RZm9YfOIY47T5ufNHJrlxb/aUZZnYDB316Bz/9AgTNbE5uimzo0WZ\nnzJ7RJL/vRf7uSprHTxzFyZ54Bjjo2OM65NcmeSSqvrAtu+5bYzx9iRvTfL8qrqjhqvqM0l+PMk7\nxxjXJnl2kuur6pb5n68cY1yX5D5JrsrsCeT8McabjlnH5ZmdfrguyfuTXFlVju6sj0nMUVV9LMkv\nZ3aK9IOZvQnxXcvZZfbYJGYoScYY90ny11X1paXsKcs0lTl6Q2ZHmT+Q2Wva/0zy9iXs7ylzYGtr\n6yt/V0NjjDcneU9VXb7qtbC+zBGLMkPsBXP0lXU4wrNjY/YfaDona3wuktUzRyzKDLEXzNFds2+P\n8AAA+8e+PMIDAOwvJ/1n6YcPf87hn33gzDPvfWCZ2zdH/ZmhL7exccZCt9/cPLJHK1kf5ohFnWyG\nHOEBANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7gAQDaEzwAQHuCBwBoT/AAAO0J\nHgCgPcEDALR3cFkb3tg4Y6Hbb24e2aOVwP7gZw7gxBzhAQDaEzwAQHuCBwBoT/AAAO0JHgCgPcED\nALQneACA9gQPANCe4AEA2hM8AEB7ggcAaE/wAADtCR4AoD3BAwC0d3DVCziRjY0zdn3bzc0je7gS\nOHUWmXuAvbKuz0VbW1snvG6ywQPsnNgHOD6ntACA9gQPANCe4AEA2hM8AEB7ggcAaE/wAADtCR4A\noD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGiv5aelr/Jj7X1aNaucPwCOb2nBs8gLvxcM9jPR\nDKy7KT6POaUFALQneACA9gQPANCe4AEA2hM8AEB7ggcAaE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQ\nnuABANoTPABAe0v7tPT9apFPep/ip8vuV4v8PbL3/H0Ai5pk8Kzqhd+TKntFvAJMi1NaAEB7ggcA\naE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtDfJT0vf\nrxb5tHafzv3lFnk8Adi9RZ9/l/GaJni2WfQB9gLbi4icDn8XwKKc0gIA2hM8AEB7ggcAaE/wAADt\nCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9n5bOZPn0eYD9abfP/1tb\nWye8TvDsoc3NI7u+rRf35Vjk7wRgv+r4euaUFgDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMA\ntCd4AID2BA8A0J7gAQDaEzwAQHuCBwBoT/AAAO0JHgCgvYOrXgD9bWycseolALDPCZ6J2Nw8suol\nTJLHBWC9TPV52yktAKA9wQMAtCd4AID2BA8A0J7gAQDaEzwAQHuCBwBoT/AAAO0JHgCgPcEDALQn\neACA9gQPANCe4AEA2juwtbW16jUAACyVIzwAQHuCBwBoT/AAAO0JHgCgPcEDALQneACA9gQPANCe\n4AEA2hM8AEB7ggcAaE/wAADtCR4AoL2Dq17AosYYG0lek+RQki8mOZDktVV1xfz6a5K8oqree8zt\nXpTkY1V11Q7v755JnlBVbznm8tOT/EKSMV/D64+ugWmayuzMrxtJfi3JjVX1jG2XvyzJk+dru6qq\nfnYn98lyrckMnZXkV5Lco6rO3cn9cWqsyRy9MsnjMjtQ8v6q+omd3OcUdDjC87YkH6mqb62qRyb5\n/iQvG2NccLIbVdWlOx2SuUNJnnKcy5+b5J5V9egkFyS5ZIxx5i62z6kzidkZY9wryWVJ3nHM5Y+c\nf//5Sc5L8t1jjEfv4n5ZnknP0NyvJNnNfXHqTHqOxhhPTnJukkcleWSSc8cY37GL+12ptT7CM8Z4\nfJK7V9Xrjl5WVX88xnhxkkuSXD2/+KIxxsVJzk7y8qq6YoxxeWaVetkY42lJnpdZVR9O8pyqunmM\nceF8O19K8vEkL0jyxiT3GWO8uqpeuG0535zk/fM1fH6M8f4k/yDJLy9r/9m9ic3OrUken+QHknzj\ntsufmOTtVfWX8zW/PcmTknxwLx8LdmdNZihJLkry8CTfu4e7zx5Zkzl6d5Jrq+r2+ZpvTvL1e/pA\nnALrfoTnUJIbjnP5h5I8bNvXd6uqCzL7wX/9GOOO/R5jPCDJS5M8bn6495okL5kf8rssyZOq6rwk\nn51v89IkVx8zJEnykSRPHmOcNsb42sxK+Ow92EeWYzKzU1W3VdUXj7OWs5P82bav/yxmakrWYYZS\nVZ/b5f5xakx+juaXf25+X4/M7K0b797l/q7MWh/hSXJLThxtt2/789VJUlWfmJ2ezPZTTY9KclaS\nd8+vOz3J/03y0CSfqqrD89tenCRjjGN/ezrq8iQPSXJtkj9McmNm52KZpinNzl11IMnWgttg76zj\nDDE9azNHY4zzMnut+76q+vxutrFK6x48H03yrONc/ojcuZi3D82xLxq3Jrmhqi7cvoExxsOzgyNg\nVXVbkp/cdvurkvzRXb09p9xkZuckPpU7H9E5O8mn92C77I11mCGmby3maP6enV9I8uSqumkvtnmq\nrfUPVFVdm+TI/Lxmkjv+RcKrkrxs27c+dn7dg5Pcltn5zaM+nOScMcZ959/z1DHGRUluSnK/Mcb9\n55e/bn757UlOO3YtY4zvGmNctu1+Hp7ZYUUmaEqzcxJXJfmeMcZXjTG+KrM3GR7vTamswJrMEBO3\nDnM0xvi6JG9I8sR1jZ0kObC1td5HyMcYX53ZP+d7TJIvZPYX+eqqetv8+muS/H6Sb5r/75Kq+q/H\nvNnrBzM7OvOF+f9+tKr+fP7O9J9O8pdJPpnk2fNt/E6Sd1XVs7at47Qkv5jkQUnukeSFVfWeJe8+\nC5jQ7DwmsyeTr03y1Zkdxfk3VfUfxxj/IslTM/tt7jeq6ueX+JCwQ1OfocxOg7wjyb2SbGR2muPK\nqrpkWY8JO7cGc/Q1821vj51fqqo37v2jsTxrHzy7NcZ4c5L3VNXlq14L68XssCgzxF4wRzuz1qe0\ndmvM/mNN5yS5btVrYb2YHRZlhtgL5mjn9u0RHgBg/zjpv9I6fPhzamgfOPPMex9Y5vb32xxtbJyx\n0O03N4/s0UpOHTPEXpjyHO3Hn+t1dLIZ2pentACA/UXwAADtCR4AoD3BAwC0J3gAgPYEDwDQnuAB\nANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2BA8A0N7BVS8AlmVj44xVL2FXFln35uaRPVwJ\nQB+O8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7gAQDaEzwAQHuCBwBoT/AAAO0JHgCgPcEDALQneACA\n9gQPANCe4AEA2ju46gXQ38bGGateAsBKn4tWdd+bm0dWcr9TJHjgBFbxRCEOAZbDKS0AoD3BAwC0\nJ3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7gAQDaa/dp6av+\ntOlVfMI2HLWq+d/a2lrJ/S7boo+n5wOYjnbBQy9eMIC9ssjzifhdf05pAQDtCR4AoD3BAwC0J3gA\ngPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7gAQDaO7jqBRzPxsYZ\nq17Cri2y9s3NI3u4EgDgqEkGz6JWFQ7rHGpMg9kFWA6ntACA9gQPANCe4AEA2hM8AEB7ggcAaE/w\nAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhvaZ+W7tOXd26Rx2xVn7INU7fK56L9\n+DzouYipWlrwLGodf2gWWfN+fGJkOtbx5w1OJT8j688pLQCgPcEDALQneACA9gQPANCe4AEA2hM8\nAEB7ggcAaE/wAADtCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAewdXvQCAZdrcPLLqJZxyGxtn\nLHT7/fiY0Z8jPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7gAQDaEzwAQHuCBwBoT/AA\nAO0JHgCgPcEDALQ32U9LX+TTflf1Sb+LfkIxALAcSwueRaJjv4bDqkJt2bruF0yVnzn4ck5pAQDt\nCR4AoD3BAwC0J3gAgPYEDwDQnuABANoTPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7g\nAQDaO7C1tbXqNQAALJUjPABAe4IHAGhP8AAA7QkeAKA9wQMAtCd4AID2BA8A0J7gAQDaEzwAQHuC\nBwBoT/AAAO0dXPUCdmOMsZHkNUkOJflikgNJXltVV8yvvybJK6rqvcfc7kVJPlZVV+3w/u6Z5AlV\n9ZbjXDeS/FqSG6vqGfPL7pHk3yb5lvnaPprkx6rqr3dyvyzX1OfomOv/dZJDVfWdO7lPlmvqMzTG\n+M4kVyS5adu3PqeqPrGT+2W5pj5H88ufkOTnktya5A+S/NN1e01by+BJ8rYkv15VP5okY4xvSPKu\nMcbNVXX1iW5UVZfu8v4OJXlKkjsNxxjjXkkuS/KOJN+47aonJLm1qh49/773Jfnu+bqZjqnP0dHr\nz0/y8CS37/J+WZ51mKHfrqpn7vL+ODUmPUdjjDOSvCnJuVX1yTHGv8/sF/o/2OX9r8TaBc8Y4/FJ\n7l5Vrzt6WVX98RjjxUkuSXJ0OC4aY1yc5OwkL6+qK8YYlyd5f1VdNsZ4WpLnZVbShzP7refmMcaF\n8+18KcnHk7wgyRuT3GeM8eqqeuG25dya5PFJfiDbhqOqfjPJb87X+9VJvjbJp/f4oWAB6zBH83Xe\nK7Pfqp6f2W+ATMS6zBDTtiZzdEGSG6rqk/P1/djePgqnxjq+h+dQkhuOc/mHkjxs29d3q6oLklyU\n5PVjjDv2dYzxgCQvTfK4qjo3yTVJXjI/zHdZkidV1XlJPjvf5qVJrj5mMFJVt1XVF0+00PkwfjLJ\nFVX1ezvdUZZqXeboNUl+PrMnMKZlXWbo0Bjj7WOMD40xfnb7/TMJ6zBH35TkL8YYvzTGuH6M8dox\nxtodMFnHwb8lJ1739kP+VyfJtnPVZ2677lFJzkry7vm50afPv35okk9V1eH5bS+uqvftdqHzw8h/\nJ8nj5/XNdEx+jsYYFyT5mqq6cqe35ZSY/Awl+T9JXpnke5M8Nsl5Sf7xLrbD8qzDHCWzUPrnSc5P\n8pAkz9rldlZm7QotszcAH++BfkTuXMnbB+VAkq1tX9+a2eG5C7dvYIzx8OxBBM63c0tV3VRVnx9j\nvC3JdyT59UW3zZ6Z/BwleVqSh4wxrk9yepIHjTF+sap+ZA+2zeImP0NV9Sf5m+edL8yfiw4tul32\n1OTnKMlnkvxeVR2Zb/e3knzbHmz3lFq7IzxVdW2SI/NzmUmSMcZZSV6V5GXbvvWx8+senOS23PmU\nwIeTnDPGuO/8e546xrgos3/JcL8xxv3nl79ufvntSU7bwTLPSfKvxhgH5l8/OsmNO7g9S7YOc1RV\n/6SqHlZV357Zb+j/Q+xMxzrM0BjjR8YYl8z/fLck35XkIzvdV5ZnHeYoyW8nefT8PanJmr6mrV3w\nzF2Y5IFjjI/Of/u9MsklVfWBbd9z2xjj7UnemuT5VXVHDVfVZ5L8eJJ3jjGuTfLsJNdX1S3zP185\nxrguyX2SXJVZZZ8/xnjT9kWMMR4zxrgxySuSXDjGuHGM8dwk/ynJnyb5wHx9n8/sPCrTMvU5Yvqm\nPkNvSfItY4wPJvlgkj9KcvmePwosatJzND8l9pIk185n6QuZvfF5rRzY2tr6yt/VxBjjzUneU1WX\nr3otrC9zxKLMEHvBHO3Muh7h2bEx+w80nZPkulWvhfVljliUGWIvmKOd21dHeACA/WnfHOEBAPYv\nwQMAtCd4AID2BA8A0J7gAQDaEzwAQHv/H3HjFRhQaOBIAAAAAElFTkSuQmCC\n", 216 | "text/plain": [ 217 | "" 218 | ] 219 | }, 220 | "metadata": {}, 221 | "output_type": "display_data" 222 | } 223 | ], 224 | "source": [ 225 | "fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(10, 10))\n", 226 | "for r in range(4):\n", 227 | " for c in range(4):\n", 228 | " ax = axes[r, c]\n", 229 | " plot_input_image(objects[4*r + c], ax)\n", 230 | " ax.set_xlabel('Object {}'.format(4*r + c + 1))" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 4, 236 | "metadata": { 237 | "collapsed": false 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "all_pairs = np.zeros((1, len(objects)*(len(objects)-1)//2, 8, 8))\n", 242 | "all_pair_groups = np.zeros_like(all_pairs)\n", 243 | "\n", 244 | "k = 0\n", 245 | "for i in range(len(objects)):\n", 246 | " for j in range(i + 1, len(objects)):\n", 247 | " all_pairs[0, k, :, :] = objects[i] + objects[j]\n", 248 | " all_pair_groups[0, k][objects[i] != 0] = 1\n", 249 | " all_pair_groups[0, k][objects[j] != 0] = 2\n", 250 | " all_pair_groups[0, k][all_pairs[0, k] > 1] = 0\n", 251 | " k += 1\n", 252 | "\n", 253 | "all_pairs = (all_pairs != 0).astype(np.float32)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 5, 259 | "metadata": { 260 | "collapsed": false 261 | }, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABCMAAAQjCAYAAABaT+TIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3TFy21q2NlDI5dShygPQIFwawB3CTRS6ynGXeiIagKsc\nKtEQNACVB6HOXQqdiy/o/3+3323igAR4PpwDrBWaEnm4sQECnwXuq8PhMAAAAACkfFh7AQAAAMC+\nCCMAAACAKGEEAAAAECWMAAAAAKKEEQAAAEDUx9KDb2+/dzNq4/r609U5P1+jNnefv07+zOOvH5d+\n2aIW6jIMp9VmTK2atVKb1qjLOLU5Tl3Gqc1x6jKuhdpMfWanz2WGoY26tEptjuuhLmvta+fUZi/9\nMgyX7ZkWj6NLjNXGX0YAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooqjPUu+\nPH6/5DpO9vPu2yqve465tbmp+NxTEnWdWvsa27bFNf3dWvvalBZqs1d6YlyrtSlRt+NSdWnxvQ/D\n+n2xpC5T5zM9n8sMQ7s9U9LCeV6LWumZvZ4D65n5WqzdnNrMDiO4vNeH2+JGbLHpUl4fbtdeAgBw\ngtL5zJ7PZeDSHn/9WHsJVFLatls6jrpNAwAAAIgSRgAAAABRwggAAAAgShgBAAAARAkjAAAAgKju\npmncff5afLz3b5Utvj8TJdiZFvf3FtcEbE/p29JbGS1XMnqsdC6zitJn1/P7U3Al27XkHL7GucXU\nc9oX1+ecsmIYMfeDcnLH2YCx2ky995v7l1lN2cr4lyUnT7VOvHo4oWtx9jTt66G312B/Ou7m/qX4\neI0Toh7qMgx1embq876XcdZzzmfmnssMg56ZMlbXls6t08fgVnqm1PctbZ8W7bVnSqY+s6e0dAx2\nmwYAAAAQJYwAAAAAooQRAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgqtpoz5ItjLCpOUv44r97woiw\nsec9d/Z0aX1To8paG7k3NTbHXG6gF1PjuG4mfn/J59Ye5qQfs4VzHfoytp9P7d8tWLq/9HCc6W37\nlLZJ8hy4tI4etjtlq4QRS/TQdK3NEp6a6e2ECQCySp/LPcy5py2l//B5fbhd5T98+EuL26e165Vj\npv5jsHdzr2t7uB4+lds0AAAAgChhBAAAABAljAAAAACihBEAAABAlDACAAAAiKo2TaOVb2FdS2vj\ne4ojLIPrAGBfRs8HThg73botnOss2T5beP9k1eqZLR9namrteoX9qRJGLB3D0sO4krkjuUrje045\nQM+pzdSIMCOfAPZt7ueAi9GyHs5nxhgLzqVNXR/M3V/04nFT9TRCmBa4TQMAAACIEkYAAAAAUcII\nAAAAIEoYAQAAAEQJIwAAAICoaqM9l1jy7a49TIVYMkZny998u+X3toS6jPMN2uPW6Juej7/Mq43x\nb2W991utseC912VKjfe3h31tbt22UJut7xOcby89MTuMKJ50TpyQ7qW455oardZD3eaOZVrzvY3W\nvIMLK4A1vT7cTv5MDyHVXKe8/9aNfW6fMhacnB7GsE+t70ul153qxdbr1qJUv02+xgrbTr9kuU0D\nAAAAiBJGAAAAAFHCCAAAACBKGAEAAABECSMAAACAKGEEAAAAEDV7tOfd56/zX7XiOKjSuuaOnUwq\n1tUYrU2a3Jdsdzox1cs9HIPJ0S9wOUvOy5/fnzJrWOl8pvdrAy7P5087ZocRS9zcv6zxsjFz59NO\n7Rg39y+b3TlO6Yk57/2kD+fG5wlvebtP2ev7ntLiDOxFAfUFdfn50sIs9x2b7JlC7bZc1y2/t6XU\nZp41zmda+mxyTsM59nKccZsGAAAAECWMAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUatM04A5WvlG\nZAD64vMD2lBjzKb9G/o1O4wwnoZLmuqnJR80rfdq6+uDU7TSx62sg2ktbauW1gJztNDDNc/llkjV\npoVtwGlsq3a4TQMAAACIEkYAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQdXU4HNZeAwAA\nALAj/jICAAAAiBJGAAAAAFHCCAAAACBKGAEAAABECSMAAACAKGEEAAAAECWMAAAAAKKEEQAAAECU\nMAIAAACIEkYAAAAAUcIIAAAAIOrj2gvo1dvb78Oln/Pu89fJn3n89ePSL1t0ff3p6pyfr1GXYTit\nNmNq1ayV2rRGXcapzXHqMk5tjlOXcS3UZuozO30uMwxt1KVVanNcD3VZa187pzZ76ZdhuGzPtHgc\nXWKsNv4yAgAAAIgSRgAAAABRwggAAAAgShgBAAAARAkjAAAAgChhBAAAABBltOdGfHn8vuj3f959\nu9BK5lkyunOJqbol6rLWSNceematvuhtXFLK1sZMXcrSPk3UrcfjzNqfSzWdUpctv/8SPXNcD5/Z\na+m9Z/Z8DryW3numljWOM8IIurHXCx0AAICtcZsGAAAAECWMAAAAAKKEEQAAAECUMAIAAACIEkYA\nAAAAUaZpdGbpyJU11Rpd1HNNalObefY68mmt8WI92ENtSu9xbJqRY8y4pbUp/X7Px5lh0Ddjatal\ntH8/vz9Ve91L6KVf9vA50Ys99Mzc3319uJ39mjUIIzrSWvNc2lqjO3s/qduyWj3Ry4cUsF2lz54t\nH6O2fi5T09zzFRfJ/RMIs1Vu0wAAAACihBEAAABAlDACAAAAiBJGAAAAAFHCCAAAACDKNI2ZtvbN\nxGPv59xxTz2O7xx77pv7l+LvtTAKy7coj1Ob47Z27LokteFcjjOca0nPlI5RpSklN7NfkVbM7Zvi\n59rEZJsWxsHO7Xn6IYxozNxRhlMn0aXnTZ6ArzG+c6/j004xVht10TNj1hrB27pWjrFjprbb3DUu\nGY28h31pr6OjS/12ynb32fTfpv6DpOT14bbrXtzKcWbu5+fc9zDVMzf3L6NrauFzq7YWjjNLzh2W\nHmfHpM+B3aYBAAAARAkjAAAAgChhBAAAABAljAAAAACihBEAAABAlDACAAAAiDLacydKo1hamD/d\n0uilFqnPf1OTcXsYyTWHusw3tr8tGTc4NeO+BXs4zizZL4z6/W976JkapupWGjdY7GHHmVmv3cK1\nAWVbOdYII2ZKzwqesmTWbO/zp4dh27PcXxd+kG65Nkuoy3EuLsZtuTZLjzN75lhyHvUqG6vP3ZLg\nj00b3acm9rUeLmaXXN9s+VizpWsDt2kAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQZZpG\nBUu+nXZqxNaWv80d/q6Hb3pegxGV49TmOPvSOLWZr1S7lr6t/dL0zDx7r9ve33/a7PMB06WihBEN\nWTSvvRE1w5IaJzbF5wydSE3VbM+ji5ZSm+OEmuPU5ri5o+NKJ4M39y9d1Hvrx5GxbeDCqZ4aPdXL\n/lQyty49vO/eRjG+Ptw2f+yrdd3Uyr7UW8/M5TYNAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQJYwA\nAAAAokzT6EzP3269dO1zf7/0bbJbmFTRc0/UpjbH1Rp/2cK3T1PH3H3p5sLrWIPjyDzqNm5ObZbs\nS1PH/Of3pwXPfjm1eqaHc7kpW92flryvmp8vWxhb3EvPCCMaMjVGp5emmmvP7x24jBYDkRbXlLJ0\nNFmCzx560MO+BEm1Rl867me5TQMAAACIEkYAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQ\nZbRnY4qzoI11AoCzFT9bh32PXwWAtVwdDoe119Clt7ffFy/c1MnSMORPmK6vP12d8/M16jIMp9Vm\nTK2atVKb1qjLOLU5Tl3Gqc1x59bljw9/FuuypTBCzxynLuNStakVCtZ6Xj0z7pzaqMs4tXGbBgAA\nABAmjAAAAACihBEAAABAlDACAAAAiBJGAAAAAFHCCAAAACDKaE8AAAAgyl9GAAAAAFHCCAAAACBK\nGAEAAABECSMAAACAKGEEAAAAECWMAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUcIIAAAAIOrj2gtg\nW97efh/Sr3n3+Wvx8cdfP6q87vX1p6tzfr5Gbb48fp/8mZ933y79skUt1GUYpmuTrsswtFEbdRmn\nNsepyzi16UcrddEzx6nLuFJtbu5fir/bwjmwfWlcqTZLrm96qo2/jAAAAACihBEAAABAlDACAAAA\niBJGAAAAAFHCCAAAACBKGAEAAABEGe1JE6bG17w+3I4+drPwuWuNPTrV1PqGYf01ruWU2owq9Mzi\n5y6wrcbttTYtsH360vrnFnA5k8fnifOZJc/tWMLahBF07/XhdnRebq0LTtp3c/8y+iGrLwCAHhTP\nZ+5fwquBy3KbBgAAABAljAAAAACihBEAAABAlDACAAAAiBJGAAAAAFGmaRDT6gSD0rqMPOrbl8fv\nR/99ahws7NHosXDBWDmOa/XzkHbV7Jnn96dqz70nNbeR85ltWtIzxd/t6HNbGEEzShf+YwfhJc85\nDE4IW7ckDCr1TGkcLEALhOHAMPz7nKX0mPH2/F1pHGxr3KYBAAAARAkjAAAAgChhBAAAABAljAAA\nAACihBEAAABAlDACAAAAiDLak2bMHd+55HeT85mXzIheozZz/esf/zzr59PrW/N1e6hNzX1p7Ll7\nqMtar3up2qw1i75WzXrpmSXmrrmV2oyNG1xrW7RSl7mvu9Y+fExr+1Mrx5m1tN4zS7fP3FHsPXxm\n19TafjpWG2EEMaU5yWtpYU2lGdHD0N7BBI7Rx2075Vg394SP87Xw2UNf9nCM7f0Y9Pjrx+hja2yf\nqZ5hfVOfBa2FujW4TQMAAACIEkYAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAo0zS4qLvPX8cf7PDb\nw0vv5/n9KfI6PdaNfdLH21T61u6b+5fyL9vucDFbP8aWjjU9TIXY0vYpvpfhsufAU0bX0llN59jD\ne786HA5rr4EN+ePDn8WGKo09KpkaYVPjQ+qEA/HVOc/39vb7aG2mXmcYxut2ymif9Af49fWni9Rl\nqTV6ZkoLtalVlyV93EJdhkHPjFkURkyY+5nQQl1adW5tpj6355q7bWtppWfmHmeWHGOnpGpTeg9z\nRxxOPe8w5OriHPi4JftSrZ6p5ZL70pK+7ul8xm0aAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgShgB\nAAAARBntSVTpm2HX+ObtU76dGoA6WvtM+Lupz4gWv829Fb2PaezR2vtTq+dUa9eFcTV7ptZz65nL\nEkZwUaUdtNUPqRIHHIBxrw+3i0busb4an3OnjJ1mn5aMKj/3OYfBMWiuHs5/b+5fRte5he3ewza4\nBLdpAAAAAFHCCAAAACBKGAEAAABECSMAAACAKGEEAAAAECWMAAAAAKKM9qQZxTE8E7PcZz0nsDmL\n9vmJ44yZ5dvkc4I9KPX58/tTcCVwOWMjYW/C62A+YQQxrc6BdiEADMP2Z5avodXj/jlK72HsRBgu\nbepcpZf9Cf7TknPw0vH39eF2+Hn3bfTxqf3FtUGO2zQAAACAKGEEAAAAECWMAAAAAKKEEQAAAECU\nMAIAAACIMk2DTfAt0sAlGBO2TT4joA2j++LMEe7F5wSaJ4ygCUtG6PQ+nqf19cEpWunjtcaEcb5W\nemYY2loLzNFCD7c6fjRVmxa2AaexrdrhNg0AAAAgShgBAAAARAkjAAAAgChhBAAAABAljAAAAACi\nhBEAAABA1NXhcFh7DQAAAMCO+MsIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooQRAAAA\nQJQwAgAAAIgSRgAAAABRwggAAAAgShgBAAAARAkjAAAAgKiPay8A9uLt7ffh0s/55fH7ot//efft\nQiv5y/X1p6tzfr5GXYZhWW1q1GUY2qiNnhmnZ47TM+P0zPlOqZmeOW7LPeM4M26qNrX6ouSc2rS4\nLw1D+z1z9/lr8Xcff/0Yfayn44y/jAAAAACihBEAAABAlDACAAAAiBJGAAAAAFHCCAAAACBKGAEA\nAABEXR0O0YlOsFtj43umRve8PtxWWc+UuaN9tj7WaMmaUmPClo67muPm/mXyZ8Zqs+W6/H9r70+O\nM331zNT+1PpxprT+pb2mZ8apzXGt12XKVN1K+1vpWLFEarTnnntmL5/bHxetBljd3IPCmidMCadc\nHI9aYWZ3kp4ZN6c26jJObY7bQ13G3Ny/VAuSe6FnjnOcOW7R+czG7bVnTgkiejnOuE0DAAAAiBJG\nAAAAAFHCCAAAACBKGAEAAABECSMAAACAKNM0gKOmvlF37jcYs1zv3wJdU6u1Ka1rz/tSC3VptWfW\n1mpd9Mz5pkYEPr8/XeR1Wq1Lqmfmjv6e+r0lSs89NZEhsT/pmXW2/ZKRr5eujTACVlZrBvSU0gGj\nlQ+HJTOUa9V1rbnO56pxElF8zk4uqON1GdrYn6b2h8kTopl16+E4M6VGz0yN6uv9OHO3YBShnjnu\npPGOHRyH08fgXnpmypxzmpoXukl77pl0gLXG+YzbNAAAAIAoYQQAAAAQJYwAAAAAooQRAAAAQJQw\nAgAAAIgyTQNCxr6B9pRvyF5r4kbJ2Pv51z/+eZHn6VXp/Wy5NlPf7FyaDrDlupwi1TNLvpF8yTd3\nzz1+9b4v1fwm+95rU4u6ZNSozdL9pYXjDMct7ZdaIyrnSvZMbxNR5tRGGAH8l5ZGFabmxp+qlXFQ\nrdWlJa3VppWeaU1LdVmjZ2pcPF2afWlca7UpSY6Kba0uyZ6Zu0+v9R9Ok/8ZFtqWpW20xsV4smdK\n++bU9lmjb2rUxm0aAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgShgBAAAARAkjAAAAgCijPQGAalqb\nEX9Jvc2AP9fW3x9ZNftpy8cZ2DJhBADs3NyTdRerZT1cBI3NjbdtSXMcuqwejj/gNg0AAAAgShgB\nAAAARAkjAAAAgChhBAAAABAljAAAAACiTNMA/suXx+9rL+F/ldYy9i3wNbVSm6l1rFGbVuiZ4/TM\nuLHa3Ny/hFfyb630TGtaqktrxxnWV5rq8fpwW/zdFj+beujjWpNUTCI5rsYxWBgBIaMH9Q4O9sdc\n6kOq9DwtnXie6pIf3mPP1WJdkh/ceub856pVl6ntPvdE8dInwTV7Zo2T1sRFwtTFU2od50jsS8PQ\n33EmuT+V1NpX5j5va/27VVuqc/Sz6aKvlDGnPm7TAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUcII\nAAAAIMo0DVjZ1LfN1xoH1du3gR9TYyzVFuoyDHXGzpWe85QRiK2PyloyIqz193aKNUYVlmq+hZqO\ncZyZ95y9WGtk4txj2PP7U+T11zifmfpsauU443yGc22lZ4QRwFEtjGJqdbTa2rWZev09n4BMhiIr\nbLsW1rRWz5QuPk4JsBJK6zhlvGUtjjPHrV2XqTWssS8NQzv7E8fd3L+MBh97Pp+ZUivA6sFezoHd\npgEAAABECSMAAACAKGEEAAAAECWMAAAAAKKEEQAAAECUMAIAAACIMtoTOrfnMY5T1OY4deFcvffM\n3eev1Z6799rUoi7j1OY4dRmnNsctqcuWx4IOQz89c3U4HNZeA+zC29vvi+9sSw80NQ7E19efrs75\n+Rp1GYY2P6BaqI2eGadnjttCz/zx4c9iXR5//Zi1Dj1z3BZ6xnFm3NzalELB14fbOU85DMMw3Ny/\nFB+fu3/rmXHn1KbFugyD48yYdM+4TQMAAACIEkYAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAo0zQA\nAACAKH8ZAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooQRAAAAQJQwAgAAAIgSRgAAAABRwggAAAAg\nShgBAAAARAkjAAAAgChhBAAAABD1ce0FAPv29vb7sPYaUq6vP12d8/N7qY26jFOb49RlXAu1ufv8\ntfj4468fl37JSS3UpVVqc1wPdVlrXzunNnvpl2G4bM+0eBxdYqw2/jICAAAAiBJGAAAAAFHCCAAA\nACBKGAEAAABECSMAAACAKGEEAAAAEGW0J9ClqZFHa0mMWvry+H3yZ37efau+jv90yvZooTbputC+\nFnumhZFuLR5nWtFiz7RAz4wr1ebm/qXa67ZwLCnRMwgjAAAAoBFrB0UpbtMAAAAAooQRAAAAQJQw\nAgAAAIgSRgAAAABRwggAAAAgyjQNoFnFUVjBdcDWtT7+bWtaHU1Mu2r2zPP7U7Xn3pPiNnq4zS2E\nbsz97F16PGjpM10YAXTp9eHW7Glgs1o6WQSWubl/Gb+wvH+Z/P05xwOhJz1wmwYAAAAQJYwAAAAA\nooQRAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgymhPANi4L4/fi4/fTPz+khFxex1RaawenG7p/tLD\ncWbsODx1/K1pbt2f358uvJJxpTX2sN0pE0YAAIQ5iebS9FS7Xh9ui4/9vPsWXE0/bu5f1l5CVXP3\n2S3t627TAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUcIIAAAAIMo0DQCgmi2PZdvC+M7R91D49v/J\n34URtXpmy8cZ2DJhBADs3NzRci5Gy3q+CLq5fymu37bnXFNjGufuL3rxuCXHHzUlxW0aAAAAQJQw\nAgAAAIgSRgAAAABRwggAAAAgShgBAAAARJmmAQCsoodxfHv+Vvkvj99HH7sJroPT9LA/rUFdtmuN\n43OqZ0rH32EYZk3AapEwAgCYZeqkrPcL+alRhMPQ/8XM2PqnToTnjoNluyb7oVK/GEF7eanj2lTP\n3J1wDKZvbtMAAAAAooQRAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgShgBAAAARBntCcButDjmrffR\nkED7Sse+5/en4ErgL1MjhG9C62A9wggANqM0s7zFIGLrSkFLK9ujOOe+9BgETYWWrexPcEmvD7fl\nY/RMU/uL/yTIcZsGAAAAECWMAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUaZpAACb5pvTYX219kOT\nRKBfV4fDYe01ADv29vZ7Nweh6+tPV+f8/F5qoy7j1Oa4c+vyx4c/i3XZUhihZ45Tl3Gp2qwVRsx9\nXj0z7pzaqMs4tXGbBgAAABAmjAAAAACihBEAAABAlDACAAAAiBJGAAAAAFHCCAAAACDKaE8AAAAg\nyl9GAAAAAFHCCAAAACBKGAEAAABECSMAAACAKGEEAAAAECWMAAAAAKKEEQAAAECUMAIAAACIEkYA\nAAAAUcIIAAAAIEoYAQAAAER9XHsBABz39vb7sPYaEq6vP12d8/N7qcswqM0YdRnXQm3uPn8tPv74\n68elX3JSC3Vpldoc10Nd1trXzqnNXvplGC7bMy0eR5cYq42/jAAAAACihBEAAABAlDACAAAAiBJG\nAAAAAFHCCAAAACBKGAEAAABEGe0J0Jkvj9/XXsJRP+++rb2EZmtTom7HperS4nsfhvX7Ykldbio+\nd8nee6YkURt1GTdVmzX29xbWpGfma7F2c2ojjAAA4GJeH25HT0pbPIGGXj3++rH2EqiktG23dBx1\nmwYAAAAQJYwAAAAAooQRAAAAQJQwAgAAAIgSRgAAAABRpmkAcBF3n78WH/et3+uZ2jbDw21mIezG\naM/ptVWUjgHP70/BlWxX8Tg70fc1Pj8d99s3uY3GbGjbXR0Oh7XXAMARb2+/uzlAn/KBOnYydX39\n6eqc1+qpLktdqjazT3j+n9aCpN57Zsn+MqWF2rQYTLZQlyWW7sMlz+9PXdemlnN75o8Pf1arS5Uw\nYubzDsN5tdlLvwxDWz1T0tIx2G0aAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgShgBAAAARBntCQAU\nLfkm/9YmcaTUnH4AW7O1iT+9mFv31DjYFifzcFnCCADYuKkTNhfOeU6iuTQ91Tbbh78r9cReghi3\naQAAAABRwggAAAAgShgBAAAARAkjAAAAgChhBAAAABAljAAAAACijPYEAKopjSfrfTTZFkaibuE9\n0Ieavbbl4wxsmTACAHZu7sm6C9kyF0FwOsehy1py/FFTUtymAQAAAEQJIwAAAIAoYQQAAAAQJYwA\nAAAAooQRAAAAQJRpGgDAKlofx+cb5enJ2vtTq/vL2nVh3NKeWaPn9MxlCSMAgFmmTspavTi5pN5P\nTHtfP32p0W+l59zDMagGx4X17WUbuE0DAAAAiBJGAAAAAFHCCAAAACBKGAEAAABECSMAAACAKNM0\nANiMHr85fS/fmN2qHnsGzlXq8+f3p+BKAP4ijAAAqtjCyD1hES0wRpctWuv4OrW/OO7nuE0DAAAA\niBJGAAAAAFHCCAAAACBKGAEAAABECSMAAACAKGEEAAAAEGW0JwCwacYeQhtK++LccYr2b+iXMAKA\nxVqZyd3KOpjW0rZqaS0wRws9PLWGtUKDVG1a2AacxrZqh9s0AAAAgChhBAAAABAljAAAAACihBEA\nAABAlDACAAAAiLo6HA5rrwEAAADYEX8ZAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooQRAAAAQJQw\nAgAAAIgSRgAAAABRwggAAAAgShgBAAAARAkjAAAAgChhBAAAABD1ce0FAMA53t5+H9ZeQ8r19aer\nc35+L7VRl3Et1Obu89fi44+/flz6JSe1UJdWqc1xPdRlrX3tnNrspV+G4bI90+JxdImx2vjLCAAA\nACBKGAEAAABECSMAAACAKGEEAAAAECWMAAAAAKKEEQAAAECU0Z4AUNGXx+/Fx3/efau+hqkRYcOQ\nHxPWQl3oi54ZpzbH9VCXU47PNZ779eG2+Lst1GYNPfTMWmrUxl9GAAAAAFHCCAAAACBKGAEAAABE\nCSMAAACAKGEEAAAAECWMAAAAAKKM9gQAiiPg0mM/t25qlJ+Re/zd5PjHiZ4p/f7z+9OcJfE3NUd0\nsk1LembJ77b0mS6MAACAjt3cv4xeYLhI7t+ci0fbnR64TQMAAACIEkYAAAAAUcIIAAAAIEoYAQAA\nAEQJIwAAAIAoYQQAAAAQZbQnAFC0lXnmScbqkfbl8fvRf78Jr2OOpfvLXo8zNZW2yfP70+prGAbb\nfQuEEQBsNz7vAAAcvUlEQVSwcVMnbC6c80rbZOyiEsa8PtwWH/t59y24Go5p7cL55v5ldE0+EzJK\nPbGXIMZtGgAAAECUMAIAAACIEkYAAAAAUcIIAAAAIEoYAQAAAESZpgEAVDP6jeCFb//vxRa+cX4L\n74E+1Oy1LR9naup5HCzbIIwAgJ2bOyLMhWzZVkavQYLj0GVN1dMIYVrgNg0AAAAgShgBAAAARAkj\nAAAAgChhBAAAABAljAAAAACiTNMAADjCt/TD6ewvnGtpz1TpOeNgo4QRAMAsU6PjSieKN/cvmxh9\n2ft76H39/Fsv+1ONNfbwvnvz+nA7/Lz7tvYyVtHKvtTCGhLcpgEAAABECSMAAACAKGEEAAAAECWM\nAAAAAKKEEQAAAECUMAIAAACIMtoTgM1ocs69meVNa7Jn4IKmevz5/Sm0EujDl8fvo4/tdeRpLcII\nAKiolZnla9jC+97Ce6B/+pAtWquvSwHdq/9AiHKbBgAAABAljAAAAACihBEAAABAlDACAAAAiBJG\nAAAAAFGmaQAAm2Z8J7ShtC/OnawwtX+bRALtujocDmuvAQBO9vb2ezcfXNfXn67O+fm91Obcuvzx\n4c9iXbZ0saJnjlOXcana1AoNaj2vnhl3Tm3UZZzauE0DAAAACBNGAAAAAFHCCAAAACBKGAEAAABE\nCSMAAACAKGEEAAAAEGW0JwAAABDlLyMAAACAKGEEAAAAECWMAAAAAKKEEQAAAECUMAIAAACIEkYA\nAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAoj6uvQAA4DLe3n4f1l5DwvX1p6tz\nfn4vdRmGNmpz9/lr8fHHXz8u/ZKTWqhLq9TmuB7qsta+dk5t9tIvw3DZnmnxOLrEWG38ZQQAAAAQ\nJYwAAAAAooQRAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgymhPANi4qRFhw9DfmLAt29pIN2DcKcfn\nWs/tWMLa/GUEAAAAECWMAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUcIIAAAAIMpoTwCgOALO+LfL\nqjnKj22q2TPP70/VnntP7Neca0nPLPndlj7ThREAAI1p6WQRWNec44FwhB64TQMAAACIEkYAAAAA\nUcIIAAAAIEoYAQAAAEQJIwAAAIAo0zQAgKKtjBBL8k32cLql+8tejzNLza17ahzs1Pps9/4JIwBg\n46ZO2Fw45zmJ5tL0VNtsH/6u1BN7CWLcpgEAAABECSMAAACAKGEEAAAAECWMAAAAAKKEEQAAAECU\nMAIAAACIMtoTAKimNJ6s99FkWxiJuoX3QB9q9tqWjzOwZcIIANi5uSfrLmTLXATB6RyHLmvJ8UdN\nSXGbBgAAABAljAAAAACihBEAAABAlDACAAAAiBJGAAAAAFGmaQAAq2h9HJ9vlKcna+9Pre4va9eF\ncUt7Zo2e0zOXJYwAAGaZOilr9eLkkno/Me19/fSlRr+VnnMPx6AaHBfWt5dt4DYNAAAAIEoYAQAA\nAEQJIwAAAIAoYQQAAAAQJYwAAAAAokzTAICN+/L4fZXXvVnlVbPWqO3Pu2/VX2OtnlkiUZdhWKc2\nS/el0lSJ5/enhc9+unTtWjoG9bZPpfYnjuutX4ZhXs8IIwCAKl4fbouPt36yO7V+SDmlF8f2p9R4\ny6lRhGtcXL0+3DZxnHEsOW7Nuoz1RY8hQM/cpgEAAABECSMAAACAKGEEAAAAECWMAAAAAKKEEQAA\nAECUMAIAAACIMtoTANi0qdGGUyMJAYDLuzocDmuvAQC4gLe337v4UL++/nR1zs//8eHPYl22FEac\nWxs9c9xe6jIMudrUCgVrPa+eGXdObdRlnNq4TQMAAAAIE0YAAAAAUcIIAAAAIEoYAQAAAEQJIwAA\nAIAo0zQAAACAKH8ZAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooQRAAAAQJQwAgAAAIgSRgAAAABR\nwggAAAAgShgBAAAARAkjAAAAgChhBAAAABD1ce0FAADU9Pb2+7D2GlKurz9dnfPzNWpz9/lr8fHH\nXz8u/ZKTWqhLq9TmuB7qsta+dk5t9tIvw3DZnmnxOLrEWG38ZQQAAAAQJYwAAAAAooQRAAAAQJQw\nAgAAAIgSRgAAAABRwggAAAAgymhPAGC3vjx+Lz7+8+5baCV/aWGk21RdhmGd2rSgxZ5pgZ4ZV6rN\nzf1Ltddt4VhSomcQRgAAAEAj1g6KUtymAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooQRAAAAQJRp\nGgAAI2qMcZwat9eLsdrchNfRmt565vn9qdpz/12pNr2PcCxuo4fb3EI2Zrc9M8yfqNH6SNf/JIwA\nAGhMSyeLwDI39y/j+3SlC+qthJ5sm9s0AAAAgChhBAAAABAljAAAAACihBEAAABAlDACAAAAiBJG\nAAAAAFFGewIAzPTl8fvoYz8rjezrXalmw7D9uk29/2NuKqyDNszph1Oc0jNzx38+vz/N+r1LW1K7\nHo4zY9vn9eG2+HtT276lzy1hBACwW6UTr1oXCY+/flR53kuaOiEt1eb14Xb092vVNGmNnjnF2n21\npGfYpyUXvlvop9I+OzcoGobpsKIlbtMAAAAAooQRAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgyjQN\nAIAKehg7N/sb2ye+rX30eU/4lveWxs6xvlP2I32xTcXj08KJEY4zx6VHLwsjAACOmDrpKp0o9zRa\nbY6b+5e1l9CkWqMKS+NSezHZM52/v7lq9szc114yVjLl5v5ldDTmFsZ+lpTe+5TWauM2DQAAACBK\nGAEAAABECSMAAACAKGEEAAAAECWMAAAAAKJM0wAAOGLNb5Q3du44dWmX/aU/Y3W7Cb3+0p5ZMkKY\neS69rwkjAABmmjNeraXRanPHw9XSUm3S9nDRPKffpuqiZ45bUpfeR8meMnq4tWPf39VaX62emctt\nGgAAAECUMAIAAACIEkYAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAooz0BgN1aOuceWjfV48/vT6GV\nQDvG9ovXh9vi7/U88rRFwggAgBE1Zr1PncyuMeu9FaXa7LkujHNxeFwPx5kax9dTCKGPW6Nn3KYB\nAAAARAkjAAAAgChhBAAAABAljAAAAACihBEAAABAlGkaAMCm+eZ0aENpX5w7WWFq/15rYgMw7epw\nOKy9BgCAav748GfxZGdLFyvX15+uzvn5t7ffuzgRVJdxqdrUCg1qPa+eGXdObdRlnNq4TQMAAAAI\nE0YAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQZbQnAAAAEOUvIwAAAIAoYQQAAAAQJYwA\nAAAAooQRAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgShgBAAAARAkjAAAAgChhBAAAABAljAAAAACi\nPq69AAAA1vH29vtw6ee8+/y1+Pjjrx+XfslJ19efrs75+Rp1aZXaHNdDXdba186pzV76ZRgu2zMt\nHkeXGKuNv4wAAAAAooQRAAAAQJQwAgAAAIgSRgAAAABRwggAAAAgyjQNAACa9+Xxe/Hxn3ffQitp\nj9oc10NdpqYm1Hru14fb4u+2UJs19NAza6lRG38ZAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooQR\nAAAAQJQwAgAAAIgSRgAAAABRH9deAAAADMP0HHvmufv8dfSx5/en4Eour5WeKdWYtmyhZ5b87uOv\nH0f/fY26CCMAAOjCz7tvay+hSTf3L2svoVlb6Jmxi0fq2ELP1HLp2rhNAwAAAIgSRgAAAABRwggA\nAAAgShgBAAAARAkjAAAAgCjTNAAAiGllrF5rSnXZ+rSMqTGFpkkct1bPpMbBzu0Lx5iyluojjAAA\nIOL14XbyZ4zVm2evF+y99Mtet0+rWuibUk/UCuimgoh0XdymAQAAAEQJIwAAAIAoYQQAAAAQJYwA\nAAAAooQRAAAAQJQwAgAAAIgy2hMAYKemxsdxWVNj9basVq/VGoHYij33zBS1Oa6nuggjAAC4qN4v\nANfy8+7b0X+/u38JryTr9eF27SV0a6xnhrF/P0EPIeVUz4zWhaZq4zYNAAAAIEoYAQAAAEQJIwAA\nAIAoYQQAAAAQJYwAAAAAokzTAACAC1hzpF5pAkJiukmrExjWrsuUnsYwXlrNnqn13C30zDBsp2+E\nEQAAOzX3xHorJ8JraGmsXlKtUYylHm41IDlXumdaueAehvG1bGXbjiltg1OOv70cZ9ymAQAAAEQJ\nIwAAAIAoYQQAAAAQJYwAAAAAooQRAAAAQJRpGgAA/B+mZYzrsTalyQPP70/BlexXj33DevbSL8II\nAADO1svouLQadZkatdjCmMOWxkG2ptZY095N1aVkSb9N7S899PJWesZtGgAAAECUMAIAAACIEkYA\nAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAooz0BAHZqL7Psz6Uu40q1KY0brDVOsYWxpoyzL41TG2EE\nAABHbGWOfQ2t1WbuhfwepGrz+nBbfLy1nmnFGnXpZX/ZQ8+4TQMAAACIEkYAAAAAUcIIAAAAIEoY\nAQAAAEQJIwAAAICoq8PhsPYaAAAAgB3xlxEAAABAlDACAAAAiBJGAAAAAFHCCAAAACBKGAEAAABE\nCSMAAACAKGEEAAAAECWMAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUR/XXgAAALTk7e33ocbzfnn8\nXnz85923Gi9bdH396eqcn69RG3UZV6rNzf1L8Xcff/249HKGYTivNvalcaXa3H3+Wvzd0rbtqTb+\nMgIAAACIEkYAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQZbQnAADACqZGOA4Pt9Weu9bo\nTziVMAIAAKBBN/cvo6HB3f1LeDVwWW7TAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUcIIAAAAIMo0\nDQAAaNjk+McFnt+fqj33ntTcRl8evx/995tqr0jCkp4p/u6CcbBpwggAAIAGvRYuLF8fboefd9+O\nPlYzHKFtpXGwrXGbBgAAABAljAAAAACihBEAAABAlDACAAAAiBJGAAAAAFHCCAAAACDKaE8AAKBZ\nU2Mqexlj2Jq54z+f358uvJLjiusrjDxlGL48fh99bGwc7BqEEQAA0LmeL8hLF0dzL5hbM3f7lC4q\n6VupJ5YEcD31jNs0AAAAgChhBAAAABAljAAAAACihBEAAABAlDACAAAAiDJNAwAA/kNP30aftqQ2\nY1MzSs95c8LzliYPtDJlZK89teR9n7Lte7ak77fST8IIAACgSa8Pt12M/mxt1OKSEKaVmpa2/VYu\nxsfs5b27TQMAAACIEkYAAAAAUcIIAAAAIEoYAQAAAEQJIwAAAIAo0zQAAKBzPYy3XEOqLsUJFA+3\nF3udvRmt60RNa00E2fO+VMPV4XBYew0AANCMt7ffTZ0gL72wKl1AXV9/ujrnucZqc8oa0xdyU2u6\nZF3++PBnsWe2dBF7Tm1K+1IrI0TPkdiXtmisNm7TAAAAAKKEEQAAAECUMAIAAACIEkYAAAAAUcII\nAAAAIEoYAQAAAER9XHsBAABAPaURis/vT8GVAPxFGAEAAA17/PWj+HgpbIBWTfV1LVP7y1rr2iO3\naQAAAABRwggAAAAgShgBAAAARAkjAAAAgChhBAAAABBlmgYAAFBdaYrB3AkGJolAv4QRAADQsRZG\nEbY6fjRVmxa2AaexrdrhNg0AAAAgShgBAAAARAkjAAAAgChhBAAAABAljAAAAACihBEAAABA1NXh\ncFh7DQAAAMCO+MsIAAAAIEoYAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooQRAAAAQJQwAgAAAIgS\nRgAAAABRwggAAAAgShgBAAAARAkjAAAAgKiPay8AAADow9vb78Paa0i4vv50dc7Pr1GXu89fi48/\n/vpR5XXPqc1e+mUYLtsza23bWsZq4y8jAAAAgChhBAAAABAljAAAAACihBEAAABAlDACAAAAiDJN\nAwAAYAVTUxNqPndvExnYHn8ZAQAAAEQJIwAAAIAoYQQAAAAQJYwAAAAAooQRAAAAQJQwAgAAAIgS\nRgAAAABRV4fDYe01AAAAHXh7+93UxcOXx+9Vnvdf//jn1Tk//8eHP5uqyyleH25n/d45tWmtX4ZB\nzzz++nH032vVZRjGa/Ox2isCAACwyNjF4xI1LzzhVG7TAAAAAKKEEQAAAECUMAIAAACIEkYAAAAA\nUcIIAAAAIMo0DQAAYDV3n7/O/+WZ4ylp31Rf1JgyQpYwAgAA2KSfd9/WXsIwDO1dON/cv8z/5X9c\nbh0tSvVMqSdaDWIuXRu3aQAAAABRwggAAAAgShgBAAAARAkjAAAAgChhBAAA8D/t3TGO6lYYhmEc\nTUtpTR8WgVhAlpCGcqSpr24WEpR6JEo3WUIWcDWLID2inD5OlyIa7MHG3zHwPOWA7ZMTp+CVnR8g\nyjQNAABgUqPGd3KxdfPW+fkqtA7oIkYAAABFzW305b077DZnxzQKR6R4TQMAAACIEiMAAACAKDEC\nAAAAiBIjAAAAgCgxAgAAAIgSIwAAAIAooz0BAIBRSo2DXDdvnZ+fG1/J+b1bha4/9p4ZfPxuM+q6\nj6zrv7ch/62JEQAAwOSa4770EviCw25z1xFn9f3H7O/Fua/vWrymAQAAAESJEQAAAECUGAEAAABE\niREAAABAlBgBAAAARJmmAQAAPJS+kaB/f/sttJJ56tqfe9ibc2NBDyPGft7Dvowx5J4RIwAAgJt0\nzyMox+jbl74Yk1BqfOW5EPHoStwzXtMAAAAAosQIAAAAIEqMAAAAAKLECAAAACBKjAAAAACixAgA\nAAAgymhPAABgcl0jFYeOeewb01hqfCTQr2rbtvQaAACAG3A6fQz68TBVNJjqvHW9rC75/tB9uUWX\n7I19Oc/eeE0DAAAACBMjAAAAgCgxAgAAAIgSIwAAAIAoMQIAAACIMk0DAAAAiPJkBAAAABAlRgAA\nAABRYgQAAAAQJUYAAAAAUWIEAAAAECVGAAAAAFFiBAAAABAlRgAAAABRYgQAAAAQJUYAAAAAUWIE\nAAAAEPVUegEAAAC37HT6aNPX3D6/dH7eHPeTXLeul9VXv1tiX0q5ZF8Wi+69KfXvdirn9saTEQAA\nAECUGAEAAABEiREAAABAlBgBAAAARIkRAAAAQJQYAQAAAEQZ7QkAAFBA3wjHKc99a+MhuT+ejAAA\nAACixAgAAAAgSowAAAAAosQIAAAAIEqMAAAAAKLECAAAACDKaE8AAIARphzRyX0ac8+MOXZOI13F\nCAAAgJka8uNRHOEWeE0DAAAAiBIjAAAAgCgxAgAAAIgSIwAAAIAoMQIAAACIEiMAAACAKKM9AQAA\nZmrdvF18zOoL3xk6/vOvf/4cdNylutZ32G1Gnft9+zrq+FvWdT+l90WMAAAAmFBz3A86bkiI4DZ0\n3RN9oajr2Fu6Z7ymAQAAAESJEQAAAECUGAEAAABEiREAAABAlBgBAAAARJmmAQAAUMAtTT64J3Ma\nbznEVPdN33mvvTdiBAAAwEwN+QG4HnHevrGSKUPHofa59QB02G16vzPonimwL17TAAAAAKLECAAA\nACBKjAAAAACixAgAAAAgSowAAAAAokzTAAAAmNAcJzicW9MqvI5zJpvq8YVpFHzu2iNRxQgAAICJ\n9I1iHPIjrk/fObt+VB52m0nWNBer7z+6vzCDf/Zbu2eG8poGAAAAECVGAAAAAFFiBAAAABAlRgAA\nAABRYgQAAAAQJUYAAAAAUUZ7AgAAENU3KnIVWsf/da3rnkeeliBGAAAAMCuH3WaSH//b55fOa5Lj\nNQ0AAAAgSowAAAAAosQIAAAAIEqMAAAAAKLECAAAACDKNA0AAIBC+kZcDpko0XfOW2HM5ufu5Z4R\nIwAAAEZojvtBx5WMBqV/zPddf6q96RvfWXpf+tZwT/eM1zQAAACAKDECAAAAiBIjAAAAgCgxAgAA\nAIgSIwAAAIAoMQIAAACIqtq2Lb0GAAAA4IF4MgIAAACIEiMAAACAKDECAAAAiBIjAAAAgCgxAgAA\nAIgSIwAAAIAoMQIAAACIEiMAAACAKDECAAAAiBIjAAAAgCgxAgAAAIh6Kr0AAAAA7s/p9NGmr7lu\n3nq/8759vfp163pZXfL9rr3ZPr90Htsc95dc6j99ezPFviwW5/fGkxEAAABAlBgBAAAARIkRAAAA\nQJQYAQAAAESJEQAAAEBU1bbx/8EpAAAAd27MNI2vTMWYwtCJEslpGofd5pJLXc2198aTEQAAAHAD\nSoWIKYgRAAAAQJQYAQAAAESJEQAAAECUGAEAAABEiREAAABAlBgBAAAARD2VXgAAAAD3Z928lV4C\nMyZGAAAAcFPet6+DjruFQNIc94OP3T6/DDp3iX3xmgYAAAAQJUYAAAAAUWIEAAAAECVGAAAAAFFi\nBAAAABBlmgYAAAAshk+V+Pvbb1e7zur7j85jx0zbGOPaeyNGAAAAENc1nrNvROVi4GjPe1cqVAzh\nNQ0AAAAgSowAAAAAosQIAAAAIEqMAAAAAKLECAAAACBKjAAAAACijPYEAADgIXSNE+2ybt6uvJJ5\nGbovi8XwvfFkBAAAABAlRgAAAABRYgQAAAAQJUYAAAAAUWIEAAAAECVGAAAAAFFGewIAABC3fX4p\ncuwgu032eg+gatu29BoAAAC4Mz//8fvZH5ur7z+SS7mK5rg/+1ldL6tLztW1N4vFYvG+fb3kdLN2\nbm+8pgEAAABEiREAAABAlBgBAAAARIkRAAAAQJQYAQAAAESJEQAAAEDUU+kFAAAA8FgOu02R8ZXb\n55fOz7vGd3JdnowAAAAAosQIAAAAIEqMAAAAAKLECAAAACBKjAAAAACixAgAAAAgqmrbtvQaAAAA\nuDOn08fD/Nis62V1yfftjScjAAAAgDAxAgAAAIgSIwAAAIAoMQIAAACIEiMAAACAKDECAAAAiDLa\nEwAAAIjyZAQAAAAQJUYAAAAAUWIEAAAAECVGAAAAAFFiBAAAABAlRgAAAABRYgQAAAAQJUYAAAAA\nUWIEAAAAECVGAAAAAFFiBAAAABD1VHoBAAAA8EhOp4/23Gfb55fOY5vj/urrmVJdL6vP/u7JCAAA\nACBKjAAAAACixAgAAAAgSowAAAAAosQIAAAAIEqMAAAAAKLECAAAACBKjAAAAACixAgAAAAgSowA\nAAAAosQIAAAAIEqMAAAAAKLECAAAACCqatu29BoAAADgYfzy069Ffog3x338mnW9rD77uycjAAAA\ngCgxAgAAAIgSIwAAAIAoMQIAAACIEiMAAACAKDECAAAAiBIjAAAAgCgxAgAAAIiq2rYtvQYAAAB4\nGKfTx9kf4tvnl7PHHXab3nO/b1+HLWoidb2sPvu7JyMAAACAKDECAAAAiBIjAAAAgCgxAgAAAIgS\nIwAAAIAoMQIAAACIEiMAAACAKDECAAAAiBIjAAAAgCgxAgAAAIgSIwAAAIAoMQIAAACIeiq9AAAA\nAOA61s1b/Jrv29eLjxEjAAAAYCYOu03pJUR4TQMAAACIEiMAAACAKDECAAAAiBIjAAAAgCgxAgAA\nAIgSIwAAAIAoMQIAAACIqtq2Lb0GAAAAeBin08esfoivm7fe77xvXwedu66X1Wd/92QEAAAAECVG\nAAAAAFFiBAAAABAlRgAAAABRYgQAAAAQJUYAAAAAUUZ7AgAAQNDcRntOyWhPAAAAYBbECAAAACBK\njAAAAACixAgAAAAgSowAAAAAosQIAAAAIMpoTwAAACDKkxEAAABAlBgBAAAARIkRAAAAQJQYAQAA\nAESJEQAAAECUGAEAAABE/QuYundya43u3QAAAABJRU5ErkJggg==\n", 266 | "text/plain": [ 267 | "" 268 | ] 269 | }, 270 | "metadata": {}, 271 | "output_type": "display_data" 272 | } 273 | ], 274 | "source": [ 275 | "# WARNING: SLOW! for some reason this plot takes a couple of minutes. Feel free to skip\n", 276 | "fig, axes = plt.subplots(nrows=16, ncols=16, figsize=(20, 20))\n", 277 | "k = 0\n", 278 | "for r in range(16):\n", 279 | " for c in range(16):\n", 280 | " ax = axes[r, c]\n", 281 | " if c <= r:\n", 282 | " ax.set_visible(False)\n", 283 | " continue\n", 284 | " plot_groups(all_pair_groups[0, k], ax)\n", 285 | " k += 1" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "# Save as Dataset" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 6, 298 | "metadata": { 299 | "collapsed": true 300 | }, 301 | "outputs": [], 302 | "source": [ 303 | "import h5py\n", 304 | "import os\n", 305 | "import os.path\n", 306 | "\n", 307 | "data_dir = os.environ.get('BRAINSTORM_DATA_DIR', '.')" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 7, 313 | "metadata": { 314 | "collapsed": true 315 | }, 316 | "outputs": [], 317 | "source": [ 318 | "with h5py.File(os.path.join(data_dir, 'simple_superpos.h5'), 'w') as f:\n", 319 | " single = f.create_group('train_single')\n", 320 | " single.create_dataset('default', data=np.array(objects).reshape(1, 16, 8, 8, 1), compression='gzip')\n", 321 | " single.create_dataset('groups', data=np.array(objects).reshape(1, 16, 8, 8, 1), compression='gzip')\n", 322 | " \n", 323 | " pairs = f.create_group('test')\n", 324 | " pairs.create_dataset('default', data=all_pairs.reshape(1, -1, 8, 8, 1), compression='gzip')\n", 325 | " pairs.create_dataset('groups', data=all_pair_groups.reshape(1, -1, 8, 8, 1), compression='gzip')" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "# References\n", 333 | " * A. Ravishankar Rao and Guillermo A. Cecchi and Charles C. Peck and James R. Kozloski [Unsupervised Segmentation With Dynamical Units](http://www.cnbc.cmu.edu/cns/papers/oscillat_ieee_2008_v19.pdf)\n", 334 | " IEEE TRANSACTIONS ON NEURAL NETWORKS, VOL. 19, NO. 1, JANUARY 2008" 335 | ] 336 | } 337 | ], 338 | "metadata": { 339 | "kernelspec": { 340 | "display_name": "Python 3", 341 | "language": "python", 342 | "name": "python3" 343 | }, 344 | "language_info": { 345 | "codemirror_mode": { 346 | "name": "ipython", 347 | "version": 3 348 | }, 349 | "file_extension": ".py", 350 | "mimetype": "text/x-python", 351 | "name": "python", 352 | "nbconvert_exporter": "python", 353 | "pygments_lexer": "ipython3", 354 | "version": "3.4.3" 355 | } 356 | }, 357 | "nbformat": 4, 358 | "nbformat_minor": 0 359 | } 360 | -------------------------------------------------------------------------------- /Datasets/plot_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | from __future__ import division, print_function, unicode_literals 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | 7 | 8 | def plot_groups(groups, ax): 9 | mask = (groups == 0) 10 | sns.heatmap(groups, mask=mask, square=True, cmap='viridis_r', 11 | xticklabels=False, yticklabels=False, cbar=False, ax=ax) 12 | 13 | 14 | def plot_input_image(img, ax): 15 | mask = (img == 0) 16 | sns.heatmap(img, mask=mask, square=True, xticklabels=False, 17 | yticklabels=False, cmap='Greys', cbar=False, ax=ax) 18 | -------------------------------------------------------------------------------- /Networks/best_bars_dae.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_bars_dae.h5 -------------------------------------------------------------------------------- /Networks/best_bars_dae_train_multi.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_bars_dae_train_multi.h5 -------------------------------------------------------------------------------- /Networks/best_corners_dae.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_corners_dae.h5 -------------------------------------------------------------------------------- /Networks/best_corners_dae_train_multi.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_corners_dae_train_multi.h5 -------------------------------------------------------------------------------- /Networks/best_mnist_shape_dae.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_mnist_shape_dae.h5 -------------------------------------------------------------------------------- /Networks/best_mnist_shape_dae_train_multi.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_mnist_shape_dae_train_multi.h5 -------------------------------------------------------------------------------- /Networks/best_multi_mnist_dae.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_multi_mnist_dae.h5 -------------------------------------------------------------------------------- /Networks/best_multi_mnist_dae_train_multi.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_multi_mnist_dae_train_multi.h5 -------------------------------------------------------------------------------- /Networks/best_shapes_dae.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_shapes_dae.h5 -------------------------------------------------------------------------------- /Networks/best_shapes_dae_train_multi.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_shapes_dae_train_multi.h5 -------------------------------------------------------------------------------- /Networks/best_simple_superpos_dae.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/Networks/best_simple_superpos_dae.h5 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Binding by Reconstruction Clustering 2 | 3 | This is the code repository complementing the paper ["Binding by Reconstruction Clustering"](http://arxiv.org/abs/1511.06418). 4 | Everything from the paper starting from the datasets all the way to the plots can be 5 | reproduced from this repository. 6 | 7 | ![reconstruction clustering animated](animations/RC.gif) 8 | 9 | ## Dependencies and Setup 10 | 11 | * brainstorm == 0.5 12 | * numpy >= 1.8 13 | * matplotlib >= 1.5 14 | * seaborn >= 0.6 15 | * sacred >= 0.6.7 16 | * jupyter 17 | * ipython 18 | * pymongo 19 | * h5py 20 | * sklearn 21 | * pandas 22 | 23 | Make sure you have a MongoDB running locally. 24 | Make sure you have set the `BRAINSTORM_DATA_DIR` environment variable. 25 | 26 | ## Preparing the data 27 | First run all of the jupyter notebooks in the `Dataset` directory. 28 | This will create HDF5 files for all datasets and save them in your `BRAINSTORM_DATA_DIR`. 29 | 30 | 31 | ## Random Search 32 | If you want to run the random search for good hyperparameters yourself you can 33 | run the file ``run_random_search.py``. 34 | It will perform 100 runs of random search for each of the datasets and save the 35 | results to the local MongoDB database. 36 | But be warned: This might take a couple of days! 37 | 38 | You can then look at the results using the `Get_Search_Results.ipynb` notebook. 39 | 40 | ## Train best Networks 41 | To get the best networks we used in the paper for each dataset run `run_best_nets.py`. 42 | It will save a `Networks/DATASET_best_dae.h5` network for each dataset. This shouldn't take more than half an hour. 43 | 44 | Alternatively you can use your own best results from the random search by running the 45 | corresponding cells in the `Get_Search_Results.ipynb`. 46 | 47 | These files are needed for the following steps. 48 | 49 | ## Evaluation 50 | Next we use these networks for Reconstruction Clustering and store all of the results for later analysis. 51 | 52 | run_evaluation.py 53 | 54 | NOTE: This should take about an hour and use ca 21 GBytes of disk space. 55 | 56 | ## Plots 57 | The `Plots.ipynb` notebook generates all the figures used in the paper. 58 | Once you've run all the other steps you should be able to able to generate them 59 | all yourself. 60 | 61 | ## Database Dump 62 | With the file ``dump.zip`` we've included a dump of the MongoDB that contains all the information about 63 | all the experimental runs we did. 64 | 65 | 66 | 67 | ## Demo Images 68 | 69 | ### Regular (soft) Reconstruction Clustering 70 | These animations show the cluster assignment during a run of RC on 120 71 | different test images. 72 | For each dataset we used the best DAE trained on single object images. 73 | 74 | #### Shapes 75 | ![shapes animation](animations/shapes.gif) 76 | 77 | #### Multi MNIST 78 | ![Multi-MNIST animation](animations/multi_mnist.gif) 79 | 80 | #### Corners 81 | ![Corners animation](animations/corners.gif) 82 | 83 | #### Bars 84 | ![Bars animation](animations/bars.gif) 85 | 86 | #### MNIST + Shape 87 | ![MNIST + Shape animation](animations/mnist_shape.gif) 88 | 89 | #### Simple Superposition 90 | ![Simple Superposition animation](animations/simple_superpos.gif) 91 | 92 | 93 | ### Hard Reconstruction Clustering 94 | These animations show the hard cluster assignment during a run of RC on 120 95 | different test images. 96 | To improve visibility we toned down the brightness on the background pixels. 97 | Here we used the best DAE trained on **multi object** images. 98 | 99 | #### Shapes 100 | ![shapes animation](animations/shapes_train_multi.gif) 101 | 102 | #### Multi MNIST 103 | ![Multi-MNIST animation](animations/multi_mnist_train_multi.gif) 104 | 105 | #### Corners 106 | ![Corners animation](animations/corners_train_multi.gif) 107 | 108 | #### Bars 109 | ![Bars animation](animations/bars_train_multi.gif) 110 | 111 | #### MNIST + Shape 112 | ![MNIST + Shape animation](animations/mnist_shape_train_multi.gif) 113 | -------------------------------------------------------------------------------- /Run Random Search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from __future__ import division, print_function, unicode_literals\n", 12 | "from sacred.observers import MongoObserver\n", 13 | "from dae import ex\n", 14 | "import dispy" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 5, 20 | "metadata": { 21 | "collapsed": true 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "cluster_nodes = ['put', 'your', 'nodes', 'here']\n", 26 | "\n", 27 | "mongo_db = {\n", 28 | " 'url': 'DATABASE_IP_HERE:27017',\n", 29 | " 'db': 'binding_via_rc'\n", 30 | "}" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": { 37 | "collapsed": false 38 | }, 39 | "outputs": [ 40 | { 41 | "name": "stderr", 42 | "output_type": "stream", 43 | "text": [ 44 | "2016-01-16 17:57:08,627 - dispy - Storing fault recovery information in \"_dispy_20160116175708\"\n", 45 | "INFO:dispy:Storing fault recovery information in \"_dispy_20160116175708\"\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "cluster = dispy.JobCluster('dae.py', nodes=cluster_nodes)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": { 57 | "collapsed": false 58 | }, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "ClusterStatus(nodes=[], jobs_pending=0)" 64 | ] 65 | }, 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "cluster.status()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 6, 78 | "metadata": { 79 | "collapsed": false 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "# Random Search\n", 84 | "nr_runs_per_dataset = 100\n", 85 | "datasets = {\n", 86 | " 'bars': 12, \n", 87 | " 'corners': 5,\n", 88 | " 'shapes': 3,\n", 89 | " 'multi_mnist': 3,\n", 90 | " 'mnist_shape': 2,\n", 91 | " 'simple_superpos':2\n", 92 | "}\n", 93 | "\n", 94 | "jobs = []\n", 95 | "for ds, k in datasets.items():\n", 96 | " for i in range(nr_runs_per_dataset):\n", 97 | " job = cluster.submit('-m', '{url}:{db}.random_search'.format(**mongo_db), 'with', \n", 98 | " 'random_search',\n", 99 | " 'dataset.name={}'.format(ds),\n", 100 | " 'verbose=False',\n", 101 | " 'em.k={}'.format(k))\n", 102 | " jobs.append(job)\n", 103 | " \n" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 7, 109 | "metadata": { 110 | "collapsed": false 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "# Multi-Train Runs\n", 115 | "multi_jobs = []\n", 116 | "for ds, k in datasets.items():\n", 117 | " if ds == \"simple_superpos\": continue\n", 118 | " for i in range(nr_runs_per_dataset):\n", 119 | " job = cluster.submit('-m', '{url}:{db}.train_multi'.format(**mongo_db), 'with', \n", 120 | " 'random_search',\n", 121 | " 'dataset.name={}'.format(ds),\n", 122 | " 'dataset.train_set=train_multi',\n", 123 | " 'em.e_step=max',\n", 124 | " 'verbose=False',\n", 125 | " 'em.k={}'.format(k))\n", 126 | " multi_jobs.append(job)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 8, 132 | "metadata": { 133 | "collapsed": true 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "# MSE-Likelihood Runs\n", 138 | "mse_jobs = []\n", 139 | "for ds, k in datasets.items():\n", 140 | " for i in range(nr_runs_per_dataset):\n", 141 | " job = cluster.submit('-m', '{url}:{db}.mse_likelihood'.format(**mongo_db), 'with', \n", 142 | " 'random_search',\n", 143 | " 'dataset.name={}'.format(ds),\n", 144 | " 'dataset.salt_n_pepper=0.3',\n", 145 | " 'network_spec=Fr250',\n", 146 | " 'verbose=False',\n", 147 | " 'em.k={}'.format(k))\n", 148 | " mse_jobs.append(job)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "collapsed": true 156 | }, 157 | "outputs": [], 158 | "source": [] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "Python 3", 164 | "language": "python", 165 | "name": "python3" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.4.3" 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 0 182 | } 183 | -------------------------------------------------------------------------------- /animations/RC.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/RC.gif -------------------------------------------------------------------------------- /animations/bars.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/bars.gif -------------------------------------------------------------------------------- /animations/bars_train_multi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/bars_train_multi.gif -------------------------------------------------------------------------------- /animations/corners.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/corners.gif -------------------------------------------------------------------------------- /animations/corners_train_multi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/corners_train_multi.gif -------------------------------------------------------------------------------- /animations/mnist_shape.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/mnist_shape.gif -------------------------------------------------------------------------------- /animations/mnist_shape_train_multi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/mnist_shape_train_multi.gif -------------------------------------------------------------------------------- /animations/multi_mnist.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/multi_mnist.gif -------------------------------------------------------------------------------- /animations/multi_mnist_train_multi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/multi_mnist_train_multi.gif -------------------------------------------------------------------------------- /animations/shapes.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/shapes.gif -------------------------------------------------------------------------------- /animations/shapes_train_multi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/shapes_train_multi.gif -------------------------------------------------------------------------------- /animations/simple_superpos.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/animations/simple_superpos.gif -------------------------------------------------------------------------------- /dae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | from __future__ import division, print_function, unicode_literals 4 | 5 | import os 6 | 7 | import h5py 8 | import numpy as np 9 | from sklearn.metrics import adjusted_mutual_info_score 10 | 11 | import brainstorm as bs 12 | from brainstorm import optional as opt 13 | from brainstorm.tools import create_net_from_spec 14 | from sacred import Experiment 15 | 16 | if opt.has_pycuda: 17 | from brainstorm.handlers import PyCudaHandler 18 | HANDLER = PyCudaHandler() 19 | else: 20 | from brainstorm.handlers import default_handler 21 | HANDLER = default_handler 22 | 23 | ex = Experiment('binding_dae') 24 | 25 | 26 | @ex.config 27 | def cfg(): 28 | dataset = { 29 | 'name': 'corners', 30 | 'salt_n_pepper': 0.5, 31 | 'train_set': 'train_single' # train_multi or train_single 32 | } 33 | training = { 34 | 'learning_rate': 0.01, 35 | 'patience': 10, 36 | 'max_epochs': 500 37 | } 38 | em = { 39 | 'nr_iters': 10, 40 | 'k': 3, 41 | 'nr_samples': 1000, 42 | 'e_step': 'expectation', # expectation, expectation_pi, max, or max_pi 43 | 'init_type': 'gaussian', # gaussian, uniform, or spatial 44 | 'dump_results': None 45 | } 46 | network_spec = "F64" 47 | net_filename = 'Networks/binding_dae_{}_{}.h5'.format( 48 | dataset['name'], 49 | np.random.randint(0, 1000000)) 50 | verbose = True 51 | 52 | 53 | @ex.named_config 54 | def random_search(): 55 | network_spec = "F{act_func}{size}".format( 56 | act_func=str(np.random.choice(['r', 't', 's'])), 57 | size=np.random.choice([100, 250, 500, 1000])) 58 | training = { 59 | 'learning_rate': float(10**np.random.uniform(-3, 0))} 60 | dataset = { 61 | 'salt_n_pepper': float(np.random.randint(0, 10) / 10)} 62 | 63 | 64 | @ex.capture(prefix='dataset') 65 | def open_dataset(name): 66 | data_dir = os.environ.get('BRAINSTORM_DATA_DIR', './Datasets') 67 | filename = os.path.join(data_dir, name + '.h5') 68 | return h5py.File(filename, 'r') 69 | 70 | 71 | @ex.capture(prefix='dataset') 72 | def get_input_shape(train_set): 73 | with open_dataset() as f: 74 | return f[train_set]['default'].shape[2:] 75 | 76 | 77 | @ex.capture 78 | def create_network(network_spec, dataset): 79 | print("Network Specifications:", network_spec) 80 | with open_dataset() as f: 81 | in_shape = f[dataset['train_set']]['default'].shape[2:] 82 | net = create_net_from_spec('multi-label', in_shape, in_shape, network_spec, 83 | use_conv=('C' in network_spec)) 84 | return net 85 | 86 | 87 | @ex.capture 88 | def create_trainer(training, net_filename, verbose): 89 | import os 90 | import os.path 91 | dirname = os.path.dirname(net_filename) 92 | if not os.path.exists(dirname): 93 | os.makedirs(dirname) 94 | 95 | trainer = bs.Trainer(bs.training.SgdStepper(training['learning_rate']), 96 | verbose=verbose) 97 | trainer.train_scorers = [bs.scorers.Hamming()] 98 | trainer.add_hook(bs.hooks.StopOnNan()) 99 | trainer.add_hook(bs.hooks.StopAfterEpoch(training['max_epochs'])) 100 | trainer.add_hook(bs.hooks.MonitorScores('val_iter', trainer.train_scorers, 101 | name='validation')) 102 | trainer.add_hook(bs.hooks.EarlyStopper('validation.total_loss', 103 | patience=training['patience'])) 104 | trainer.add_hook(bs.hooks.SaveBestNetwork('validation.total_loss', 105 | net_filename, criterion='min')) 106 | trainer.add_hook(bs.hooks.InfoUpdater(ex)) 107 | if verbose: 108 | trainer.add_hook(bs.hooks.StopOnSigQuit()) 109 | trainer.add_hook(bs.hooks.ProgressBar()) 110 | return trainer 111 | 112 | 113 | @ex.capture(prefix='dataset') 114 | def get_data_iters(name, salt_n_pepper, train_set): 115 | with open_dataset(name) as f: 116 | train_size = int(0.9 * f[train_set]['default'].shape[1]) 117 | train_data = f[train_set]['default'][:, :train_size] 118 | val_data = f[train_set]['default'][:, train_size:] 119 | 120 | train_iter = bs.data_iterators.AddSaltNPepper( 121 | bs.data_iterators.Minibatches(default=train_data, targets=train_data, 122 | batch_size=100), 123 | {'default': salt_n_pepper}) 124 | 125 | val_iter = bs.data_iterators.AddSaltNPepper( 126 | bs.data_iterators.Minibatches(default=val_data, targets=val_data, 127 | batch_size=100), 128 | {'default': salt_n_pepper}) 129 | return train_iter, val_iter 130 | 131 | 132 | def get_test_data(): 133 | with open_dataset() as f: 134 | test_groups = f['test']['groups'][:] 135 | test_data = f['test']['default'][:] 136 | return test_data, test_groups 137 | 138 | 139 | def evaluate_groups(true_groups, predicted): 140 | idxs = np.where(true_groups != 0.0) 141 | score = adjusted_mutual_info_score(true_groups[idxs], 142 | predicted.argmax(1)[idxs]) 143 | confidence = np.mean(predicted.max(1)[idxs]) 144 | return score, confidence 145 | 146 | 147 | @ex.capture 148 | def load_best_net(net_filename): 149 | net = bs.Network.from_hdf5(net_filename) 150 | net.output_name = "Output.outputs.predictions" 151 | return net 152 | 153 | 154 | @ex.capture(prefix='em') 155 | def get_initial_groups(k, dims, init_type, _rnd, low=.25, high=.75): 156 | shape = (1, 1, dims[0], dims[1], 1, k) # (T, B, H, W, C, K) 157 | if init_type == 'spatial': 158 | assert k == 3 159 | group_channels = np.zeros((dims[0], dims[1], 3)) 160 | group_channels[:, :, 0] = np.linspace(0, 0.5, dims[0])[:, None] 161 | group_channels[:, :, 1] = np.linspace(0, 0.5, dims[1])[None, :] 162 | group_channels[:, :, 2] = 1.0 - group_channels.sum(2) 163 | group_channels = group_channels.reshape(shape) 164 | elif init_type == 'gaussian': 165 | group_channels = np.abs(_rnd.randn(*shape)) 166 | group_channels /= group_channels.sum(5)[..., None] 167 | elif init_type == 'uniform': 168 | group_channels = _rnd.uniform(low, high, size=shape) 169 | group_channels /= group_channels.sum(5)[..., None] 170 | else: 171 | raise ValueError('Unknown init_type "{}"'.format(init_type)) 172 | return group_channels 173 | 174 | 175 | def get_likelihood(Y, T, group_channels): 176 | log_loss = T * np.log(Y.clip(1e-6, 1 - 1e-6)) + \ 177 | (1 - T) * np.log((1 - Y).clip(1e-6, 1 - 1e-6)) 178 | return np.sum(log_loss * group_channels) 179 | 180 | 181 | @ex.capture(prefix='em') 182 | def perform_e_step(T, Y, mixing_factors, e_step, k): 183 | loss = (T * Y + (1 - T) * (1 - Y)) * mixing_factors 184 | if e_step == 'expectation': 185 | group_channels = loss / loss.sum(5)[..., None] 186 | elif e_step == 'expectation_pi': 187 | group_channels = loss / loss.sum(5)[..., None] 188 | mixing_factors = group_channels.reshape(-1, k).sum(0) 189 | mixing_factors /= mixing_factors.sum() 190 | elif e_step == 'max': 191 | group_channels = (loss == loss.max(5)[..., None]).astype(np.float) 192 | elif e_step == 'max_pi': 193 | group_channels = (loss == loss.max(5)[..., None]).astype(np.float) 194 | mixing_factors = group_channels.reshape(-1, k).sum(0) 195 | mixing_factors /= mixing_factors.sum() 196 | else: 197 | raise ValueError('Unknown e_type: "{}"'.format(e_step)) 198 | 199 | return group_channels, mixing_factors 200 | 201 | 202 | @ex.command(prefix='em') 203 | def reconstruction_clustering(network, input_data, true_groups, k, nr_iters): 204 | T, N, H, W, C = input_data.shape 205 | input_data = input_data[..., None] # add a cluster dimension 206 | 207 | mixing_factors = np.ones((1, 1, 1, 1, k)) / k 208 | gamma = get_initial_groups(dims=(H, W)) 209 | output_prior = np.ones_like(input_data) * 0.5 210 | 211 | gammas = np.zeros((nr_iters + 1, 1, H, W, C, k)) 212 | likelihoods = np.zeros(2 * nr_iters + 1) 213 | scores = np.zeros((nr_iters + 1, 2)) 214 | 215 | gammas[0:1] = gamma 216 | likelihoods[0] = get_likelihood(output_prior, input_data, gamma) 217 | scores[0] = evaluate_groups(true_groups.flatten(), 218 | gamma.reshape(-1, k)) 219 | 220 | for j in range(nr_iters): 221 | X = gamma * input_data 222 | Y = np.zeros_like(X) 223 | 224 | # run the k copies of the autoencoder 225 | for _k in range(k): 226 | network.provide_external_data({'default': X[..., _k], 227 | 'targets': input_data[..., 0]}) 228 | network.forward_pass() 229 | Y[..., _k] = network.get(network.output_name).reshape((1, 1, H, W, C)) 230 | 231 | # save the log-likelihood after the M-step 232 | likelihoods[2*j+1] = get_likelihood(Y, input_data, gamma) 233 | # perform an E-step 234 | gamma, mixing_factors = perform_e_step(input_data, Y, mixing_factors) 235 | # save the log-likelihood after the E-step 236 | likelihoods[2*j+2] = get_likelihood(Y, input_data, gamma) 237 | # save the resulting group-assignments 238 | gammas[j+1] = gamma[0] 239 | # save the score and confidence 240 | scores[j+1] = evaluate_groups(true_groups.flatten(), 241 | gamma.reshape(-1, k)) 242 | return gammas, likelihoods, scores 243 | 244 | 245 | @ex.command(prefix='em') 246 | def evaluate(nr_samples, dump_results=None): 247 | network = load_best_net() 248 | test_data, test_groups = get_test_data() 249 | all_scores = [] 250 | all_likelihoods = [] 251 | all_gammas = [] 252 | nr_samples = min(nr_samples, test_data.shape[1]) 253 | for i in range(nr_samples): 254 | gammas, likelihoods, scores = reconstruction_clustering( 255 | network, test_data[:, i:i+1], test_groups[:, i:i+1]) 256 | all_gammas.append(gammas) 257 | all_likelihoods.append(likelihoods) 258 | all_scores.append(scores) 259 | 260 | all_gammas = np.array(all_gammas) 261 | all_likelihoods = np.array(all_likelihoods) 262 | all_scores = np.array(all_scores) 263 | 264 | print('Average Score: {:.4f}'.format(all_scores[:, -1, 0].mean())) 265 | print('Average Confidence: {:.4f}'.format(all_scores[:, -1, 1].mean())) 266 | 267 | if dump_results is not None: 268 | import pickle 269 | with open(dump_results, 'wb') as f: 270 | pickle.dump((all_scores, all_likelihoods, all_gammas), f) 271 | print('wrote the results to {}'.format(dump_results)) 272 | return all_scores[:, -1, 0].mean() 273 | 274 | 275 | @ex.command 276 | def draw_net(filename='net.png'): 277 | network = create_network() 278 | from brainstorm.tools import draw_network 279 | draw_network(network, filename) 280 | 281 | 282 | @ex.pre_run_hook 283 | def initialize(seed): 284 | bs.global_rnd.set_seed(seed) 285 | 286 | 287 | @ex.automain 288 | def run(net_filename): 289 | network = create_network() 290 | network.set_handler(HANDLER) 291 | trainer = create_trainer() 292 | train_iter, val_iter = get_data_iters() 293 | 294 | trainer.train(network, train_iter, val_iter=val_iter) 295 | 296 | ex.add_artifact(net_filename) 297 | 298 | ex.info['best_val_loss'] = float(np.min(trainer.logs['validation']['total_loss'])) 299 | return evaluate() 300 | -------------------------------------------------------------------------------- /dump.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/dump.zip -------------------------------------------------------------------------------- /extra imgs/DAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/extra imgs/DAE.png -------------------------------------------------------------------------------- /extra imgs/FTW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/extra imgs/FTW.png -------------------------------------------------------------------------------- /extra imgs/NNFTW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/extra imgs/NNFTW.png -------------------------------------------------------------------------------- /extra imgs/Tiefighter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/extra imgs/Tiefighter.png -------------------------------------------------------------------------------- /extra imgs/circles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/extra imgs/circles.png -------------------------------------------------------------------------------- /extra imgs/interlocked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/extra imgs/interlocked.png -------------------------------------------------------------------------------- /extra imgs/interrupted_lines.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/extra imgs/interrupted_lines.png -------------------------------------------------------------------------------- /extra imgs/split_lines.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Qwlouse/Binding/87fbc46e5c146c4b84168b63143b8c9042fa622b/extra imgs/split_lines.png -------------------------------------------------------------------------------- /run_best_nets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | from __future__ import division, print_function, unicode_literals 4 | from dae import ex 5 | 6 | @ex.named_config 7 | def best_bars(): 8 | dataset = { 9 | 'name': 'bars', 10 | 'salt_n_pepper': 0.0 11 | } 12 | training = { 13 | 'learning_rate': 0.768014586935404 14 | } 15 | seed = 459182787 16 | network_spec = "Fr100" 17 | net_filename = 'Networks/best_bars_dae.h5' 18 | 19 | ex.run(named_configs=['best_bars']) 20 | 21 | 22 | @ex.named_config 23 | def best_corners(): 24 | dataset = { 25 | 'name': 'corners', 26 | 'salt_n_pepper': 0.0 27 | } 28 | training = { 29 | 'learning_rate': 0.0019199822609484764 30 | } 31 | seed = 158253144 32 | network_spec = "Fr100" 33 | net_filename = 'Networks/best_corners_dae.h5' 34 | 35 | ex.run(named_configs=['best_corners']) 36 | 37 | 38 | @ex.named_config 39 | def best_shapes(): 40 | dataset = { 41 | 'name': 'shapes', 42 | 'salt_n_pepper': 0.4 43 | } 44 | training = { 45 | 'learning_rate': 0.08314720669724956 46 | } 47 | seed = 845841083 48 | network_spec = "Ft500" 49 | net_filename = 'Networks/best_shapes_dae.h5' 50 | 51 | ex.run(named_configs=['best_shapes']) 52 | 53 | 54 | @ex.named_config 55 | def best_multi_mnist(): 56 | dataset = { 57 | 'name': 'multi_mnist', 58 | 'salt_n_pepper': 0.6 59 | } 60 | training = { 61 | 'learning_rate': 0.011361917579645924 62 | } 63 | seed = 498470020 64 | network_spec = "Fr1000" 65 | net_filename = 'Networks/best_multi_mnist_dae.h5' 66 | 67 | ex.run(named_configs=['best_multi_mnist']) 68 | 69 | 70 | @ex.named_config 71 | def best_mnist_shape(): 72 | dataset = { 73 | 'name': 'mnist_shape', 74 | 'salt_n_pepper': 0.6 75 | } 76 | training = { 77 | 'learning_rate': 0.0316848152096582 78 | } 79 | seed = 166717815 80 | network_spec = "Fs250" 81 | net_filename = 'Networks/best_mnist_shape_dae.h5' 82 | 83 | ex.run(named_configs=['best_mnist_shape']) 84 | 85 | 86 | @ex.named_config 87 | def best_simple_superpos(): 88 | dataset = { 89 | 'name': 'simple_superpos', 90 | 'salt_n_pepper': 0.1 91 | } 92 | training = { 93 | 'learning_rate': 0.36662702472680564 94 | } 95 | seed = 848588405 96 | network_spec = "Fr100" 97 | net_filename = 'Networks/best_simple_superpos_dae.h5' 98 | 99 | ex.run(named_configs=['best_simple_superpos']) 100 | 101 | 102 | @ex.named_config 103 | def best_bars_train_multi(): 104 | dataset = { 105 | 'name': 'bars', 106 | 'train_set': 'train_multi', 107 | 'salt_n_pepper': 0.8 108 | } 109 | training = { 110 | 'learning_rate': 0.01219213699462807 111 | } 112 | seed = 141786426 113 | network_spec = "Fs100" 114 | net_filename = 'Networks/best_bars_dae_train_multi.h5' 115 | 116 | ex.run(named_configs=['best_bars_train_multi']) 117 | 118 | 119 | @ex.named_config 120 | def best_corners_train_multi(): 121 | dataset = { 122 | 'name': 'corners', 123 | 'train_set': 'train_multi', 124 | 'salt_n_pepper': 0.7 125 | } 126 | training = { 127 | 'learning_rate': 0.02603487482829947 128 | } 129 | seed = 872544498 130 | network_spec = "Fr100" 131 | net_filename = 'Networks/best_corners_dae_train_multi.h5' 132 | 133 | ex.run(named_configs=['best_corners_train_multi']) 134 | 135 | 136 | @ex.named_config 137 | def best_shapes_train_multi(): 138 | dataset = { 139 | 'name': 'shapes', 140 | 'train_set': 'train_multi', 141 | 'salt_n_pepper': 0.9 142 | } 143 | training = { 144 | 'learning_rate': 0.049401835193689486 145 | } 146 | seed = 702200962 147 | network_spec = "Fs100" 148 | net_filename = 'Networks/best_shapes_dae_train_multi.h5' 149 | 150 | ex.run(named_configs=['best_shapes_train_multi']) 151 | 152 | 153 | @ex.named_config 154 | def best_multi_mnist_train_multi(): 155 | dataset = { 156 | 'name': 'multi_mnist', 157 | 'train_set': 'train_multi', 158 | 'salt_n_pepper': 0.9 159 | } 160 | training = { 161 | 'learning_rate': 0.001785591525476118 162 | } 163 | seed = 632224571 164 | network_spec = "Fs250" 165 | net_filename = 'Networks/best_multi_mnist_dae_train_multi.h5' 166 | 167 | ex.run(named_configs=['best_multi_mnist_train_multi']) 168 | 169 | 170 | @ex.named_config 171 | def best_mnist_shape_train_multi(): 172 | dataset = { 173 | 'name': 'mnist_shape', 174 | 'train_set': 'train_multi', 175 | 'salt_n_pepper': 0.6 176 | } 177 | training = { 178 | 'learning_rate': 0.033199614969711265 179 | } 180 | seed = 900543563 181 | network_spec = "Fr1000" 182 | net_filename = 'Networks/best_mnist_shape_dae_train_multi.h5' 183 | 184 | ex.run(named_configs=['best_mnist_shape_train_multi']) 185 | 186 | 187 | -------------------------------------------------------------------------------- /run_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | from __future__ import division, print_function, unicode_literals 4 | from dae import ex 5 | 6 | for ds in ['bars', 'corners', 'shapes', 'multi_mnist', 'mnist_shape', 'simple_superpos']: 7 | for k in [2, 3, 5, 12]: 8 | ex.run_command('evaluate', config_updates={ 9 | 'dataset.name': ds, 10 | 'net_filename': 'Networks/best_{}_dae.h5'.format(ds), 11 | 'em.k': k, 12 | 'em.nr_iters': 10, 13 | 'em.dump_results': 'Results/{}_10_{}.pickle'.format(ds, k), 14 | 'seed': 1337}) 15 | 16 | # Longer results for bars convergence plot 17 | for k in [2, 3, 5, 12]: 18 | ex.run_command('evaluate', config_updates={ 19 | 'dataset.name': 'bars', 20 | 'net_filename': 'Networks/best_{}_dae.h5'.format('bars'), 21 | 'em.k': k, 22 | 'em.nr_iters': 20, 23 | 'em.dump_results': 'Results/{}_20_{}.pickle'.format('bars', k), 24 | 'seed': 42}) 25 | 26 | # Results for multi-object trained networks 27 | for ds in ['bars', 'corners', 'shapes', 'multi_mnist', 'mnist_shape']: 28 | for k in [2, 3, 5, 12]: 29 | ex.run_command('evaluate', config_updates={ 30 | 'dataset.name': ds, 31 | 'net_filename': 'Networks/best_{}_dae_train_multi.h5'.format(ds), 32 | 'em.k': k, 33 | 'em.nr_iters': 10, 34 | 'em.e_step': 'max', 35 | 'em.dump_results': 'Results/{}_10_{}_train_multi.pickle'.format(ds, k), 36 | 'seed': 23}) 37 | 38 | 39 | -------------------------------------------------------------------------------- /run_random_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | from __future__ import division, print_function, unicode_literals 4 | from sacred.observers import MongoObserver 5 | from dae import ex 6 | 7 | nr_runs_per_dataset = 100 8 | datasets = { 9 | 'bars': 12, 10 | 'corners': 5, 11 | 'shapes': 3, 12 | 'multi_mnist': 3, 13 | 'mnist_shape': 2, 14 | 'simple_superpos':2 15 | } 16 | db_name = 'binding_via_rc' 17 | 18 | # Random search 19 | ex.observers = [MongoObserver.create(db_name=db_name, prefix='random_search')] 20 | for ds, k in datasets.items(): 21 | for i in range(nr_runs_per_dataset): 22 | ex.run(config_updates={'dataset.name': ds, 'verbose': False, 'em.k': k}, 23 | named_configs=['random_search']) 24 | 25 | 26 | # Multi-Train Runs 27 | ex.observers = [MongoObserver.create(db_name=db_name, prefix='train_multi')] 28 | for ds, k in datasets.items(): 29 | if ds == "simple_superpos": continue 30 | for i in range(nr_runs_per_dataset): 31 | ex.run(config_updates={ 32 | 'dataset.name': ds, 33 | 'dataset.train_set': 'train_multi', 34 | 'em.k': k, 35 | 'em.e_step': 'max', 36 | 'verbose': False}, named_configs=['random_search']) 37 | 38 | # MSE-Likelihood Runs 39 | ex.observers = [MongoObserver.create(db_name=db_name, prefix='mse_likelihood')] 40 | for ds, k in datasets.items(): 41 | for i in range(nr_runs_per_dataset): 42 | ex.run(config_updates={ 43 | 'dataset.name': ds, 44 | 'dataset.salt_n_pepper': 0.3, 45 | 'network_spec': 'Fr250', 46 | 'em.k': k, 47 | 'verbose': False}, named_configs=['random_search']) 48 | 49 | --------------------------------------------------------------------------------