├── tests ├── __init__.py └── test_conversion.py ├── lpdn ├── layers │ ├── __init__.py │ ├── reshape.py │ ├── dense.py │ ├── activation.py │ ├── convolution.py │ └── pooling.py ├── utils │ ├── __init__.py │ └── conversion.py └── __init__.py ├── example ├── tmp.h5 └── MNIST.ipynb ├── setup.py ├── .gitignore ├── LICENSE.md └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lpdn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lpdn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example/tmp.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/LPDN/HEAD/example/tmp.h5 -------------------------------------------------------------------------------- /lpdn/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils.conversion import convert_to_lpdn 2 | from .layers.activation import filter_activation 3 | from .layers.pooling import LPMaxPooling2D, LPAveragePooling1D 4 | from .layers.dense import LPDense 5 | from .layers.convolution import LPConv1D, LPConv2D 6 | from .layers.activation import LPActivation 7 | from .layers.reshape import LPFlatten 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="lpdn", 8 | version="0.0.2", 9 | author="Marco Ancona", 10 | author_email="marco.ancona@inf.ethz.ch", 11 | description="Implementation of Lightweight Probabilistic Deep Network (inference-only) for Keras and Tensorflow", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/marcoancona/LPDN", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | test_suite='nose.collector', 22 | tests_require=['nose'], 23 | ) -------------------------------------------------------------------------------- /lpdn/layers/reshape.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras import backend as K 3 | from keras.layers import Flatten 4 | import numpy as np 5 | 6 | 7 | class LPFlatten(Flatten): 8 | """ 9 | Propagate distributions over a Dense layer 10 | """ 11 | def __init__(self, **kwargs): 12 | super(LPFlatten, self).__init__(**kwargs) 13 | self.n_batch = None 14 | self.n_feat = None 15 | 16 | def compute_output_shape(self, input_shape): 17 | self.n_batch = input_shape[0] 18 | self.n_feat = np.prod(input_shape[1:-1]) 19 | return self.n_batch, self.n_feat, 2 20 | 21 | def assert_input_compatibility(self, inputs): 22 | return super(LPFlatten, self).assert_input_compatibility(inputs[..., 0]) 23 | 24 | def call(self, inputs): 25 | n_batch = tf.shape(inputs)[0] 26 | return K.reshape(inputs, (n_batch, -1, 2)) 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | *.log 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | 57 | # PyBuilder 58 | target/ 59 | 60 | \.idea/ 61 | 62 | example/\.ipynb_checkpoints/ 63 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2019 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /lpdn/layers/dense.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras import backend as K 3 | from keras.layers import Dense 4 | from .activation import filter_activation 5 | 6 | 7 | class LPDense(Dense): 8 | """ 9 | Propagate distributions over a probabilistic Dense layer 10 | """ 11 | def __init__(self, units, **kwargs): 12 | super(LPDense, self).__init__(units, **kwargs) 13 | 14 | def build(self, input_shape): 15 | return super(LPDense, self).build(input_shape[:-1]) 16 | 17 | def compute_output_shape(self, input_shape): 18 | original_output_shape = super(LPDense, self).compute_output_shape(input_shape[:-1]) 19 | return original_output_shape + (2,) 20 | 21 | def assert_input_compatibility(self, inputs): 22 | return super(LPDense, self).assert_input_compatibility(inputs[..., 0]) 23 | 24 | def call(self, inputs): 25 | m = inputs[..., 0] 26 | v = inputs[..., 1] 27 | 28 | m = K.dot(m, self.kernel) 29 | v = K.dot(v, self.kernel ** 2) 30 | 31 | if self.use_bias: 32 | m += self.bias 33 | 34 | if self.activation is not None: 35 | m, v = filter_activation(self.activation.__name__, m, v) 36 | 37 | return tf.stack([m, v], -1) -------------------------------------------------------------------------------- /lpdn/layers/activation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow import distributions as dist 4 | from keras.layers import Activation 5 | 6 | exp = np.exp 7 | normal = dist.Normal(loc=0., scale=1.) 8 | 9 | 10 | def _filter_linear(m, v): 11 | return m, v 12 | 13 | 14 | def _filter_relu(m, v): 15 | v = tf.maximum(v, 0.0001) 16 | s = v**0.5 17 | m_out = m*normal.cdf(m/s) + s*normal.prob(m/s) 18 | v_out = (m**2 + v)*normal.cdf(m/s) + (m*s)*normal.prob(m/s) - m_out**2 19 | return m_out, v_out 20 | 21 | 22 | ACTIVATIONS = { 23 | 'linear' : _filter_linear, 24 | 'relu': _filter_relu 25 | } 26 | 27 | 28 | def filter_activation(activation_name, m, v): 29 | activation_name = activation_name.lower() 30 | if activation_name in ACTIVATIONS: 31 | return ACTIVATIONS[activation_name](m, v) 32 | else: 33 | raise Exception("Activation '%s' not supported" % activation_name) 34 | 35 | 36 | class LPActivation(Activation): 37 | def __init__(self, activation, **kwargs): 38 | self.activation_name = activation if isinstance(activation, str) else activation.__name__ 39 | if self.activation_name not in ACTIVATIONS: 40 | raise Exception("Activation '%s' not supported" % self.activation_name) 41 | super(LPActivation, self).__init__(activation, **kwargs) 42 | 43 | def call(self, inputs): 44 | m = inputs[..., 0] 45 | v = inputs[..., 1] 46 | m, v = filter_activation(self.activation_name, m, v) 47 | return tf.stack([m, v], -1) 48 | -------------------------------------------------------------------------------- /tests/test_conversion.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import pkg_resources 3 | import logging, warnings 4 | import tensorflow as tf 5 | import numpy as np 6 | from keras.models import Sequential, Model 7 | from keras.layers import Dense, Flatten, Activation, Input, Conv2D, MaxPooling2D, Conv1D, AveragePooling1D, Dropout 8 | 9 | from lpdn import convert_to_lpdn 10 | 11 | 12 | def dense_sequential_model(): 13 | model = Sequential() 14 | model.add(Dense(32, input_shape=(10,), activation='relu')) 15 | model.add(Dense(32, activation='relu')) 16 | return model 17 | 18 | 19 | def dense_functional_model(): 20 | input = Input((10,)) 21 | y = Dense(32, activation='relu')(input) 22 | y = Dense(32, activation='relu')(y) 23 | return Model(inputs=input, outputs=y) 24 | 25 | 26 | class TestModelConversion(TestCase): 27 | 28 | def setUp(self): 29 | pass 30 | # self.session = tf.Session() 31 | 32 | def tearDown(self): 33 | pass 34 | #self.session.close() 35 | #tf.reset_default_graph() 36 | 37 | def test_tf_available(self): 38 | try: 39 | pkg_resources.require('tensorflow>=1.0') 40 | except Exception: 41 | self.fail("Tensorflow requirement not met") 42 | 43 | def test_dense_sequential_model(self): 44 | model = dense_sequential_model() 45 | lp_model = convert_to_lpdn(model) 46 | model.summary() 47 | lp_model.summary() 48 | 49 | def test_dense_functional_model(self): 50 | model = dense_functional_model() 51 | lp_model = convert_to_lpdn(model) 52 | model.summary() 53 | lp_model.summary() -------------------------------------------------------------------------------- /lpdn/utils/conversion.py: -------------------------------------------------------------------------------- 1 | import logging, warnings 2 | from keras.models import Model 3 | from keras.layers import Dense, Flatten, Activation, Input, InputLayer, Conv2D, MaxPooling2D, Conv1D, AveragePooling1D, Dropout 4 | from ..layers.pooling import LPMaxPooling2D, LPAveragePooling1D 5 | from ..layers.dense import LPDense 6 | from ..layers.convolution import LPConv1D, LPConv2D 7 | from ..layers.activation import LPActivation 8 | from ..layers.reshape import LPFlatten 9 | 10 | IGNORE_LIST = [Dropout, InputLayer] 11 | 12 | 13 | def convert_to_lpdn(keras_model, input_shape=None): 14 | # Create an equivalent probabilistic model. 15 | if input_shape is None: 16 | input_shape = keras_model.layers[0].input_shape[1:] + (2,) 17 | logging.info("Inferred input shape: " + str(input_shape)) 18 | 19 | lp_input = Input(shape=input_shape) 20 | y = lp_input 21 | for li, l in enumerate(keras_model.layers): 22 | if isinstance(l, Conv2D): 23 | y = LPConv2D(l.filters, l.kernel_size, padding=l.padding, activation=l.activation, name=l.name)(y) 24 | elif isinstance(l, Conv1D): 25 | y = LPConv1D(l.filters, l.kernel_size, padding=l.padding, activation=l.activation, name=l.name)(y) 26 | elif isinstance(l, Dense): 27 | y = LPDense(l.units, activation=l.activation, name=l.name)(y) 28 | elif isinstance(l, MaxPooling2D): 29 | y = LPMaxPooling2D(l.pool_size, strides=l.strides, name=l.name)(y) 30 | elif isinstance(l, AveragePooling1D): 31 | y = LPAveragePooling1D(l.pool_size, strides=l.strides, name=l.name)(y) 32 | elif isinstance(l, Flatten): 33 | y = LPFlatten(name=l.name)(y) 34 | elif isinstance(l, Activation): 35 | y = LPActivation(l.activation, name=l.name)(y) 36 | elif any([isinstance(l, layerclass) for layerclass in IGNORE_LIST]): 37 | logging.info("Ignoring layer " + l.name) 38 | else: 39 | raise RuntimeError("Layer %s not supported" % str(l)) 40 | 41 | model = Model(inputs=lp_input, outputs=y) 42 | return model 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Lightweight Probabilistic Deep Networks 2 | This repository contains an unofficial implementation of [Lightweight Probabilistic Deep Networks](https://arxiv.org/abs/1805.11327) using Keras (and assuming Tensorflow backend). 3 | 4 | Notice that this library is at en early stage of development and proper unity tests are not there yet. Moreover, only ReLU and Linear are supported as layer activations. 5 | 6 | ### Install 7 | In your Python 3 (virtual) environment: 8 | ``` 9 | pip install git+https://github.com/marcoancona/LPDN.git 10 | ``` 11 | Notice that this will not install any additional dependencies but LPDN assumes `numpy`, `tensorflow(-gpu)` and `keras` are available. 12 | 13 | ### How to use 14 | A Keras probabilistic model can be built from scratch or converted from an existing model. 15 | 16 | Let assume we want to build the equivalent of the following model: 17 | ```py 18 | from keras.layers import Dense, Flatten, Activation, Conv2D, MaxPooling2D 19 | 20 | model = Sequential() 21 | model.add(Conv2D(32, kernel_size=(3, 3), 22 | activation='relu', 23 | input_shape=input_shape)) 24 | model.add(Conv2D(64, (3, 3), activation='relu')) 25 | model.add(MaxPooling2D(pool_size=(2, 2))) 26 | model.add(Flatten()) 27 | model.add(Dense(num_classes)) 28 | ``` 29 | 30 | We can either use the *conversion utility* 31 | ```py 32 | from lpdn import convert_to_lpdn 33 | 34 | lp_model = convert_to_lpdn(model) 35 | ``` 36 | 37 | or build the model *from scratch* by replacing the original layers with the Lightweight Propabilistic (LP-) equivalent. 38 | 39 | ```py 40 | from lpdn import LPDense, LPFlatten, LPActivation, LPConv2D, LPMaxPooling2D 41 | 42 | lp_model = Sequential() 43 | lp_model.add(LPConv2D(32, kernel_size=(3, 3), 44 | activation='relu', 45 | input_shape=input_shape)) 46 | lp_model.add(LPConv2D(64, (3, 3), activation='relu')) 47 | lp_model.add(LPMaxPooling2D(pool_size=(2, 2))) 48 | lp_model.add(LPFlatten()) 49 | lp_model.add(LPDense(num_classes)) 50 | ``` 51 | Notice that, in both cases, the probabilistic model is initialized with random weights. You can easily transfer the weights of the original model: 52 | ```py 53 | model.save_weights('w.h5') 54 | lp_model.load_weights('w.h5') 55 | ``` 56 | 57 | If `model` takes an input of shape `[batch, n_features]`, `lp_model` requires an input of shape `[batch, n_features, 2]` where mean and variance of the input features are stacked along the last dimension. Similarly, the output will also have one additional dimension to account for mean and variance. 58 | 59 | -------------------------------------------------------------------------------- /lpdn/layers/convolution.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras import backend as K 3 | from keras.layers import Conv2D, Conv1D 4 | from .activation import filter_activation 5 | 6 | 7 | class LPConv2D(Conv2D): 8 | """ 9 | Propagate distributions over a probabilistic Conv2D layer 10 | """ 11 | def __init__(self, filters, kernel_size, **kwargs): 12 | super(LPConv2D, self).__init__(filters, kernel_size, **kwargs) 13 | 14 | def build(self, input_shape): 15 | return super(LPConv2D, self).build(input_shape[:-1]) 16 | 17 | def compute_output_shape(self, input_shape): 18 | original_output_shape = super(LPConv2D, self).compute_output_shape(input_shape[:-1]) 19 | return original_output_shape + (2,) 20 | 21 | def assert_input_compatibility(self, inputs): 22 | return super(LPConv2D, self).assert_input_compatibility(inputs[..., 0]) 23 | 24 | def _conv2d(self, input, kernel): 25 | return K.conv2d(input, kernel, self.strides, self.padding, self.data_format, self.dilation_rate) 26 | 27 | def call(self, inputs): 28 | m = inputs[..., 0] 29 | v = inputs[..., 1] 30 | 31 | m = self._conv2d(m, self.kernel) 32 | v = self._conv2d(v, self.kernel**2) 33 | 34 | if self.use_bias: 35 | m = K.bias_add( 36 | m, 37 | self.bias, 38 | data_format=self.data_format) 39 | 40 | if self.activation is not None: 41 | m, v = filter_activation(self.activation.__name__, m, v) 42 | 43 | return tf.stack([m, v], -1) 44 | 45 | 46 | class LPConv1D(Conv1D): 47 | """ 48 | Propagate distributions over a probabilistic Conv1D layer 49 | """ 50 | def __init__(self, filters, kernel_size, **kwargs): 51 | super(LPConv1D, self).__init__(filters, kernel_size, **kwargs) 52 | 53 | def build(self, input_shape): 54 | return super(LPConv1D, self).build(input_shape[:-1]) 55 | 56 | def compute_output_shape(self, input_shape): 57 | original_output_shape = super(LPConv1D, self).compute_output_shape(input_shape[:-1]) 58 | return original_output_shape + (2,) 59 | 60 | def assert_input_compatibility(self, inputs): 61 | return super(LPConv1D, self).assert_input_compatibility(inputs[..., 0]) 62 | 63 | def _conv1d(self, input, kernel): 64 | return K.conv1d(input, kernel, self.strides, self.padding, self.data_format) 65 | 66 | def call(self, inputs): 67 | m = inputs[..., 0] 68 | v = inputs[..., 1] 69 | 70 | m = self._conv1d(m, self.kernel) 71 | v = self._conv1d(v, self.kernel**2) 72 | 73 | if self.use_bias: 74 | m = K.bias_add( 75 | m, 76 | self.bias, 77 | data_format=self.data_format) 78 | 79 | if self.activation is not None: 80 | m, v = filter_activation(self.activation.__name__, m, v) 81 | 82 | return tf.stack([m, v], -1) 83 | -------------------------------------------------------------------------------- /lpdn/layers/pooling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras import backend as K 3 | from keras.layers import MaxPooling2D, AveragePooling1D 4 | import numpy as np 5 | 6 | from tensorflow.contrib import distributions as dist 7 | exp = np.exp 8 | normal = dist.Normal(loc=0., scale=1.) 9 | 10 | 11 | def _ab_max_pooling(a, b): 12 | mu_a = a[..., 0] 13 | va = a[..., 1] 14 | mu_b = b[..., 0] 15 | vb = b[..., 1] 16 | vavb = tf.maximum(va + vb, 0.00001) ** 0.5 17 | 18 | muamub = mu_a - mu_b 19 | muamub_p = mu_a + mu_b 20 | alpha = muamub / vavb 21 | 22 | mu_c = vavb * normal.prob(alpha) + muamub * normal.cdf(alpha) + mu_b 23 | vc = muamub_p * vavb * normal.prob(alpha) 24 | vc += (mu_a ** 2 + va) * normal.cdf(alpha) + (mu_b ** 2 + vb) * (1. - normal.cdf(alpha)) - mu_c ** 2 25 | return tf.stack([mu_c, vc], -1) 26 | 27 | 28 | class LPMaxPooling2D(MaxPooling2D): 29 | 30 | def assert_input_compatibility(self, inputs): 31 | return super(LPMaxPooling2D, self).assert_input_compatibility(inputs[..., 0]) 32 | 33 | def compute_output_shape(self, input_shape): 34 | original_output_shape = super(LPMaxPooling2D, self).compute_output_shape(input_shape[:-1]) 35 | return original_output_shape + (2,) 36 | 37 | def extract_patches(self, x): 38 | return tf.extract_image_patches( 39 | x, 40 | ksizes=(1,) + self.pool_size + (1,), 41 | strides=(1,) + self.strides + (1,), 42 | padding='VALID', 43 | rates=[1, 1, 1, 1] 44 | ) 45 | 46 | def _pooling_function(self, inputs, pool_size, strides, 47 | padding, data_format): 48 | 49 | m = inputs[..., 0] 50 | v = inputs[..., 1] 51 | 52 | n_channels = m.get_shape().as_list()[-1] 53 | n_pool = np.prod(self.pool_size) 54 | 55 | # TODO: can all this long pipeline of reshaping, transpositions be simplified or made more efficient? 56 | m = self.extract_patches(m) 57 | v = self.extract_patches(v) 58 | 59 | patches_shape = m.get_shape().as_list() 60 | if patches_shape[0] is None: 61 | patches_shape[0] = -1 62 | 63 | m = tf.reshape(m, (patches_shape[0:3] + [n_pool, n_channels])) 64 | v = tf.reshape(v, (patches_shape[0:3] + [n_pool, n_channels])) 65 | 66 | m = tf.transpose(m, (0, 1, 2, 4, 3)) 67 | v = tf.transpose(v, (0, 1, 2, 4, 3)) 68 | 69 | # Everything in batch dimension except dimension with pooling elements 70 | m = tf.reshape(m, (-1, n_pool)) 71 | v = tf.reshape(v, (-1, n_pool)) 72 | 73 | # Transpose because scan is over dimension 0 74 | m = tf.transpose(m) 75 | v = tf.transpose(v) 76 | 77 | # Apply max pooling in sequence 78 | tmp = tf.stack([m, v], -1) 79 | tmp = tf.scan(_ab_max_pooling, tmp, reverse=True) 80 | m = tmp[0, :, 0] 81 | v = tmp[0, :, 1] 82 | 83 | # Start inverting all reshaping to bet (batch, 1) 84 | m = tf.transpose(m) 85 | v = tf.transpose(v) 86 | 87 | m = tf.reshape(m, (patches_shape[0:3] + [n_channels,])) 88 | v = tf.reshape(v, (patches_shape[0:3] + [n_channels,])) 89 | 90 | return tf.stack([m, v], -1) 91 | 92 | 93 | class LPAveragePooling1D(AveragePooling1D): 94 | 95 | def assert_input_compatibility(self, inputs): 96 | return super(LPAveragePooling1D, self).assert_input_compatibility(inputs[..., 0]) 97 | 98 | def compute_output_shape(self, input_shape): 99 | original_output_shape = super(LPAveragePooling1D, self).compute_output_shape(input_shape[:-1]) 100 | return original_output_shape + (2,) 101 | 102 | def _pooling_function(self, inputs, pool_size, strides, 103 | padding, data_format): 104 | m = inputs[..., 0] 105 | v = inputs[..., 1] 106 | 107 | m = K.pool2d(m, pool_size, strides,padding, data_format, pool_mode='avg') 108 | v = K.pool2d(v, pool_size, strides,padding, data_format, pool_mode='avg') / self.pool_size 109 | 110 | return tf.stack([m, v], -1) 111 | 112 | 113 | -------------------------------------------------------------------------------- /example/MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": { 19 | "collapsed": false 20 | }, 21 | "outputs": [ 22 | { 23 | "name": "stderr", 24 | "output_type": "stream", 25 | "text": [ 26 | "Using TensorFlow backend.\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "'''Trains a simple convnet on the MNIST dataset.\n", 32 | "\n", 33 | "Gets to 99.25% test accuracy after 12 epochs\n", 34 | "(there is still a lot of margin for parameter tuning).\n", 35 | "16 seconds per epoch on a GRID K520 GPU.\n", 36 | "'''\n", 37 | "\n", 38 | "from __future__ import print_function\n", 39 | "import keras\n", 40 | "from keras.datasets import mnist\n", 41 | "from keras.models import Sequential, Model\n", 42 | "from keras.layers import Dense, Dropout, Flatten, Activation\n", 43 | "from keras.layers import Conv2D, MaxPooling2D\n", 44 | "from keras import backend as K\n", 45 | "import numpy as np" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 44, 51 | "metadata": { 52 | "collapsed": false 53 | }, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "x_train shape: (60000, 28, 28, 1)\n", 60 | "60000 train samples\n", 61 | "10000 test samples\n", 62 | "Train on 60000 samples, validate on 10000 samples\n", 63 | "Epoch 1/1\n", 64 | "60000/60000 [==============================] - 7s 119us/step - loss: 0.2632 - acc: 0.9206 - val_loss: 0.0572 - val_acc: 0.9808\n", 65 | "Test loss: 0.05723402953627519\n", 66 | "Test accuracy: 0.9808\n" 67 | ] 68 | }, 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "array([[2.55904268e-07, 1.39071110e-08, 6.85280611e-06, 3.12924954e-06,\n", 73 | " 1.06494564e-08, 4.46219817e-09, 3.27008554e-09, 9.99987245e-01,\n", 74 | " 4.73717634e-08, 2.50406583e-06],\n", 75 | " [3.20763729e-06, 2.49420300e-05, 9.99965549e-01, 4.06200525e-06,\n", 76 | " 2.91464763e-09, 3.35106165e-09, 1.43034219e-06, 3.49216434e-09,\n", 77 | " 8.12050985e-07, 6.17292883e-11],\n", 78 | " [5.40509473e-06, 9.99128282e-01, 2.28736302e-04, 4.59426656e-06,\n", 79 | " 1.07721025e-04, 3.36976836e-06, 1.36639908e-04, 3.34077515e-04,\n", 80 | " 3.71899878e-05, 1.40207922e-05],\n", 81 | " [9.99922633e-01, 2.51893624e-07, 2.12697705e-05, 8.25799930e-07,\n", 82 | " 3.33046586e-08, 5.25347286e-06, 4.45402002e-05, 2.59837475e-06,\n", 83 | " 9.42924032e-07, 1.64843470e-06],\n", 84 | " [2.08605184e-06, 3.03851812e-06, 2.29683337e-06, 5.74533203e-07,\n", 85 | " 9.98565853e-01, 6.74309604e-07, 1.59546089e-05, 5.38241329e-06,\n", 86 | " 6.03199624e-06, 1.39810436e-03]], dtype=float32)" 87 | ] 88 | }, 89 | "execution_count": 44, 90 | "metadata": {}, 91 | "output_type": "execute_result" 92 | } 93 | ], 94 | "source": [ 95 | "batch_size = 128\n", 96 | "num_classes = 10\n", 97 | "epochs = 1\n", 98 | "\n", 99 | "# input image dimensions\n", 100 | "img_rows, img_cols = 28, 28\n", 101 | "\n", 102 | "# the data, split between train and test sets\n", 103 | "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", 104 | "\n", 105 | "if K.image_data_format() == 'channels_first':\n", 106 | " x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n", 107 | " x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n", 108 | " input_shape = (1, img_rows, img_cols)\n", 109 | "else:\n", 110 | " x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n", 111 | " x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n", 112 | " input_shape = (img_rows, img_cols, 1)\n", 113 | "\n", 114 | "x_train = x_train.astype('float32')\n", 115 | "x_test = x_test.astype('float32')\n", 116 | "x_train /= 255\n", 117 | "x_test /= 255\n", 118 | "print('x_train shape:', x_train.shape)\n", 119 | "print(x_train.shape[0], 'train samples')\n", 120 | "print(x_test.shape[0], 'test samples')\n", 121 | "\n", 122 | "# convert class vectors to binary class matrices\n", 123 | "y_train = keras.utils.to_categorical(y_train, num_classes)\n", 124 | "y_test = keras.utils.to_categorical(y_test, num_classes)\n", 125 | "\n", 126 | "model = Sequential()\n", 127 | "model.add(Conv2D(32, kernel_size=(3, 3),\n", 128 | " activation='relu',\n", 129 | " input_shape=input_shape))\n", 130 | "model.add(Conv2D(64, (3, 3), activation='relu'))\n", 131 | "model.add(MaxPooling2D(pool_size=(2, 2)))\n", 132 | "model.add(Dropout(0.25))\n", 133 | "model.add(Flatten())\n", 134 | "model.add(Dense(128, activation='relu'))\n", 135 | "model.add(Dropout(0.5))\n", 136 | "model.add(Dense(num_classes))\n", 137 | "model.add(Activation('softmax'))\n", 138 | "\n", 139 | "model.compile(loss=keras.losses.categorical_crossentropy,\n", 140 | " optimizer=keras.optimizers.Adadelta(),\n", 141 | " metrics=['accuracy'])\n", 142 | "\n", 143 | "model.fit(x_train, y_train,\n", 144 | " batch_size=batch_size,\n", 145 | " epochs=epochs,\n", 146 | " verbose=1,\n", 147 | " validation_data=(x_test, y_test))\n", 148 | "score = model.evaluate(x_test, y_test, verbose=0)\n", 149 | "print('Test loss:', score[0])\n", 150 | "print('Test accuracy:', score[1])\n" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 45, 156 | "metadata": { 157 | "collapsed": false 158 | }, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "array([7, 2, 1, 0, 4])" 164 | ] 165 | }, 166 | "execution_count": 45, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "np.argmax(model.predict(x_test[:5]), -1)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 47, 178 | "metadata": { 179 | "collapsed": false 180 | }, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "Tensor(\"input_2:0\", shape=(?, 28, 28, 1, 2), dtype=float32)\n", 187 | "Tensor(\"input_2:0\", shape=(?, 28, 28, 1, 2), dtype=float32)\n", 188 | "Tensor(\"conv2d_3_1/stack:0\", shape=(?, 26, 26, 32, 2), dtype=float32)\n", 189 | "Tensor(\"conv2d_3_1/stack:0\", shape=(?, 26, 26, 32, 2), dtype=float32)\n", 190 | "_________________________________________________________________\n", 191 | "Layer (type) Output Shape Param # \n", 192 | "=================================================================\n", 193 | "input_2 (InputLayer) (None, 28, 28, 1, 2) 0 \n", 194 | "_________________________________________________________________\n", 195 | "conv2d_3 (LPConv2D) (None, 26, 26, 32, 2) 320 \n", 196 | "_________________________________________________________________\n", 197 | "conv2d_4 (LPConv2D) (None, 24, 24, 64, 2) 18496 \n", 198 | "_________________________________________________________________\n", 199 | "max_pooling2d_2 (LPMaxPoolin (None, 12, 12, 64, 2) 0 \n", 200 | "_________________________________________________________________\n", 201 | "flatten_2 (LPFlatten) (None, 9216, 2) 0 \n", 202 | "_________________________________________________________________\n", 203 | "dense_3 (LPDense) (None, 128, 2) 1179776 \n", 204 | "_________________________________________________________________\n", 205 | "dense_4 (LPDense) (None, 10, 2) 1290 \n", 206 | "=================================================================\n", 207 | "Total params: 1,199,882\n", 208 | "Trainable params: 1,199,882\n", 209 | "Non-trainable params: 0\n", 210 | "_________________________________________________________________\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "import tempfile, sys, os, pickle\n", 216 | "sys.path.insert(0, os.path.abspath('..'))\n", 217 | "\n", 218 | "from lpdn import convert_to_lpdn\n", 219 | "lp_model = convert_to_lpdn(Model(inputs=model.inputs, outputs=model.layers[-2].output))\n", 220 | "model.save_weights('tmp.h5')\n", 221 | "lp_model.load_weights('tmp.h5')\n", 222 | "lp_model.summary()" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 53, 228 | "metadata": { 229 | "collapsed": false 230 | }, 231 | "outputs": [ 232 | { 233 | "name": "stdout", 234 | "output_type": "stream", 235 | "text": [ 236 | "(10000, 28, 28, 1, 2)\n", 237 | "(10000, 10, 2)\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "x_dist = np.stack([x_test, 0.5*np.ones_like(x_test)], -1)\n", 243 | "print(x_dist.shape)\n", 244 | "y_dist = lp_model.predict(x_dist)\n", 245 | "print (y_dist.shape)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 55, 251 | "metadata": { 252 | "collapsed": false 253 | }, 254 | "outputs": [ 255 | { 256 | "name": "stderr", 257 | "output_type": "stream", 258 | "text": [ 259 | "WARNING:matplotlib.legend:No handles with labels found to put in legend.\n" 260 | ] 261 | }, 262 | { 263 | "name": "stdout", 264 | "output_type": "stream", 265 | "text": [ 266 | "[7 2 1 0 4]\n" 267 | ] 268 | }, 269 | { 270 | "data": { 271 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAADHNJREFUeJzt3X/sXXV9x/Hne6WUCRIpSlNLA+gISQVX59dqAnE6hABhA/8h8ofpEmJJJstM/EPC/hhxiSGLYsx+uJTRWJ2im8jaZPiDNVuYGWF8YYyfMhips11pJaAgk9KW9/74HswX+N7z/fbec++5X97PR/LN997zPueed2776jn3fE7vJzITSfX8Wt8NSOqH4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VNQxk9zZsbEqj+P4Se5SKuVFXuClPBhLWXek8EfERcCXgBXA32TmDW3rH8fxvD/OH2WXklrcnbuWvO7Qp/0RsQL4S+BiYANwZURsGPb1JE3WKJ/5NwFPZOaTmfkS8E3gsm7akjRuo4R/HfCTec/3NMteJSK2RMRsRMwe4uAIu5PUpbFf7c/MrZk5k5kzK1k17t1JWqJRwr8XWD/v+anNMknLwCjhvwc4MyLOiIhjgY8BO7tpS9K4DT3Ul5mHI+Ia4PvMDfVty8yHO+tM0liNNM6fmbcDt3fUi6QJ8vZeqSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4qa6BTdqife+66BtX/c+bXWbc/562ta6+v/9N+G6klzPPJLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlEjjfNHxG7geeAIcDgzZ7poSm8cB9534sDaYY60bvum/82u29E8Xdzk8+HMfLqD15E0QZ72S0WNGv4EfhAR90bEli4akjQZo572n5eZeyPiFOCOiPhRZt45f4XmH4UtAMfxphF3J6krIx35M3Nv8/sAcBuwaYF1tmbmTGbOrGTVKLuT1KGhwx8Rx0fEm195DFwIPNRVY5LGa5TT/jXAbRHxyut8IzO/10lXksZu6PBn5pPAb3bYi96Ann334LH8PYcPtm578s13dd2O5nGoTyrK8EtFGX6pKMMvFWX4paIMv1SUX92tkeS5G1vr/3rpjQNrv33nH7Zu+xv8x1A9aWk88ktFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUY7zayTPbPj11vraFYO/um3dt1d23Y6Ogkd+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrKcX6N5Pw/aP967X944S0Dayf8y2Ot27ZP4K1ReeSXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIWHeePiG3ApcCBzDy7WbYa+BZwOrAbuCIznx1fm+rLined1Vr/3Cm3tNZvfu7UgbUjP/v5UD2pG0s58n8FuOg1y64FdmXmmcCu5rmkZWTR8GfmncAzr1l8GbC9ebwduLzjviSN2bCf+ddk5r7m8VPAmo76kTQhI1/wy8wEclA9IrZExGxEzB7i4Ki7k9SRYcO/PyLWAjS/DwxaMTO3ZuZMZs6sZNWQu5PUtWHDvxPY3DzeDOzoph1Jk7Jo+CPiFuAu4KyI2BMRVwE3ABdExOPAR5rnkpaRRcf5M/PKAaXzO+5FU2jvBSePtP29z5/WUv3lSK+t0XiHn1SU4ZeKMvxSUYZfKsrwS0UZfqkov7pbrZ7bcGik7e//i40Da2+h/Wu/NV4e+aWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMf5izt48fta6zsu/PPW+meffm9rffWtDwysvdy6pcbNI79UlOGXijL8UlGGXyrK8EtFGX6pKMMvFeU4f3F7fqf9r8C7jz2utb559zmt9VNe+NFR96TJ8MgvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UtOs4fEduAS4EDmXl2s+x64BPAT5vVrsvM28fVpMbnbWcfaK0fyfb/dX/MjpO6bEcTtJQj/1eAixZY/sXM3Nj8GHxpmVk0/Jl5J/DMBHqRNEGjfOa/JiIeiIhtEeG5n7TMDBv+LwPvBDYC+4AvDFoxIrZExGxEzB7i4JC7k9S1ocKfmfsz80hmvgzcBGxqWXdrZs5k5sxKVg3bp6SODRX+iFg77+lHgYe6aUfSpCxlqO8W4EPAWyNiD/AnwIciYiOQwG7g6jH2KGkMFg1/Zl65wOKbx9CLxuCYM05rrX/+rL9vrd/08/Wt9dXb7jrqnjQdvMNPKsrwS0UZfqkowy8VZfilogy/VJRf3f0G9/jVb2+tf2CRmy4/cd+HW+vrvb9r2fLILxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFOc7/Bvfy+hdH2v6XP2ufolvLl0d+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrKcf43uL96/9+OtP26767oqBNNG4/8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1TUouP8EbEe+CqwBkhga2Z+KSJWA98CTgd2A1dk5rPja1WDvPi7mwbWzjvu3xfZ2ls9qlrKkf8w8OnM3AB8APhkRGwArgV2ZeaZwK7muaRlYtHwZ+a+zLyvefw88CiwDrgM2N6sth24fFxNSureUX3mj4jTgfcAdwNrMnNfU3qKuY8FkpaJJYc/Ik4AbgU+lZnPza9lZjJ3PWCh7bZExGxEzB7i4EjNSurOksIfESuZC/7XM/M7zeL9EbG2qa8FDiy0bWZuzcyZzJxZySKzQkqamEXDHxEB3Aw8mpk3zivtBDY3jzcDO7pvT9K4LGWc51zg48CDEXF/s+w64Abg7yLiKuDHwBXjaVGL+Z/fW/ATFwCrov2P+LNPn9NaP2HHva31wXvWtFs0/Jn5QyAGlM/vth1Jk+IdflJRhl8qyvBLRRl+qSjDLxVl+KWi/P+cy8CKE09srX/m3NuHfu1vfPeDrfV3HL5r6NfWdPPILxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFOc6/DLx8sP3rzx75v7cPrH1k70zrtmd+7uHW+pHWqpYzj/xSUYZfKsrwS0UZfqkowy8VZfilogy/VJTj/MtALjLO/1jLUP6x/Lh1W8fx6/LILxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFLRr+iFgfEf8cEY9ExMMR8UfN8usjYm9E3N/8XDL+diV1ZSk3+RwGPp2Z90XEm4F7I+KOpvbFzPz8+NqTNC6Lhj8z9wH7msfPR8SjwLpxNyZpvI7qM39EnA68B7i7WXRNRDwQEdsi4qQB22yJiNmImD1E+22qkiZnyeGPiBOAW4FPZeZzwJeBdwIbmTsz+MJC22Xm1sycycyZlazqoGVJXVhS+CNiJXPB/3pmfgcgM/dn5pHMfBm4Cdg0vjYldW0pV/sDuBl4NDNvnLd87bzVPgo81H17ksZlKVf7zwU+DjwYEfc3y64DroyIjUACu4Grx9KhpLFYytX+HwKxQGn4SeEl9c47/KSiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0VFZk5uZxE/hVfNGf1W4OmJNXB0prW3ae0L7G1YXfZ2Wma+bSkrTjT8r9t5xGxmtswu359p7W1a+wJ7G1ZfvXnaLxVl+KWi+g7/1p7332Zae5vWvsDehtVLb71+5pfUn76P/JJ60kv4I+KiiHgsIp6IiGv76GGQiNgdEQ82Mw/P9tzLtog4EBEPzVu2OiLuiIjHm98LTpPWU29TMXNzy8zSvb530zbj9cRP+yNiBfBfwAXAHuAe4MrMfGSijQwQEbuBmczsfUw4Ij4I/AL4amae3Sz7M+CZzLyh+YfzpMz8zJT0dj3wi75nbm4mlFk7f2Zp4HLg9+nxvWvp6wp6eN/6OPJvAp7IzCcz8yXgm8BlPfQx9TLzTuCZ1yy+DNjePN7O3F+eiRvQ21TIzH2ZeV/z+HnglZmle33vWvrqRR/hXwf8ZN7zPUzXlN8J/CAi7o2ILX03s4A1zbTpAE8Ba/psZgGLztw8Sa+ZWXpq3rthZrzumhf8Xu+8zPwt4GLgk83p7VTKuc9s0zRcs6SZmydlgZmlf6XP927YGa+71kf49wLr5z0/tVk2FTJzb/P7AHAb0zf78P5XJkltfh/ouZ9fmaaZmxeaWZopeO+macbrPsJ/D3BmRJwREccCHwN29tDH60TE8c2FGCLieOBCpm/24Z3A5ubxZmBHj728yrTM3DxoZml6fu+mbsbrzJz4D3AJc1f8/xv44z56GNDXO4D/bH4e7rs34BbmTgMPMXdt5CrgZGAX8DjwT8DqKerta8CDwAPMBW1tT72dx9wp/QPA/c3PJX2/dy199fK+eYefVJQX/KSiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFfX/98XKu6yssugAAAAASUVORK5CYII=\n", 272 | "text/plain": [ 273 | "
" 274 | ] 275 | }, 276 | "metadata": { 277 | "needs_background": "light" 278 | }, 279 | "output_type": "display_data" 280 | }, 281 | { 282 | "data": { 283 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAADKCAYAAACFWKrDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAACqNJREFUeJzt3V9snXUdx/HPZ6cd6zocoBh0XexMCGYSdaZBcNGYIcn4k3ELCRj/JPNCdFMSHN55Z6IheEGIC0w0ENAAiQRRIGGEkOCkDFS2QTL5t+Kgw4FAGdvafb04p7TUlvM0nKe/87XvV7Jk7U5OP3lyzjvPnp72OCIEAMhjSekBAID5IdwAkAzhBoBkCDcAJEO4ASAZwg0AyRBuAEiGcANAMoQbAJLpqeVOl/dH78rT6rjryhrHin759xxfeaL0hK6x9LBLT9D4SeU3SFLvO+UfF2cM/rv0BEnSgUOnl56gaJReIB1/47AmxsYqPUBrCXfvytM0+O0f1XHXla18vvwTQ5JGN71bekLXGLi1t/QEvX5m+Q2S9PHdR0pP0LZbflt6giRp643fLT1B4ytKL5Be/NV1lW/LpRIASIZwA0AyhBsAkiHcAJAM4QaAZAg3ACRDuAEgGcINAMkQbgBIhnADQDKEGwCSqRRu2xttP2t7v+1tdY8CAMytbbhtNyTdIOlCSWslXW57bd3DAACzq3LGfY6k/RHxXEQck3SHpEvrnQUAmEuVcK+SdGDaxyOtzwEACujYNydtb7Y9bHt4/J2xTt0tAGCGKuF+WdLqaR8PtD73PhGxPSKGImKoZ3l/p/YBAGaoEu7HJZ1pe43tpZIuk3RPvbMAAHNp+9ZlETFu+ypJ90tqSNoREXtqXwYAmFWl95yMiPsk3VfzFgBABfzkJAAkQ7gBIBnCDQDJEG4ASIZwA0AyhBsAkiHcAJAM4QaAZAg3ACRDuAEgGcINAMkQbgBIptIvmZqvWCJN9EUdd13Zq+cW/fLv+efXbik9QWvu2Vx6giRp2aEjpSdoRX93nKscuGBZ6Qn62Te/UXqCJOndi8u2QpI+ve2x0hN0MKq/AU13PIoBAJURbgBIhnADQDKEGwCSIdwAkAzhBoBkCDcAJEO4ASAZwg0AyRBuAEiGcANAMoQbAJJpG27bO2yP2n56IQYBAD5YlTPuWyRtrHkHAKCituGOiEckHV6ALQCACrjGDQDJdCzctjfbHrY9PDFW/ReCAwDmp2PhjojtETEUEUON/v5O3S0AYAYulQBAMlVeDni7pMcknWV7xPZ36p8FAJhL2zcLjojLF2IIAKAaLpUAQDKEGwCSIdwAkAzhBoBkCDcAJEO4ASAZwg0AyRBuAEiGcANAMoQbAJIh3ACQjCOi43fa/9HV8dmLtnb8fufj1a+cKPr1J3nCpSeob6RReoIkaembpRdI0SWnKkvGO/+8m6+xgdILmpYcLf8cOfXZ8r14+v7r9fbhA5UORpc8jAEAVRFuAEiGcANAMoQbAJIh3ACQDOEGgGQINwAkQ7gBIBnCDQDJEG4ASIZwA0AyhBsAkmkbbturbe+0vdf2HttbFmIYAGB2PRVuMy7p6ojYbftkSU/YfjAi9ta8DQAwi7Zn3BFxMCJ2t/7+lqR9klbVPQwAMLt5XeO2PShpnaRddYwBALRXOdy2V0i6S9LWiPifX4lve7PtYdvD40fHOrkRADBNpXDb7lUz2rdFxN2z3SYitkfEUEQM9ZzU38mNAIBpqryqxJJulrQvIq6rfxIA4INUOeNeL+lKSRtsP9X6c1HNuwAAc2j7csCIeFRS+XfzBABI4icnASAdwg0AyRBuAEiGcANAMoQbAJIh3ACQDOEGgGQINwAkQ7gBIBnCDQDJEG4ASIZwA0AyVd5zct5OOeMtbfrxzjruurI//HxD0a8/afS8idIT1DhaekHTR14cLz1Br3ypUXpCUxecMi0b5XfHTXr3tPLH4sQ8atwFDx8AwHwQbgBIhnADQDKEGwCSIdwAkAzhBoBkCDcAJEO4ASAZwg0AyRBuAEiGcANAMoQbAJJpG27by2z/1fbfbO+x/dOFGAYAmF2V30d1VNKGiHjbdq+kR23/KSL+UvM2AMAs2oY7IkLS260Pe1t/os5RAIC5VbrGbbth+ylJo5IejIhd9c4CAMylUrgjYiIiviBpQNI5ts+eeRvbm20P2x4ee/1Yp3cCAFrm9aqSiHhD0k5JG2f5t+0RMRQRQ/2nLu3UPgDADFVeVXK67VNaf++TdIGkZ+oeBgCYXZVXlXxC0m9sN9QM/e8j4t56ZwEA5lLlVSV/l7RuAbYAACrgJycBIBnCDQDJEG4ASIZwA0AyhBsAkiHcAJAM4QaAZAg3ACRDuAEgGcINAMkQbgBIpsovmZq3w6+drN/tOL+Ou65saU93vElP38FaDnFK/VePlJ6gngcGS0+QJPUcKb1AkksPaBrvK71AevNzx0tP0MQfqzeLM24ASIZwA0AyhBsAkiHcAJAM4QaAZAg3ACRDuAEgGcINAMkQbgBIhnADQDKEGwCSIdwAkEzlcNtu2H7S9r11DgIAfLD5nHFvkbSvriEAgGoqhdv2gKSLJd1U7xwAQDtVz7ivl3SNpBM1bgEAVNA23LYvkTQaEU+0ud1m28O2hyeOjHVsIADg/aqcca+XtMn2C5LukLTB9q0zbxQR2yNiKCKGGn39HZ4JAJjUNtwRcW1EDETEoKTLJD0UEVfUvgwAMCtexw0AyczrnWwj4mFJD9eyBABQCWfcAJAM4QaAZAg3ACRDuAEgGcINAMkQbgBIhnADQDKEGwCSIdwAkAzhBoBkCDcAJEO4ASAZR0Tn79Q+JOnFD3EXH5P0WofmZMexmMKxmMKxmPL/ciw+FRGnV7lhLeH+sGwPR8RQ6R3dgGMxhWMxhWMxZTEeCy6VAEAyhBsAkunWcG8vPaCLcCymcCymcCymLLpj0ZXXuAEAc+vWM24AwBy6Lty2N9p+1vZ+29tK7ynF9mrbO23vtb3H9pbSm0qz3bD9pO17S28pyfYptu+0/YztfbbPK72pFNs/bD0/nrZ9u+1lpTcthK4Kt+2GpBskXShpraTLba8tu6qYcUlXR8RaSedK+t4iPhaTtkjaV3pEF/ilpD9HxGckfV6L9JjYXiXpB5KGIuJsSQ1Jl5VdtTC6KtySzpG0PyKei4hjku6QdGnhTUVExMGI2N36+1tqPjlXlV1Vju0BSRdLuqn0lpJsr5T0VUk3S1JEHIuIN8quKqpHUp/tHknLJf2r8J4F0W3hXiXpwLSPR7SIYzXJ9qCkdZJ2lV1S1PWSrpF0ovSQwtZIOiTp163LRjfZ7i89qoSIeFnSLyS9JOmgpP9ExANlVy2Mbgs3ZrC9QtJdkrZGxJul95Rg+xJJoxHxROktXaBH0hcl3RgR6ySNSVqU3wuyfaqa/yNfI+mTkvptX1F21cLotnC/LGn1tI8HWp9blGz3qhnt2yLi7tJ7ClovaZPtF9S8fLbB9q1lJxUzImkkIib/93WnmiFfjL4u6fmIOBQRxyXdLenLhTctiG4L9+OSzrS9xvZSNb/RcE/hTUXYtprXMfdFxHWl95QUEddGxEBEDKr5mHgoIhbFmdVMEfGKpAO2z2p96nxJewtOKuklSefaXt56vpyvRfKN2p7SA6aLiHHbV0m6X83vEO+IiD2FZ5WyXtKVkv5h+6nW534SEfcV3ITu8H1Jt7VObp6T9K3Ce4qIiF2275S0W81XYT2pRfJTlPzkJAAk022XSgAAbRBuAEiGcANAMoQbAJIh3ACQDOEGgGQINwAkQ7gBIJn/ArvESddXVvtQAAAAAElFTkSuQmCC\n", 284 | "text/plain": [ 285 | "
" 286 | ] 287 | }, 288 | "metadata": { 289 | "needs_background": "light" 290 | }, 291 | "output_type": "display_data" 292 | }, 293 | { 294 | "data": { 295 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAADKCAYAAACFWKrDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAC1xJREFUeJzt3X1oXfUdx/HPxyR9itXWJ6pttQWds9NNR/BhHRvo3DoV/VdBYWNQBnXqEEQ39seQsbENp4wiFHUb+MRQ/3Di5gQrTuacsdZpWwfFaa2rtj7UtOmwafLdH/fW1Jr0nph78rtf836B0KSXkw+H5O3pyU2uI0IAgDwOKz0AADAxhBsAkiHcAJAM4QaAZAg3ACRDuAEgGcINAMkQbgBIhnADQDLddRy0a25vdB89v45DVzbj/aIf/iN755VeIPUMuPQESdLwrNILpOiUS5VZI6UXKPZ1xudF9+7yO4Znl14gDb3/noYHByudjFrC3X30fC34yTV1HLqyJQ90xo/yv3ZZ+VIserz8F4YkvXdaV+kJGplRekHDyGm7S0/Q3nc74P+kkhb8rfzXyLtfLP81svW3v6n82PJnDAAwIYQbAJIh3ACQTC33uAFgOjuip0erzlimE+f2yhq9fx4Kbdk1qNUvbdTA0NCnPj7hBoA2W3XGMp255CT19PbKPiDcETpqcFCrJP183Yuf+vjcKgGANjtxbu8noi1JttXT26sT5/ZO6viEGwDazPInov3R39kfu33yaRBuAEiGcANAMoQbANosFBrvhdgjQqHJ/WQ34QaANtuya1BDg4OfiHdEaGhwUFt2DU7q+JWeDmh7haTbJHVJuiMifjGpjwoAn2GrX9qoVdIhn8c9GS3DbbtL0mpJF0raKuk52w9HxOQ+MgB8Rg0MDU3qedqtVLlVcrakzRHxakTslXS/pMtqWwQAOKQq4V4o6Y0D3t7afB8AoIC2fXPS9krb/bb7hyd54x0AML4q4X5T0uID3l7UfN/HRMSaiOiLiL6uSf44JwBgfFXC/ZykU2wvtT1D0uWSHq53FgBgPC2fVRIR+2xfLekxNZ4OeFdEbKh9GQBgTJWexx0Rj0p6tOYtAIAK+MlJAEiGcANAMoQbAJIh3ACQDOEGgGQINwAkQ7gBIBnCDQDJEG4ASIZwA0AyhBsAkiHcAJBMpV8yNVHHH/6BfvzVP9Vx6Mp+trczXl3tivOeKT1Bx319oPQESdK3esu/TOnNb15ceoIk6d6la0tP0Mn3fr/0BEnS28uj9YNq9qtv3lt6gm68+73Kj+WKGwCSIdwAkAzhBoBkCDcAJEO4ASAZwg0AyRBuAEiGcANAMoQbAJIh3ACQDOEGgGQINwAk0zLctu+yvd32y1MxCABwaFWuuH8vaUXNOwAAFbUMd0Q8Jan67xsEANSKe9wAkEzbwm17pe1+2/273h9q12EBAAdpW7gjYk1E9EVE39z5Pe06LADgINwqAYBkqjwd8D5Jz0g61fZW29+rfxYAYDwtXyw4Iq6YiiEAgGq4VQIAyRBuAEiGcANAMoQbAJIh3ACQDOEGgGQINwAkQ7gBIBnCDQDJEG4ASIZwA0Ayjoi2H7T36MXxhYuua/txJ+Ko/neKfvz9hhbMLT1BGik9oGHPghmlJ+iITR+UniBJ2nHO/NITdNyDG0tPkCSNDP6v9ATtW3566Qnq71+tgYGtrvJYrrgBIBnCDQDJEG4ASIZwA0AyhBsAkiHcAJAM4QaAZAg3ACRDuAEgGcINAMkQbgBIhnADQDItw217se21tjfa3mD72qkYBgAYW3eFx+yTdH1ErLM9V9Lzth+PiM741WIAMM20vOKOiG0Rsa75512SNklaWPcwAMDYJnSP2/YSSWdJeraOMQCA1iqH2/bhkh6UdF1EDIzx9ytt99vu3/fhYDs3AgAOUCnctnvUiPY9EfHQWI+JiDUR0RcRfd0ze9u5EQBwgCrPKrGkOyVtiohb6p8EADiUKlfcyyVdJel82+ub/11U8y4AwDhaPh0wIp6WVOkFLAEA9eMnJwEgGcINAMkQbgBIhnADQDKEGwCSIdwAkAzhBoBkCDcAJEO4ASAZwg0AyRBuAEiGcANAMlVec3LCDjtmSLO/s62OQ1f2+okd8upqfR+UXqA9Ozrk96P3DJdeoOGZ80pPkCR5pPQCafPtS0pPkCTNfKH85+fuk4dKT9CHr1Z/LFfcAJAM4QaAZAg3ACRDuAEgGcINAMkQbgBIhnADQDKEGwCSIdwAkAzhBoBkCDcAJEO4ASCZluG2Pcv2P22/aHuD7Z9OxTAAwNiq/HbADyWdHxG7bfdIetr2nyPiHzVvAwCMoWW4IyIk7W6+2dP8L+ocBQAYX6V73La7bK+XtF3S4xHxbL2zAADjqRTuiBiOiDMlLZJ0tu3TD36M7ZW2+233D+3c0+6dAICmCT2rJCJ2SloracUYf7cmIvoioq9n3px27QMAHKTKs0qOtT2v+efZki6U9ErdwwAAY6vyrJLjJf3Bdpcaof9jRDxS7ywAwHiqPKvkX5LOmoItAIAK+MlJAEiGcANAMoQbAJIh3ACQDOEGgGQINwAkQ7gBIBnCDQDJEG4ASIZwA0AyhBsAkqnyS6YmbGhXj95+amEdh65sZFbRD/+RmU8fWXqCemeWXtCwZ/Fw6Qna+TmXniBJOu6ct0pP0M71C0pPkCSd8Mu/l56g124+r/QEeaj65yZX3ACQDOEGgGQINwAkQ7gBIBnCDQDJEG4ASIZwA0AyhBsAkiHcAJAM4QaAZAg3ACRDuAEgmcrhtt1l+wXbj9Q5CABwaBO54r5W0qa6hgAAqqkUbtuLJF0s6Y565wAAWql6xX2rpBskjdS4BQBQQctw275E0vaIeL7F41ba7rfdP7xnsG0DAQAfV+WKe7mkS22/Jul+SefbvvvgB0XEmojoi4i+rjm9bZ4JANivZbgj4qaIWBQRSyRdLumJiLiy9mUAgDHxPG4ASGZCLxYcEU9KerKWJQCASrjiBoBkCDcAJEO4ASAZwg0AyRBuAEiGcANAMoQbAJIh3ACQDOEGgGQINwAkQ7gBIBnCDQDJOCLaf1B7h6TXJ3GIYyS906Y52XEuRnEuRnEuRn1WzsVJEXFslQfWEu7Jst0fEX2ld3QCzsUozsUozsWo6XguuFUCAMkQbgBIplPDvab0gA7CuRjFuRjFuRg17c5FR97jBgCMr1OvuAEA4+i4cNteYfvftjfbvrH0nlJsL7a91vZG2xtsX1t6U2m2u2y/YPuR0ltKsj3P9gO2X7G9yfZ5pTeVYvuHza+Pl23fZ3tW6U1ToaPCbbtL0mpJ35a0TNIVtpeVXVXMPknXR8QySedKWjWNz8V+10raVHpEB7hN0l8i4vOSvqRpek5sL5R0jaS+iDhdUpeky8uumhodFW5JZ0vaHBGvRsReSfdLuqzwpiIiYltErGv+eZcaX5wLy64qx/YiSRdLuqP0lpJsHynpa5LulKSI2BsRO8uuKqpb0mzb3ZLmSPpv4T1TotPCvVDSGwe8vVXTOFb72V4i6SxJz5ZdUtStkm6QNFJ6SGFLJe2Q9LvmbaM7bPeWHlVCRLwp6deStkjaJumDiPhr2VVTo9PCjYPYPlzSg5Kui4iB0ntKsH2JpO0R8XzpLR2gW9KXJd0eEWdJGpQ0Lb8XZHu+Gv8iXyrpBEm9tq8su2pqdFq435S0+IC3FzXfNy3Z7lEj2vdExEOl9xS0XNKltl9T4/bZ+bbvLjupmK2StkbE/n99PaBGyKejb0j6T0TsiIghSQ9J+krhTVOi08L9nKRTbC+1PUONbzQ8XHhTEbatxn3MTRFxS+k9JUXETRGxKCKWqPE58URETIsrq4NFxFuS3rB9avNdF0jaWHBSSVsknWt7TvPr5QJNk2/UdpcecKCI2Gf7akmPqfEd4rsiYkPhWaUsl3SVpJdsr2++70cR8WjBTegMP5B0T/Pi5lVJ3y28p4iIeNb2A5LWqfEsrBc0TX6Kkp+cBIBkOu1WCQCgBcINAMkQbgBIhnADQDKEGwCSIdwAkAzhBoBkCDcAJPN/cR2V3Wt9doUAAAAASUVORK5CYII=\n", 296 | "text/plain": [ 297 | "
" 298 | ] 299 | }, 300 | "metadata": { 301 | "needs_background": "light" 302 | }, 303 | "output_type": "display_data" 304 | } 305 | ], 306 | "source": [ 307 | "%matplotlib inline\n", 308 | "import matplotlib.pyplot as plt\n", 309 | "import numpy as np\n", 310 | "from matplotlib import colors\n", 311 | "from matplotlib.ticker import PercentFormatter\n", 312 | "import matplotlib.mlab as mlab\n", 313 | "plt.imshow(x_dist[2, :, :, 0, 0])\n", 314 | "plt.figure()\n", 315 | "plt.imshow(y_dist[0:5, :, 0])\n", 316 | "plt.figure()\n", 317 | "plt.imshow(y_dist[0:5, :, 1])\n", 318 | "plt.legend()\n", 319 | "print (np.argmax(y_dist[0:5, :, 0], 1))" 320 | ] 321 | } 322 | ], 323 | "metadata": { 324 | "kernelspec": { 325 | "display_name": "Python 3", 326 | "language": "python", 327 | "name": "python3" 328 | }, 329 | "language_info": { 330 | "codemirror_mode": { 331 | "name": "ipython", 332 | "version": 3 333 | }, 334 | "file_extension": ".py", 335 | "mimetype": "text/x-python", 336 | "name": "python", 337 | "nbconvert_exporter": "python", 338 | "pygments_lexer": "ipython3", 339 | "version": "3.6.7" 340 | } 341 | }, 342 | "nbformat": 4, 343 | "nbformat_minor": 2 344 | } 345 | --------------------------------------------------------------------------------