├── LICENSE ├── README.md ├── SET-MLP-Keras-Weights-Mask ├── cifar10_models_performance.pdf ├── dense_mlp_keras_cifar10.py ├── fixprob_mlp_keras_cifar10.py ├── plot_performance.py ├── results │ ├── dense_mlp_srelu_sgd_cifar10_acc.txt │ ├── fixprob_mlp_srelu_sgd_cifar10_acc.txt │ └── set_mlp_srelu_sgd_cifar10_acc.txt └── set_mlp_keras_cifar10.py ├── SET-MLP-Sparse-Python-Data-Structures ├── Results │ ├── mlp_fixprob.txt │ └── set_mlp.txt ├── data │ └── lung.mat ├── fixprob_mlp_sparse_data_structures.py ├── set_mlp_sparse_data_structures.py ├── sparseoperations.c ├── sparseoperations.cpython-35m-x86_64-linux-gnu.so ├── sparseoperations.html └── sparseoperations.pyx ├── SET-RBM-Sparse-Python-Data-Structures ├── Results │ ├── rbm_fixprob.txt │ └── set_rbm.txt ├── data │ └── COIL20.mat ├── fixprob_rbm_sparse_data_structures.py ├── set_rbm_sparse_data_structures.py ├── sparseoperations.c ├── sparseoperations.cpython-35m-x86_64-linux-gnu.so ├── sparseoperations.html └── sparseoperations.pyx ├── Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning ├── Pretrained_results │ ├── fashion_mnist_connections_evolution_per_input_pixel_rand0.gif │ ├── fc_mlp_2000_training_samples_rand0.txt │ ├── fc_mlp_2000_training_samples_rand1.txt │ ├── fc_mlp_2000_training_samples_rand2.txt │ ├── fc_mlp_2000_training_samples_rand3.txt │ ├── fc_mlp_2000_training_samples_rand4.txt │ ├── fixprob_mlp_2000_training_samples_e13_rand0.txt │ ├── fixprob_mlp_2000_training_samples_e13_rand1.txt │ ├── fixprob_mlp_2000_training_samples_e13_rand2.txt │ ├── fixprob_mlp_2000_training_samples_e13_rand3.txt │ ├── fixprob_mlp_2000_training_samples_e13_rand4.txt │ ├── mnist_learning_curves_samples2000.pdf │ ├── set_mlp_2000_training_samples_e13_rand0.txt │ ├── set_mlp_2000_training_samples_e13_rand0_input_connections.npz │ ├── set_mlp_2000_training_samples_e13_rand1.txt │ ├── set_mlp_2000_training_samples_e13_rand1_input_connections.npz │ ├── set_mlp_2000_training_samples_e13_rand2.txt │ ├── set_mlp_2000_training_samples_e13_rand2_input_connections.npz │ ├── set_mlp_2000_training_samples_e13_rand3.txt │ ├── set_mlp_2000_training_samples_e13_rand3_input_connections.npz │ ├── set_mlp_2000_training_samples_e13_rand4.txt │ └── set_mlp_2000_training_samples_e13_rand4_input_connections.npz ├── Results │ └── mnist_learning_curves_samples2000.pdf ├── fc_mlp.py ├── fixprob_mlp.py ├── plot_input_layer_connectivity.py ├── plot_learning_curve.py ├── set_mlp.py ├── sparseoperations.c ├── sparseoperations.html ├── sparseoperations.pyx └── sparseoperations.so └── Tutorial-IJCAI-2019-Scalable-Deep-Learning ├── Pretrained_results ├── fashion_mnist_connections_evolution_per_input_pixel_rand0.gif ├── fc_mlp_2000_training_samples_rand0.txt ├── fc_mlp_2000_training_samples_rand1.txt ├── fc_mlp_2000_training_samples_rand2.txt ├── fc_mlp_2000_training_samples_rand3.txt ├── fc_mlp_2000_training_samples_rand4.txt ├── fixprob_mlp_2000_training_samples_e13_rand0.txt ├── fixprob_mlp_2000_training_samples_e13_rand1.txt ├── fixprob_mlp_2000_training_samples_e13_rand2.txt ├── fixprob_mlp_2000_training_samples_e13_rand3.txt ├── fixprob_mlp_2000_training_samples_e13_rand4.txt ├── mnist_learning_curves_samples2000.pdf ├── set_mlp_2000_training_samples_e13_rand0.txt ├── set_mlp_2000_training_samples_e13_rand0_input_connections.npz ├── set_mlp_2000_training_samples_e13_rand1.txt ├── set_mlp_2000_training_samples_e13_rand1_input_connections.npz ├── set_mlp_2000_training_samples_e13_rand2.txt ├── set_mlp_2000_training_samples_e13_rand2_input_connections.npz ├── set_mlp_2000_training_samples_e13_rand3.txt ├── set_mlp_2000_training_samples_e13_rand3_input_connections.npz ├── set_mlp_2000_training_samples_e13_rand4.txt └── set_mlp_2000_training_samples_e13_rand4_input_connections.npz ├── Results └── mnist_learning_curves_samples2000.pdf ├── data └── fashion_mnist.npz ├── fc_mlp.py ├── fixprob_mlp.py ├── plot_input_layer_connectivity.py ├── plot_learning_curve.py ├── set_mlp.py ├── sparseoperations.c ├── sparseoperations.cpython-35m-x86_64-linux-gnu.so ├── sparseoperations.html └── sparseoperations.pyx /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 dcmocanu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sparse-evolutionary-artificial-neural-networks 2 | * Proof of concept implementations of various sparse artificial neural network models with adaptive sparse connectivity trained with the Sparse Evolutionary Training (SET) algorithm - https://arxiv.org/abs/1707.04780, 15 July 2017 3 | * **SET** was the first algorithm which demonstrated that **sparse neural networks** can be trained from scratch to **outperform dense neural networks** within the framework of gradient descent and introduced the idea of optimizing the sparse connections between neurons together with the weights during training. 4 | * On short, **SET** laid the ground for what is today known as **sparse training** with **dynamic sparsity** (also referred to in some papers as dynamic sparse training, pruning and growth strategies, and so on). 5 | 6 | * The following implementations are distributed in the hope that they may be useful, but without any warranties; Their use is entirely at the user's own risk. 7 | 8 | ###### Implementation 1 - using binary masks - SET-MLP with Keras and Tensorflow (SET-MLP-Keras-Weights-Mask) 9 | 10 | * Proof of concept implementation of Sparse Evolutionary Training (SET) for Multi Layer Perceptron (MLP) on CIFAR10 using Keras and a mask over weights. 11 | * This implementation can be used to test SET in varying conditions, using the Keras framework versatility, e.g. various optimizers, activation layers, tensorflow. 12 | * Also it can be easily adapted for Convolutional Neural Networks or other models which have dense layers. 13 | * Variants of this implementation have been used to perform the experiments from Reference 1 with MLP and CNN. 14 | * However, due the fact that the weights are stored in the standard Keras format (dense matrices), this implementation can not scale properly. 15 | * If you would like to build an SET-MLP with over 100000 neurons, please use Implementation 2. 16 | 17 | ###### Implementation 2 - truly sparse implementation - SET-MLP using just sparse data structures from pure Python 3 (SET-MLP-Sparse-Python-Data-Structures) 18 | 19 | * An improved version of this Implementation can be found here https://github.com/SelimaC/Tutorial-SCADS-Summer-School-2020-Scalable-Deep-Learning 20 | 21 | * Proof of concept implementation of Sparse Evolutionary Training (SET) for Multi Layer Perceptron (MLP) on lung dataset using Python, SciPy sparse data structures, and (optionally) Cython. 22 | * This implementation was developed just in the last stages of the reviewing process, and we are briefly discussing about it in the "Peer Review File" which can be downloaded from Reference 1 website. 23 | * This implementation can be used to create SET-MLP with hundred of thousands of neurons on a standard laptop. It was made starting from the vanilla fully connected MLP implementation of Ritchie Vink (https://www.ritchievink.com/) and we would like to acknowledge his work and thank him. Also, we would like to thank Thomas Hagebols for analyzing the performance of SciPy sparse matrix operations. We thank also to Amarsagar Reddy Ramapuram Matavalam from Iowa State University (amar@iastate.edu), who provided us a faster implementation of the "weightsEvolution" method, after the initial release of this code. 24 | * If you would like to try large SET-MLP models, below are the expected running times measured on my laptop (16 GB RAM) using the original implementation of the "weightsEvolution" method. I have used exactly the model and the dataset from the file "set_mlp_sparse_data_structures.py" and I just changed the number of hidden neurons per layer: 25 | - 3,000 neurons/hidden layer, 12,317 neurons in total 26 | 0.3 minutes/epoch 27 | - 30,000 neurons/hidden layer, 93,317 neurons in total 28 | 3 minutes/epoch 29 | - 300,000 neurons/hidden layer, 903,317 neurons in total 30 | 49 minutes/epoch 31 | - 600,000 neurons/hidden layer, 1,803,317 neurons in total 32 | 112 minutes/epoch 33 | * If you would like to try out SET-MLP with various activation functions, optimization methods and so on (in the detriment of scalability) please use Implementation 1. 34 | 35 | ###### Implementation 3 - truly sparse implementation - SET-RBM using just sparse data structures from pure Python 3 (SET-RBM-Sparse-Python-Data-Structures) 36 | 37 | * Proof of concept implementation of Sparse Evolutionary Training (SET) for Restricted Boltzmann Machine (RBM) on COIL20 dataset using Python, SciPy sparse data structures, and (optionally) Cython. 38 | * This implementation can be used to create SET-RBM with hundred of thousands of neurons on a standard laptop and was developed just before the publication of Reference 1. 39 | 40 | ###### Implementation 4 - IJCAI 2019 tutorial - light hands-on experience code (Tutorial-IJCAI-2019-Scalable-Deep-Learning) 41 | 42 | * Tutorial details - "Scalable Deep Learning: from theory to practice" 43 | https://sites.google.com/view/scalable-deep-learning-ijcai19 44 | * The code is based on Implementation 2 of SET-MLP to which Dropout is added. 45 | * In the "Pretrained_results" folder there is a nice animation "fashion_mnist_connections_evolution_per_input_pixel_rand0.gif" of the input layer connectivity evolution during training. 46 | 47 | ###### Implementation 5 - ECMLPKDD 2019 tutorial - light hands-on experience code (Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning) 48 | 49 | * Tutorial details - "Scalable Deep Learning: from theory to practice" 50 | https://sites.google.com/view/sdl-ecmlpkdd-2019-tutorial 51 | * The code is based on Implementation 2 of SET-MLP to which Dropout is added. 52 | * In the "Pretrained_results" folder there is a nice animation "fashion_mnist_connections_evolution_per_input_pixel_rand0.gif" of the input layer connectivity evolution during training. 53 | 54 | 55 | ###### References 56 | 57 | For an easy understanding of these implementations please read the following articles. Also, if you use parts of this code in your work, please cite the corresponding ones: 58 | 59 | 1. @article{Mocanu2018SET, 60 | author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 61 | journal = {Nature Communications}, 62 | title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 63 | year = {2018}, 64 | doi = {10.1038/s41467-018-04316-3}, 65 | url = {https://www.nature.com/articles/s41467-018-04316-3 }} 66 | 67 | 2. @article{Mocanu2016XBM, 68 | author={Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 69 | title={A topological insight into restricted Boltzmann machines}, 70 | journal={Machine Learning}, 71 | year={2016}, 72 | volume={104}, 73 | number={2}, 74 | pages={243--270}, 75 | doi={10.1007/s10994-016-5570-z}, 76 | url={https://doi.org/10.1007/s10994-016-5570-z }} 77 | 78 | 3. @phdthesis{Mocanu2017PhDthesis, 79 | title = {Network computations in artificial intelligence}, 80 | author = {Mocanu, Decebal Constantin}, 81 | year = {2017}, 82 | isbn = {978-90-386-4305-2}, 83 | publisher = {Eindhoven University of Technology}, 84 | url={https://pure.tue.nl/ws/files/69949254/20170629_CO_Mocanu.pdf } 85 | } 86 | 87 | 4. @article{Liu2019onemillion, 88 | author = {Liu, Shiwei and Mocanu, Decebal Constantin and Mocanu and Ramapuram Matavalam, Amarsagar Reddy and Pei, Yulong Pei and Pechenizkiy, Mykola}, 89 | journal = {arXiv:1901.09181}, 90 | title = {Sparse evolutionary Deep Learning with over one million artificial neurons on commodity hardware}, 91 | year = {2019}, 92 | url={https://arxiv.org/abs/1901.09181 } 93 | } 94 | 95 | SET shows that large sparse neural networks can be built if topological sparsity is created from the design phase, before training. There are many algorithmic and implementation improvements which can be made. If you find this work interesting, please share the links to this Github page and to Reference 1. For any question, suggestion, feedback please feel free to contact me by email. 96 | 97 | ###### Community 98 | 99 | Some time ago, I had a very pleasant unexpected surprise when I found out that Michael Klear released "Synapses". This library implements SET layers in PyTorch and as Michael says it is "truly sparse". For more details please read his article: 100 | 101 | https://towardsdatascience.com/the-sparse-future-of-deep-learning-bce05e8e094a 102 | 103 | And try out "Synapses" yourself: 104 | 105 | https://github.com/AlliedToasters/synapses 106 | 107 | Many things can be improved in "Synapses". If interested, please contact and help Michael in developing further the project. 108 | 109 | ###### Update 4 June 2020 110 | 111 | Our paper "Topological insights into sparse neural networks" https://arxiv.org/pdf/2006.14085.pdf has been accepted at ECMLPKDD 2020. It proposes Neural Network Sparse Topology Distance (NNSTD) to measure the distance between different sparse neural networks. The code is here https://github.com/Shiweiliuiiiiiii/Sparse_Topology_Distance. Also, it shows in a principled manner that sparse training easily unveils a plenitude of sparse sub-networks with very different topologies which outperform the dense networks. 112 | 113 | ###### Update 30 November 2020 114 | 115 | For an interesting quick read about sparse training, please have a look on this blog https://numenta.com/blog/2020/10/30/case-for-sparsity-in-neural-networks-part-2-dynamic-sparsity 116 | 117 | ###### Update 14 December 2020 118 | 119 | To see how sparse training can be used for feature selection please check our latest paper, titled "Quick and Robust Feature Selection: the Strength of Energy-efficient Sparse Training for Autoencoders", here: 120 | https://arxiv.org/abs/2012.00560 121 | 122 | and the corresponding truly sparse implementation here: 123 | https://github.com/zahraatashgahi/QuickSelection 124 | 125 | Many thanks, 126 | Decebal 127 | -------------------------------------------------------------------------------- /SET-MLP-Keras-Weights-Mask/cifar10_models_performance.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/SET-MLP-Keras-Weights-Mask/cifar10_models_performance.pdf -------------------------------------------------------------------------------- /SET-MLP-Keras-Weights-Mask/dense_mlp_keras_cifar10.py: -------------------------------------------------------------------------------- 1 | # Author: Decebal Constantin Mocanu et al.; 2 | # Proof of concept implementation of a standard dense Multi Layer Perceptron (MLP) on CIFAR10 using Keras and a mask over weights. 3 | # This implementation serves just as a comparison for the SET-MLP and MLP-FixProb models 4 | 5 | # This is a pre-alpha free software and was tested with Python 3.5.2, Keras 2.1.3, Keras_Contrib 0.0.2, Tensorflow 1.5.0, Numpy 1.14; 6 | # The code is distributed in the hope that it may be useful, but WITHOUT ANY WARRANTIES; The use of this software is entirely at the user's own risk; 7 | # For an easy understanding of the code functionality please read the following articles. 8 | 9 | # If you use parts of this code please cite the following articles: 10 | #@article{Mocanu2018SET, 11 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 12 | # journal = {Nature Communications}, 13 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 14 | # year = {2018}, 15 | # doi = {10.1038/s41467-018-04316-3} 16 | #} 17 | 18 | #@Article{Mocanu2016XBM, 19 | #author="Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio", 20 | #title="A topological insight into restricted Boltzmann machines", 21 | #journal="Machine Learning", 22 | #year="2016", 23 | #volume="104", 24 | #number="2", 25 | #pages="243--270", 26 | #doi="10.1007/s10994-016-5570-z", 27 | #url="https://doi.org/10.1007/s10994-016-5570-z" 28 | #} 29 | 30 | #@phdthesis{Mocanu2017PhDthesis, 31 | #title = "Network computations in artificial intelligence", 32 | #author = "D.C. Mocanu", 33 | #year = "2017", 34 | #isbn = "978-90-386-4305-2", 35 | #publisher = "Eindhoven University of Technology", 36 | #} 37 | 38 | from __future__ import division 39 | from __future__ import print_function 40 | from keras.preprocessing.image import ImageDataGenerator 41 | from keras.models import Sequential 42 | from keras.layers import Dense, Dropout, Activation, Flatten 43 | from keras import optimizers 44 | import numpy as np 45 | from keras import backend as K 46 | #Please note that in newer versions of keras_contrib you may encounter some import errors. You can find a fix for it on the Internet, or as an alternative you can try other activations functions. 47 | from keras_contrib.layers.advanced_activations import SReLU 48 | from keras.datasets import cifar10 49 | from keras.utils import np_utils 50 | 51 | 52 | class MLP_CIFAR10: 53 | def __init__(self): 54 | # set model parameters 55 | self.epsilon = 20 # control the sparsity level as discussed in the paper 56 | self.batch_size = 100 # batch size 57 | self.maxepoches = 1000 # number of epochs 58 | self.learning_rate = 0.01 # SGD learning rate 59 | self.num_classes = 10 # number of classes 60 | self.momentum=0.9 # SGD momentum 61 | 62 | # initialize layers weights 63 | self.w1 = None 64 | self.w2 = None 65 | self.w3 = None 66 | self.w4 = None 67 | 68 | # initialize weights for SReLu activation function 69 | self.wSRelu1 = None 70 | self.wSRelu2 = None 71 | self.wSRelu3 = None 72 | 73 | # create a MLP-FixProb model 74 | self.create_model() 75 | 76 | # train the MLP-FixProb model 77 | self.train() 78 | 79 | 80 | def create_model(self): 81 | 82 | # create a dense MLP model for CIFAR10 with 3 hidden layers 83 | self.model = Sequential() 84 | self.model.add(Flatten(input_shape=(32, 32, 3))) 85 | self.model.add(Dense(4000, name="dense_1", weights=self.w1)) 86 | self.model.add(SReLU(name="srelu1", weights=self.wSRelu1)) 87 | self.model.add(Dropout(0.3)) 88 | self.model.add(Dense(1000, name="dense_2", weights=self.w2)) 89 | self.model.add(SReLU(name="srelu2", weights=self.wSRelu2)) 90 | self.model.add(Dropout(0.3)) 91 | self.model.add(Dense(4000, name="dense_3", weights=self.w3)) 92 | self.model.add(SReLU(name="srelu3", weights=self.wSRelu3)) 93 | self.model.add(Dropout(0.3)) 94 | self.model.add(Dense(self.num_classes, name="dense_4", weights=self.w4)) 95 | self.model.add(Activation('softmax')) 96 | 97 | def train(self): 98 | 99 | # read CIFAR10 data 100 | [x_train,x_test,y_train,y_test]=self.read_data() 101 | 102 | #data augmentation 103 | datagen = ImageDataGenerator( 104 | featurewise_center=False, # set input mean to 0 over the dataset 105 | samplewise_center=False, # set each sample mean to 0 106 | featurewise_std_normalization=False, # divide inputs by std of the dataset 107 | samplewise_std_normalization=False, # divide each input by its std 108 | zca_whitening=False, # apply ZCA whitening 109 | rotation_range=10, # randomly rotate images in the range (degrees, 0 to 180) 110 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 111 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 112 | horizontal_flip=True, # randomly flip images 113 | vertical_flip=False) # randomly flip images 114 | datagen.fit(x_train) 115 | 116 | self.model.summary() 117 | 118 | sgd = optimizers.SGD(lr=self.learning_rate, momentum=self.momentum) 119 | self.model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) 120 | 121 | historytemp = self.model.fit_generator(datagen.flow(x_train, y_train, 122 | batch_size=self.batch_size), 123 | steps_per_epoch=x_train.shape[0] // self.batch_size, 124 | epochs=self.maxepoches, 125 | validation_data=(x_test, y_test), 126 | ) 127 | 128 | self.accuracies_per_epoch = historytemp.history['val_acc'] 129 | 130 | 131 | def read_data(self): 132 | 133 | #read CIFAR10 data 134 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 135 | y_train = np_utils.to_categorical(y_train, self.num_classes) 136 | y_test = np_utils.to_categorical(y_test, self.num_classes) 137 | x_train = x_train.astype('float32') 138 | x_test = x_test.astype('float32') 139 | 140 | #normalize data 141 | xTrainMean = np.mean(x_train, axis=0) 142 | xTtrainStd = np.std(x_train, axis=0) 143 | x_train = (x_train - xTrainMean) / xTtrainStd 144 | x_test = (x_test - xTrainMean) / xTtrainStd 145 | 146 | return [x_train, x_test, y_train, y_test] 147 | 148 | if __name__ == '__main__': 149 | 150 | # create and run a dense MLP model on CIFAR10 151 | model=MLP_CIFAR10() 152 | 153 | # save accuracies over for all training epochs 154 | # in "results" folder you can find the output of running this file 155 | np.savetxt("results/dense_mlp_srelu_sgd_cifar10_acc.txt", np.asarray(model.accuracies_per_epoch)) 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /SET-MLP-Keras-Weights-Mask/fixprob_mlp_keras_cifar10.py: -------------------------------------------------------------------------------- 1 | # Author: Decebal Constantin Mocanu et al.; 2 | # Proof of concept implementation of a Multi Layer Perceptron (MLP) with a fix sparsity pattern (FixProb) on CIFAR10 using Keras and a mask over weights. 3 | # This implementation can be used to test the model in varying conditions, using the Keras framework versatility, e.g. various optimizers, activation layers, tensorflow 4 | # Also it can be easily adapted for Convolutional Neural Networks or other models which have dense layers 5 | # However, due the fact that the weights are stored in the standard Keras format (dense matrices), this implementation can not scale properly. 6 | # If you would like to build and MLP-FixProb with over 100000 neurons, please use the pure Python implementation from the folder "SET-MLP-Sparse-Python-Data-Structures" 7 | 8 | # This is a pre-alpha free software and was tested with Python 3.5.2, Keras 2.1.3, Keras_Contrib 0.0.2, Tensorflow 1.5.0, Numpy 1.14; 9 | # The code is distributed in the hope that it may be useful, but WITHOUT ANY WARRANTIES; The use of this software is entirely at the user's own risk; 10 | # For an easy understanding of the code functionality please read the following articles. 11 | 12 | # If you use parts of this code please cite the following articles: 13 | #@article{Mocanu2018SET, 14 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 15 | # journal = {Nature Communications}, 16 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 17 | # year = {2018}, 18 | # doi = {10.1038/s41467-018-04316-3} 19 | #} 20 | 21 | #@Article{Mocanu2016XBM, 22 | #author="Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio", 23 | #title="A topological insight into restricted Boltzmann machines", 24 | #journal="Machine Learning", 25 | #year="2016", 26 | #volume="104", 27 | #number="2", 28 | #pages="243--270", 29 | #doi="10.1007/s10994-016-5570-z", 30 | #url="https://doi.org/10.1007/s10994-016-5570-z" 31 | #} 32 | 33 | #@phdthesis{Mocanu2017PhDthesis, 34 | #title = "Network computations in artificial intelligence", 35 | #author = "D.C. Mocanu", 36 | #year = "2017", 37 | #isbn = "978-90-386-4305-2", 38 | #publisher = "Eindhoven University of Technology", 39 | #} 40 | 41 | from __future__ import division 42 | from __future__ import print_function 43 | from keras.preprocessing.image import ImageDataGenerator 44 | from keras.models import Sequential 45 | from keras.layers import Dense, Dropout, Activation, Flatten 46 | from keras import optimizers 47 | import numpy as np 48 | from keras import backend as K 49 | #Please note that in newer versions of keras_contrib you may encounter some import errors. You can find a fix for it on the Internet, or as an alternative you can try other activations functions. 50 | from keras_contrib.layers.advanced_activations import SReLU 51 | from keras.datasets import cifar10 52 | from keras.utils import np_utils 53 | 54 | class Constraint(object): 55 | 56 | def __call__(self, w): 57 | return w 58 | 59 | def get_config(self): 60 | return {} 61 | 62 | class MaskWeights(Constraint): 63 | 64 | def __init__(self, mask): 65 | self.mask = mask 66 | self.mask = K.cast(self.mask, K.floatx()) 67 | 68 | def __call__(self, w): 69 | w *= self.mask 70 | return w 71 | 72 | def get_config(self): 73 | return {'mask': self.mask} 74 | 75 | 76 | def createWeightsMask(epsilon,noRows, noCols): 77 | # generate an Erdos Renyi sparse weights mask 78 | mask_weights = np.random.rand(noRows, noCols) 79 | prob = 1 - (epsilon * (noRows + noCols)) / (noRows * noCols) # normal tp have 8x connections 80 | mask_weights[mask_weights < prob] = 0 81 | mask_weights[mask_weights >= prob] = 1 82 | noParameters = np.sum(mask_weights) 83 | print ("Create Sparse Matrix: No parameters, NoRows, NoCols ",noParameters,noRows,noCols) 84 | return [noParameters,mask_weights] 85 | 86 | 87 | class MLP_FixProb_CIFAR10: 88 | def __init__(self): 89 | # set model parameters 90 | self.epsilon = 20 # control the sparsity level as discussed in the paper 91 | self.batch_size = 100 # batch size 92 | self.maxepoches = 1000 # number of epochs 93 | self.learning_rate = 0.01 # SGD learning rate 94 | self.num_classes = 10 # number of classes 95 | self.momentum=0.9 # SGD momentum 96 | 97 | # generate an Erdos Renyi sparse weights mask for each layer 98 | [self.noPar1, self.wm1] = createWeightsMask(self.epsilon,32 * 32 *3, 4000) 99 | [self.noPar2, self.wm2] = createWeightsMask(self.epsilon,4000, 1000) 100 | [self.noPar3, self.wm3] = createWeightsMask(self.epsilon,1000, 4000) 101 | 102 | # initialize layers weights 103 | self.w1 = None 104 | self.w2 = None 105 | self.w3 = None 106 | self.w4 = None 107 | 108 | # initialize weights for SReLu activation function 109 | self.wSRelu1 = None 110 | self.wSRelu2 = None 111 | self.wSRelu3 = None 112 | 113 | # create a MLP-FixProb model 114 | self.create_model() 115 | 116 | # train the SMLP-FixProb model 117 | self.train() 118 | 119 | 120 | def create_model(self): 121 | 122 | # create a MLP-FixProb model for CIFAR10 with 3 hidden layers 123 | self.model = Sequential() 124 | self.model.add(Flatten(input_shape=(32, 32, 3))) 125 | self.model.add(Dense(4000, name="sparse_1",kernel_constraint=MaskWeights(self.wm1),weights=self.w1)) 126 | self.model.add(SReLU(name="srelu1",weights=self.wSRelu1)) 127 | self.model.add(Dropout(0.3)) 128 | self.model.add(Dense(1000, name="sparse_2",kernel_constraint=MaskWeights(self.wm2),weights=self.w2)) 129 | self.model.add(SReLU(name="srelu2",weights=self.wSRelu2)) 130 | self.model.add(Dropout(0.3)) 131 | self.model.add(Dense(4000, name="sparse_3",kernel_constraint=MaskWeights(self.wm3),weights=self.w3)) 132 | self.model.add(SReLU(name="srelu3",weights=self.wSRelu3)) 133 | self.model.add(Dropout(0.3)) 134 | self.model.add(Dense(self.num_classes, name="dense_4",weights=self.w4)) #please note that there is no need for a sparse output layer as the number of classes is much smaller than the number of input hidden neurons 135 | self.model.add(Activation('softmax')) 136 | 137 | def train(self): 138 | 139 | # read CIFAR10 data 140 | [x_train,x_test,y_train,y_test]=self.read_data() 141 | 142 | #data augmentation 143 | datagen = ImageDataGenerator( 144 | featurewise_center=False, # set input mean to 0 over the dataset 145 | samplewise_center=False, # set each sample mean to 0 146 | featurewise_std_normalization=False, # divide inputs by std of the dataset 147 | samplewise_std_normalization=False, # divide each input by its std 148 | zca_whitening=False, # apply ZCA whitening 149 | rotation_range=10, # randomly rotate images in the range (degrees, 0 to 180) 150 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 151 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 152 | horizontal_flip=True, # randomly flip images 153 | vertical_flip=False) # randomly flip images 154 | datagen.fit(x_train) 155 | 156 | self.model.summary() 157 | 158 | sgd = optimizers.SGD(lr=self.learning_rate, momentum=self.momentum) 159 | self.model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) 160 | 161 | historytemp = self.model.fit_generator(datagen.flow(x_train, y_train, 162 | batch_size=self.batch_size), 163 | steps_per_epoch=x_train.shape[0]//self.batch_size, 164 | epochs=self.maxepoches, 165 | validation_data=(x_test, y_test), 166 | ) 167 | 168 | self.accuracies_per_epoch=historytemp.history['val_acc'] 169 | 170 | def read_data(self): 171 | 172 | #read CIFAR10 data 173 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 174 | y_train = np_utils.to_categorical(y_train, self.num_classes) 175 | y_test = np_utils.to_categorical(y_test, self.num_classes) 176 | x_train = x_train.astype('float32') 177 | x_test = x_test.astype('float32') 178 | 179 | #normalize data 180 | xTrainMean = np.mean(x_train, axis=0) 181 | xTtrainStd = np.std(x_train, axis=0) 182 | x_train = (x_train - xTrainMean) / xTtrainStd 183 | x_test = (x_test - xTrainMean) / xTtrainStd 184 | 185 | return [x_train, x_test, y_train, y_test] 186 | 187 | if __name__ == '__main__': 188 | 189 | # create and run a MLP-FixProb model on CIFAR10 190 | model=MLP_FixProb_CIFAR10() 191 | 192 | # save accuracies over for all training epochs 193 | # in "results" folder you can find the output of running this file 194 | np.savetxt("results/fixprob_mlp_srelu_sgd_cifar10_acc.txt", np.asarray(model.accuracies_per_epoch)) 195 | 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /SET-MLP-Keras-Weights-Mask/plot_performance.py: -------------------------------------------------------------------------------- 1 | # Author: Decebal Constantin Mocanu et al.; 2 | # Plot performance of all three models on CIFAR10 3 | 4 | # This is a pre-alpha free software and was tested with Python 3.5.2, Keras 2.1.3, Keras_Contrib 0.0.2, Tensorflow 1.5.0, Numpy 1.14; 5 | # The code is distributed in the hope that it may be useful, but WITHOUT ANY WARRANTIES; The use of this software is entirely at the user's own risk; 6 | # For an easy understanding of the code functionality please read the following articles. 7 | 8 | # If you use parts of this code please cite the following articles: 9 | #@article{Mocanu2018SET, 10 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 11 | # journal = {Nature Communications}, 12 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 13 | # year = {2018}, 14 | # doi = {10.1038/s41467-018-04316-3} 15 | #} 16 | 17 | #@Article{Mocanu2016XBM, 18 | #author="Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio", 19 | #title="A topological insight into restricted Boltzmann machines", 20 | #journal="Machine Learning", 21 | #year="2016", 22 | #volume="104", 23 | #number="2", 24 | #pages="243--270", 25 | #doi="10.1007/s10994-016-5570-z", 26 | #url="https://doi.org/10.1007/s10994-016-5570-z" 27 | #} 28 | 29 | #@phdthesis{Mocanu2017PhDthesis, 30 | #title = "Network computations in artificial intelligence", 31 | #author = "D.C. Mocanu", 32 | #year = "2017", 33 | #isbn = "978-90-386-4305-2", 34 | #publisher = "Eindhoven University of Technology", 35 | #} 36 | 37 | import matplotlib.pyplot as plt 38 | import numpy as np 39 | 40 | 41 | ev=np.loadtxt("results/set_mlp_srelu_sgd_cifar10_acc.txt") 42 | fix=np.loadtxt("results/fixprob_mlp_srelu_sgd_cifar10_acc.txt") 43 | dense=np.loadtxt("results/dense_mlp_srelu_sgd_cifar10_acc.txt") 44 | 45 | plt.xlabel("Epochs[#]") 46 | plt.ylabel("CIFAR10\nAccuracy [%]") 47 | 48 | 49 | plt.plot(dense*100,'b',label="MLP") 50 | plt.plot(fix*100,'y',label="MLP$_{FixProb}$") 51 | plt.plot(ev*100,'r',label="SET-MLP") 52 | 53 | plt.legend(loc=4) 54 | plt.grid(True) 55 | plt.tight_layout() 56 | plt.savefig("cifar10_models_performance.pdf") 57 | plt.close() -------------------------------------------------------------------------------- /SET-MLP-Keras-Weights-Mask/set_mlp_keras_cifar10.py: -------------------------------------------------------------------------------- 1 | # Author: Decebal Constantin Mocanu et al.; 2 | # Proof of concept implementation of Sparse Evolutionary Training (SET) of Multi Layer Perceptron (MLP) on CIFAR10 using Keras and a mask over weights. 3 | # This implementation can be used to test SET in varying conditions, using the Keras framework versatility, e.g. various optimizers, activation layers, tensorflow 4 | # Also it can be easily adapted for Convolutional Neural Networks or other models which have dense layers 5 | # However, due the fact that the weights are stored in the standard Keras format (dense matrices), this implementation can not scale properly. 6 | # If you would like to build and SET-MLP with over 100000 neurons, please use the pure Python implementation from the folder "SET-MLP-Sparse-Python-Data-Structures" 7 | 8 | # This is a pre-alpha free software and was tested with Python 3.5.2, Keras 2.1.3, Keras_Contrib 0.0.2, Tensorflow 1.5.0, Numpy 1.14; 9 | # The code is distributed in the hope that it may be useful, but WITHOUT ANY WARRANTIES; The use of this software is entirely at the user's own risk; 10 | # For an easy understanding of the code functionality please read the following articles. 11 | 12 | # If you use parts of this code please cite the following articles: 13 | #@article{Mocanu2018SET, 14 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 15 | # journal = {Nature Communications}, 16 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 17 | # year = {2018}, 18 | # doi = {10.1038/s41467-018-04316-3} 19 | #} 20 | 21 | #@Article{Mocanu2016XBM, 22 | #author="Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio", 23 | #title="A topological insight into restricted Boltzmann machines", 24 | #journal="Machine Learning", 25 | #year="2016", 26 | #volume="104", 27 | #number="2", 28 | #pages="243--270", 29 | #doi="10.1007/s10994-016-5570-z", 30 | #url="https://doi.org/10.1007/s10994-016-5570-z" 31 | #} 32 | 33 | #@phdthesis{Mocanu2017PhDthesis, 34 | #title = "Network computations in artificial intelligence", 35 | #author = "D.C. Mocanu", 36 | #year = "2017", 37 | #isbn = "978-90-386-4305-2", 38 | #publisher = "Eindhoven University of Technology", 39 | #} 40 | 41 | from __future__ import division 42 | from __future__ import print_function 43 | from keras.preprocessing.image import ImageDataGenerator 44 | from keras.models import Sequential 45 | from keras.layers import Dense, Dropout, Activation, Flatten 46 | from keras import optimizers 47 | import numpy as np 48 | from keras import backend as K 49 | #Please note that in newer versions of keras_contrib you may encounter some import errors. You can find a fix for it on the Internet, or as an alternative you can try other activations functions. 50 | from keras_contrib.layers.advanced_activations import SReLU 51 | from keras.datasets import cifar10 52 | from keras.utils import np_utils 53 | 54 | class Constraint(object): 55 | 56 | def __call__(self, w): 57 | return w 58 | 59 | def get_config(self): 60 | return {} 61 | 62 | class MaskWeights(Constraint): 63 | 64 | def __init__(self, mask): 65 | self.mask = mask 66 | self.mask = K.cast(self.mask, K.floatx()) 67 | 68 | def __call__(self, w): 69 | w *= self.mask 70 | return w 71 | 72 | def get_config(self): 73 | return {'mask': self.mask} 74 | 75 | 76 | def find_first_pos(array, value): 77 | idx = (np.abs(array - value)).argmin() 78 | return idx 79 | 80 | 81 | def find_last_pos(array, value): 82 | idx = (np.abs(array - value))[::-1].argmin() 83 | return array.shape[0] - idx 84 | 85 | 86 | def createWeightsMask(epsilon,noRows, noCols): 87 | # generate an Erdos Renyi sparse weights mask 88 | mask_weights = np.random.rand(noRows, noCols) 89 | prob = 1 - (epsilon * (noRows + noCols)) / (noRows * noCols) # normal tp have 8x connections 90 | mask_weights[mask_weights < prob] = 0 91 | mask_weights[mask_weights >= prob] = 1 92 | noParameters = np.sum(mask_weights) 93 | print ("Create Sparse Matrix: No parameters, NoRows, NoCols ",noParameters,noRows,noCols) 94 | return [noParameters,mask_weights] 95 | 96 | 97 | class SET_MLP_CIFAR10: 98 | def __init__(self): 99 | # set model parameters 100 | self.epsilon = 20 # control the sparsity level as discussed in the paper 101 | self.zeta = 0.3 # the fraction of the weights removed 102 | self.batch_size = 100 # batch size 103 | self.maxepoches = 1000 # number of epochs 104 | self.learning_rate = 0.01 # SGD learning rate 105 | self.num_classes = 10 # number of classes 106 | self.momentum=0.9 # SGD momentum 107 | 108 | # generate an Erdos Renyi sparse weights mask for each layer 109 | [self.noPar1, self.wm1] = createWeightsMask(self.epsilon,32 * 32 *3, 4000) 110 | [self.noPar2, self.wm2] = createWeightsMask(self.epsilon,4000, 1000) 111 | [self.noPar3, self.wm3] = createWeightsMask(self.epsilon,1000, 4000) 112 | 113 | # initialize layers weights 114 | self.w1 = None 115 | self.w2 = None 116 | self.w3 = None 117 | self.w4 = None 118 | 119 | # initialize weights for SReLu activation function 120 | self.wSRelu1 = None 121 | self.wSRelu2 = None 122 | self.wSRelu3 = None 123 | 124 | # create a SET-MLP model 125 | self.create_model() 126 | 127 | # train the SET-MLP model 128 | self.train() 129 | 130 | 131 | def create_model(self): 132 | 133 | # create a SET-MLP model for CIFAR10 with 3 hidden layers 134 | self.model = Sequential() 135 | self.model.add(Flatten(input_shape=(32, 32, 3))) 136 | self.model.add(Dense(4000, name="sparse_1",kernel_constraint=MaskWeights(self.wm1),weights=self.w1)) 137 | self.model.add(SReLU(name="srelu1",weights=self.wSRelu1)) 138 | self.model.add(Dropout(0.3)) 139 | self.model.add(Dense(1000, name="sparse_2",kernel_constraint=MaskWeights(self.wm2),weights=self.w2)) 140 | self.model.add(SReLU(name="srelu2",weights=self.wSRelu2)) 141 | self.model.add(Dropout(0.3)) 142 | self.model.add(Dense(4000, name="sparse_3",kernel_constraint=MaskWeights(self.wm3),weights=self.w3)) 143 | self.model.add(SReLU(name="srelu3",weights=self.wSRelu3)) 144 | self.model.add(Dropout(0.3)) 145 | self.model.add(Dense(self.num_classes, name="dense_4",weights=self.w4)) #please note that there is no need for a sparse output layer as the number of classes is much smaller than the number of input hidden neurons 146 | self.model.add(Activation('softmax')) 147 | 148 | def rewireMask(self,weights, noWeights): 149 | # rewire weight matrix 150 | 151 | # remove zeta largest negative and smallest positive weights 152 | values = np.sort(weights.ravel()) 153 | firstZeroPos = find_first_pos(values, 0) 154 | lastZeroPos = find_last_pos(values, 0) 155 | largestNegative = values[int((1-self.zeta) * firstZeroPos)] 156 | smallestPositive = values[int(min(values.shape[0] - 1, lastZeroPos +self.zeta * (values.shape[0] - lastZeroPos)))] 157 | rewiredWeights = weights.copy(); 158 | rewiredWeights[rewiredWeights > smallestPositive] = 1; 159 | rewiredWeights[rewiredWeights < largestNegative] = 1; 160 | rewiredWeights[rewiredWeights != 1] = 0; 161 | weightMaskCore = rewiredWeights.copy() 162 | 163 | # add zeta random weights 164 | nrAdd = 0 165 | noRewires = noWeights - np.sum(rewiredWeights) 166 | while (nrAdd < noRewires): 167 | i = np.random.randint(0, rewiredWeights.shape[0]) 168 | j = np.random.randint(0, rewiredWeights.shape[1]) 169 | if (rewiredWeights[i, j] == 0): 170 | rewiredWeights[i, j] = 1 171 | nrAdd += 1 172 | 173 | return [rewiredWeights, weightMaskCore] 174 | 175 | def weightsEvolution(self): 176 | # this represents the core of the SET procedure. It removes the weights closest to zero in each layer and add new random weights 177 | self.w1 = self.model.get_layer("sparse_1").get_weights() 178 | self.w2 = self.model.get_layer("sparse_2").get_weights() 179 | self.w3 = self.model.get_layer("sparse_3").get_weights() 180 | self.w4 = self.model.get_layer("dense_4").get_weights() 181 | 182 | self.wSRelu1 = self.model.get_layer("srelu1").get_weights() 183 | self.wSRelu2 = self.model.get_layer("srelu2").get_weights() 184 | self.wSRelu3 = self.model.get_layer("srelu3").get_weights() 185 | 186 | [self.wm1, self.wm1Core] = self.rewireMask(self.w1[0], self.noPar1) 187 | [self.wm2, self.wm2Core] = self.rewireMask(self.w2[0], self.noPar2) 188 | [self.wm3, self.wm3Core] = self.rewireMask(self.w3[0], self.noPar3) 189 | 190 | self.w1[0] = self.w1[0] * self.wm1Core 191 | self.w2[0] = self.w2[0] * self.wm2Core 192 | self.w3[0] = self.w3[0] * self.wm3Core 193 | 194 | def train(self): 195 | 196 | # read CIFAR10 data 197 | [x_train,x_test,y_train,y_test]=self.read_data() 198 | 199 | #data augmentation 200 | datagen = ImageDataGenerator( 201 | featurewise_center=False, # set input mean to 0 over the dataset 202 | samplewise_center=False, # set each sample mean to 0 203 | featurewise_std_normalization=False, # divide inputs by std of the dataset 204 | samplewise_std_normalization=False, # divide each input by its std 205 | zca_whitening=False, # apply ZCA whitening 206 | rotation_range=10, # randomly rotate images in the range (degrees, 0 to 180) 207 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 208 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 209 | horizontal_flip=True, # randomly flip images 210 | vertical_flip=False) # randomly flip images 211 | datagen.fit(x_train) 212 | 213 | self.model.summary() 214 | 215 | # training process in a for loop 216 | self.accuracies_per_epoch=[] 217 | for epoch in range(0,self.maxepoches): 218 | 219 | sgd = optimizers.SGD(lr=self.learning_rate, momentum=self.momentum) 220 | self.model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) 221 | 222 | historytemp = self.model.fit_generator(datagen.flow(x_train, y_train, 223 | batch_size=self.batch_size), 224 | steps_per_epoch=x_train.shape[0]//self.batch_size, 225 | epochs=epoch, 226 | validation_data=(x_test, y_test), 227 | initial_epoch=epoch-1) 228 | 229 | self.accuracies_per_epoch.append(historytemp.history['val_acc'][0]) 230 | 231 | #ugly hack to avoid tensorflow memory increase for multiple fit_generator calls. Theano shall work more nicely this but it is outdated in general 232 | self.weightsEvolution() 233 | K.clear_session() 234 | self.create_model() 235 | 236 | self.accuracies_per_epoch=np.asarray(self.accuracies_per_epoch) 237 | 238 | def read_data(self): 239 | 240 | #read CIFAR10 data 241 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 242 | y_train = np_utils.to_categorical(y_train, self.num_classes) 243 | y_test = np_utils.to_categorical(y_test, self.num_classes) 244 | x_train = x_train.astype('float32') 245 | x_test = x_test.astype('float32') 246 | 247 | #normalize data 248 | xTrainMean = np.mean(x_train, axis=0) 249 | xTtrainStd = np.std(x_train, axis=0) 250 | x_train = (x_train - xTrainMean) / xTtrainStd 251 | x_test = (x_test - xTrainMean) / xTtrainStd 252 | 253 | return [x_train, x_test, y_train, y_test] 254 | 255 | if __name__ == '__main__': 256 | 257 | # create and run a SET-MLP model on CIFAR10 258 | model=SET_MLP_CIFAR10() 259 | 260 | # save accuracies over for all training epochs 261 | # in "results" folder you can find the output of running this file 262 | np.savetxt("results/set_mlp_srelu_sgd_cifar10_acc.txt", np.asarray(model.accuracies_per_epoch)) 263 | 264 | 265 | 266 | 267 | -------------------------------------------------------------------------------- /SET-MLP-Sparse-Python-Data-Structures/Results/mlp_fixprob.txt: -------------------------------------------------------------------------------- 1 | 9.370427778949437514e-02 7.647058823529411242e-01 6.485865253082562232e-02 2 | 2.990566237701250407e-02 7.941176470588234837e-01 5.289329321548029428e-02 3 | 1.808531381982177513e-02 8.235294117647058432e-01 4.273766584386964190e-02 4 | 1.208181273664610776e-02 8.529411764705882026e-01 3.645269391300137024e-02 5 | 7.590874231280928496e-03 8.235294117647058432e-01 4.121463282270809020e-02 6 | 5.455140866234472856e-03 8.382352941176470784e-01 3.523145821157960822e-02 7 | 4.544125684660249055e-03 8.676470588235294379e-01 3.144921917225548813e-02 8 | 3.706253296406840977e-03 8.382352941176470784e-01 3.462587244079876753e-02 9 | 3.509075129345544213e-03 8.676470588235294379e-01 3.047637533349903741e-02 10 | 2.978261005411392531e-03 8.676470588235294379e-01 3.007273080372085960e-02 11 | 2.453568559764474511e-03 8.676470588235294379e-01 3.139917504337062099e-02 12 | 2.216098289660652348e-03 8.529411764705882026e-01 3.127505485344873187e-02 13 | 2.201347571135006675e-03 8.676470588235294379e-01 3.018367884434295170e-02 14 | 2.118468529281089963e-03 8.529411764705882026e-01 3.045138192510846858e-02 15 | 1.997824800047246808e-03 8.529411764705882026e-01 3.085831359953564862e-02 16 | 1.974175888724010938e-03 8.676470588235294379e-01 2.851719447472244109e-02 17 | 1.485997547701601055e-03 8.529411764705882026e-01 3.163415360234013429e-02 18 | 1.963653303266168354e-03 8.676470588235294379e-01 3.073916426689785042e-02 19 | 1.819841236860323449e-03 8.529411764705882026e-01 2.927771977061319891e-02 20 | 1.806099403981731993e-03 8.676470588235294379e-01 2.719515498876038176e-02 21 | 1.833540840822662620e-03 8.676470588235294379e-01 2.832851180274456593e-02 22 | 1.775251281193437067e-03 8.676470588235294379e-01 2.815641821453016697e-02 23 | 1.666843275279470326e-03 8.676470588235294379e-01 2.847121883398552947e-02 24 | 1.839263666290472717e-03 8.676470588235294379e-01 2.816629100364313742e-02 25 | 1.677458828817012795e-03 8.676470588235294379e-01 2.795626106719619267e-02 26 | 1.864737011101259877e-03 8.676470588235294379e-01 2.803509445802833747e-02 27 | 1.734451876501213975e-03 8.676470588235294379e-01 2.776945891080647225e-02 28 | 1.763688909048011580e-03 8.676470588235294379e-01 2.888314193532836444e-02 29 | 1.653968052123833696e-03 8.529411764705882026e-01 2.832135050067947565e-02 30 | 1.897565403330728553e-03 8.676470588235294379e-01 2.690315205429851669e-02 31 | 1.899610748040095288e-03 8.676470588235294379e-01 2.688748988560279995e-02 32 | 1.742424531678075494e-03 8.676470588235294379e-01 2.761049122629385261e-02 33 | 1.895714846136553784e-03 8.529411764705882026e-01 2.785481789111167797e-02 34 | 1.781696610275431249e-03 8.676470588235294379e-01 2.791015544385500730e-02 35 | 1.982917805092706751e-03 8.823529411764705621e-01 2.724333727759760057e-02 36 | 1.991215240495495219e-03 8.823529411764705621e-01 2.630092827584427223e-02 37 | 1.982623884199796911e-03 8.823529411764705621e-01 2.699187004197275280e-02 38 | 1.836642849280627036e-03 8.823529411764705621e-01 2.672287717008684296e-02 39 | 1.898465255281468929e-03 8.823529411764705621e-01 2.684087766267759939e-02 40 | 2.093757753044122628e-03 8.823529411764705621e-01 2.600914357923880815e-02 41 | 1.929086774963580282e-03 8.823529411764705621e-01 2.691276558611204600e-02 42 | 1.902195363950479534e-03 8.823529411764705621e-01 2.727373793948351458e-02 43 | 2.099997987671855908e-03 8.823529411764705621e-01 2.671448216522267702e-02 44 | 2.003036106286974587e-03 8.823529411764705621e-01 2.676901385255396615e-02 45 | 2.077517485952746438e-03 8.823529411764705621e-01 2.693702066329385289e-02 46 | 2.101702134485219399e-03 8.823529411764705621e-01 2.763213259726215315e-02 47 | 2.215889635089279743e-03 8.823529411764705621e-01 2.709806509408569167e-02 48 | 2.058834112313140830e-03 8.823529411764705621e-01 2.630215252313008664e-02 49 | 2.209116020698025009e-03 8.970588235294117974e-01 2.533157711485883173e-02 50 | 1.943961960379528633e-03 8.823529411764705621e-01 2.697910249747836192e-02 51 | 2.196111960818908099e-03 8.970588235294117974e-01 2.588961536457305007e-02 52 | 2.060002608913613555e-03 8.823529411764705621e-01 2.635094354828472832e-02 53 | 2.162622173836040076e-03 8.970588235294117974e-01 2.529314864109079297e-02 54 | 2.230255141760402026e-03 8.970588235294117974e-01 2.566746178884456511e-02 55 | 1.972866026606156717e-03 8.970588235294117974e-01 2.510261950434496819e-02 56 | 2.243091950738693510e-03 8.970588235294117974e-01 2.511644625537515566e-02 57 | 2.053648699608041570e-03 8.970588235294117974e-01 2.586011018809827519e-02 58 | 2.226460304946835129e-03 8.970588235294117974e-01 2.552319653855881013e-02 59 | 2.088769683117555880e-03 8.970588235294117974e-01 2.470761038575459284e-02 60 | 2.183215245385187932e-03 8.970588235294117974e-01 2.515369396258242823e-02 61 | 2.178972436411604711e-03 8.970588235294117974e-01 2.517890887141517597e-02 62 | 2.144626007047922313e-03 8.970588235294117974e-01 2.549653309468304790e-02 63 | 2.202168364129981383e-03 8.970588235294117974e-01 2.430632725414298353e-02 64 | 2.133370901114112900e-03 8.970588235294117974e-01 2.433987164302411910e-02 65 | 2.113177547526500204e-03 8.970588235294117974e-01 2.450181781475286799e-02 66 | 1.960059466262140115e-03 8.970588235294117974e-01 2.455734953723699016e-02 67 | 2.230847600181151951e-03 8.970588235294117974e-01 2.544483334037609193e-02 68 | 2.048851110134288812e-03 8.970588235294117974e-01 2.392154010580202148e-02 69 | 2.037550254635224316e-03 8.970588235294117974e-01 2.419407273031155131e-02 70 | 2.013303370003655893e-03 8.970588235294117974e-01 2.417144147645908583e-02 71 | 1.938377368043572207e-03 8.970588235294117974e-01 2.386571672702962263e-02 72 | 1.968240566939691814e-03 8.970588235294117974e-01 2.394207933254784149e-02 73 | 1.936026584917161127e-03 8.970588235294117974e-01 2.382615907255532239e-02 74 | 1.799120545214456237e-03 8.970588235294117974e-01 2.341943615954473687e-02 75 | 1.756372606320146185e-03 8.970588235294117974e-01 2.388540311277941378e-02 76 | 1.875022716975046234e-03 8.970588235294117974e-01 2.290715569324327031e-02 77 | 1.737386454852573410e-03 8.970588235294117974e-01 2.362009228624023169e-02 78 | 1.802103151688178723e-03 8.970588235294117974e-01 2.334498475415364496e-02 79 | 1.749144058635809318e-03 8.970588235294117974e-01 2.303281065192464352e-02 80 | 1.780810903845691942e-03 8.970588235294117974e-01 2.298271689075906746e-02 81 | 1.683040373438603621e-03 8.970588235294117974e-01 2.270994765955456490e-02 82 | 1.586397456066852276e-03 8.970588235294117974e-01 2.314411499207923162e-02 83 | 1.579745088448012607e-03 8.970588235294117974e-01 2.295241198117768319e-02 84 | 1.577512900415723633e-03 8.970588235294117974e-01 2.263011155009354153e-02 85 | 1.571729775629211257e-03 8.970588235294117974e-01 2.270163014204882798e-02 86 | 1.491446124447282345e-03 8.970588235294117974e-01 2.174183641037686840e-02 87 | 1.446451903421914825e-03 8.970588235294117974e-01 2.191430279477164789e-02 88 | 1.434296946029943887e-03 8.970588235294117974e-01 2.194443405156356122e-02 89 | 1.306178749099410176e-03 8.970588235294117974e-01 2.175988027692596014e-02 90 | 1.387599646165148075e-03 8.970588235294117974e-01 2.272661409211653924e-02 91 | 1.400957221261988099e-03 8.970588235294117974e-01 2.152970656661614898e-02 92 | 1.201026342080276435e-03 8.970588235294117974e-01 2.221274911108393410e-02 93 | 1.321005810754082889e-03 8.970588235294117974e-01 2.139615167299690340e-02 94 | 1.116051573094814075e-03 9.117647058823529216e-01 2.133304691189665675e-02 95 | 1.199302825836159403e-03 8.970588235294117974e-01 2.145185563570089701e-02 96 | 1.216855018288525185e-03 9.117647058823529216e-01 2.167410137368236053e-02 97 | 1.045198225476677047e-03 8.970588235294117974e-01 2.149171982947635853e-02 98 | 1.038532606047024411e-03 9.117647058823529216e-01 2.115473224529234100e-02 99 | 1.005824583390936902e-03 9.117647058823529216e-01 2.106150016478744100e-02 100 | 9.900586422353589017e-04 9.117647058823529216e-01 2.129278542625050175e-02 101 | -------------------------------------------------------------------------------- /SET-MLP-Sparse-Python-Data-Structures/Results/set_mlp.txt: -------------------------------------------------------------------------------- 1 | 9.370427778949437514e-02 7.647058823529411242e-01 6.485865253082562232e-02 2 | 4.025222951461484427e-02 7.941176470588234837e-01 5.352891392219837063e-02 3 | 1.894106881007486323e-02 8.529411764705882026e-01 4.069932390754919110e-02 4 | 1.289793668748648141e-02 9.411764705882352811e-01 2.533928070997832488e-02 5 | 1.428705048005861518e-02 8.823529411764705621e-01 3.753370528217242486e-02 6 | 1.668647607543679093e-02 8.823529411764705621e-01 3.190498689833786566e-02 7 | 1.117074226100120415e-02 9.411764705882352811e-01 2.242267796345778427e-02 8 | 1.745924265642895268e-02 8.970588235294117974e-01 2.792405558041894179e-02 9 | 6.777336146358801071e-03 9.264705882352941568e-01 2.763505613901023200e-02 10 | 3.614138345065465597e-03 9.705882352941176405e-01 1.332783881038931278e-02 11 | 1.117658265261379201e-02 9.411764705882352811e-01 1.778667006522142643e-02 12 | 5.718160021938606223e-03 9.264705882352941568e-01 2.840776291138898818e-02 13 | 3.385996294716384601e-03 8.970588235294117974e-01 2.920437516307306663e-02 14 | 8.207910999811903960e-03 8.970588235294117974e-01 3.118788709529601161e-02 15 | 4.087512401733529865e-03 9.117647058823529216e-01 3.406480263926536800e-02 16 | 8.569430440708940971e-03 8.529411764705882026e-01 4.441640592892157763e-02 17 | 8.757237359385390552e-03 9.558823529411765163e-01 2.842577234957531854e-02 18 | 6.313471432832127098e-03 9.558823529411765163e-01 2.586143841359001688e-02 19 | 3.708427422442906862e-03 9.411764705882352811e-01 1.957598454609688760e-02 20 | 1.107948784936932211e-02 8.676470588235294379e-01 3.255658454390753659e-02 21 | 4.873807890575290562e-03 8.970588235294117974e-01 3.522191602296322055e-02 22 | 6.819704632375966913e-03 9.264705882352941568e-01 2.545177450446305112e-02 23 | 1.466170834587134271e-02 8.970588235294117974e-01 3.387975669499439929e-02 24 | 7.114501257080478938e-03 8.823529411764705621e-01 3.500879444813957686e-02 25 | 4.005597185708207990e-03 8.823529411764705621e-01 4.116375011755805252e-02 26 | 2.752159313066930135e-03 9.411764705882352811e-01 2.726167302048681032e-02 27 | 4.108975109958147576e-03 9.117647058823529216e-01 2.824978605575174939e-02 28 | 1.781481306586311150e-03 9.117647058823529216e-01 2.575286255849389139e-02 29 | 4.953197436589701352e-03 9.558823529411765163e-01 2.504110921952216670e-02 30 | 6.836389358445915318e-03 8.823529411764705621e-01 3.160912276144569460e-02 31 | 8.176181025323839627e-03 9.411764705882352811e-01 1.941382467370602466e-02 32 | 2.701404035494791266e-03 8.823529411764705621e-01 3.123374213520260442e-02 33 | 7.446118225714009817e-03 9.264705882352941568e-01 2.481362777360949373e-02 34 | 6.657190188537839316e-03 9.264705882352941568e-01 2.706513956305285282e-02 35 | 9.932801165807284352e-03 9.411764705882352811e-01 2.474888600366515271e-02 36 | 7.632398049259369460e-03 9.411764705882352811e-01 2.566153379509360927e-02 37 | 5.559213928637979363e-03 9.411764705882352811e-01 2.464937961765557503e-02 38 | 1.803210959303191540e-03 9.411764705882352811e-01 3.319456226777481311e-02 39 | 1.531127989358390968e-03 8.529411764705882026e-01 4.256019534946732519e-02 40 | 5.179181069542515559e-03 8.676470588235294379e-01 3.597276056197020594e-02 41 | 2.501723665693573583e-03 8.970588235294117974e-01 2.605061553491564916e-02 42 | 5.833167695203025568e-03 9.117647058823529216e-01 2.429254077302044415e-02 43 | 1.140658042355418775e-03 8.970588235294117974e-01 3.636323727131872408e-02 44 | 4.589927564953341643e-03 8.970588235294117974e-01 3.754584261067868411e-02 45 | 6.644542410784343850e-03 8.970588235294117974e-01 3.714771390553940622e-02 46 | 2.247120804958482702e-03 9.411764705882352811e-01 2.414823101682950918e-02 47 | 5.874632472013632464e-03 9.264705882352941568e-01 2.177443525754345355e-02 48 | 1.988800466934117876e-03 9.411764705882352811e-01 2.433210148626839889e-02 49 | 3.713013831715492735e-03 8.529411764705882026e-01 4.279074204563937389e-02 50 | 2.101159822205048416e-03 9.558823529411765163e-01 2.447750487181118983e-02 51 | 1.530950066688507860e-02 9.705882352941176405e-01 2.543987357513725206e-02 52 | 8.233003310200289956e-03 8.823529411764705621e-01 3.569417868491518708e-02 53 | 2.899610156525488699e-03 9.558823529411765163e-01 2.437263485317332432e-02 54 | 5.032016779289911514e-03 9.411764705882352811e-01 2.765746408813718471e-02 55 | 4.079818425086709568e-03 9.117647058823529216e-01 3.322763541917118230e-02 56 | 3.468789458565822367e-03 8.970588235294117974e-01 3.062679719996267910e-02 57 | 4.993386466903472314e-03 9.117647058823529216e-01 2.532387703637685450e-02 58 | 3.012547903132577806e-03 8.970588235294117974e-01 2.973278762954569346e-02 59 | 7.085354656060964540e-03 9.117647058823529216e-01 3.666642705923982282e-02 60 | 2.837844642227250937e-03 9.411764705882352811e-01 2.310982210074681575e-02 61 | 4.227909502071343882e-03 9.411764705882352811e-01 2.354007546152751817e-02 62 | 7.606891613003079027e-03 9.411764705882352811e-01 3.023642881955907954e-02 63 | 1.748809262059299215e-03 9.264705882352941568e-01 2.976375393726208088e-02 64 | 6.411169404880130707e-03 9.411764705882352811e-01 2.646012478071117432e-02 65 | 8.068710560824331190e-03 9.117647058823529216e-01 2.226224909003681265e-02 66 | 2.155909426127144545e-03 9.411764705882352811e-01 2.070727249046141333e-02 67 | 2.582206350471075423e-03 8.970588235294117974e-01 3.191265141814027956e-02 68 | 1.134864862883988405e-02 9.411764705882352811e-01 2.441129414382127785e-02 69 | 7.222415592685625663e-03 8.970588235294117974e-01 3.001561186205261439e-02 70 | 3.725561963633620011e-03 8.529411764705882026e-01 3.673642956566346152e-02 71 | 2.878865709469798431e-03 8.970588235294117974e-01 2.513080645091519613e-02 72 | 1.250806594077965130e-03 8.970588235294117974e-01 2.699387302187733831e-02 73 | 4.165086345345089466e-03 9.264705882352941568e-01 2.764867015571294967e-02 74 | 4.297485195430119330e-03 9.117647058823529216e-01 2.821850193639589485e-02 75 | 5.690610455882215250e-03 9.117647058823529216e-01 3.266883460127951999e-02 76 | 7.058118529344691722e-03 8.823529411764705621e-01 3.178518468282034298e-02 77 | 1.049384241639464421e-02 9.117647058823529216e-01 3.122164655477336487e-02 78 | 4.180904783077775411e-03 8.970588235294117974e-01 2.669300320881755340e-02 79 | 3.336413398286511289e-03 9.705882352941176405e-01 1.633953593630161488e-02 80 | 2.703457160765811509e-03 9.411764705882352811e-01 2.182190334588682279e-02 81 | 1.151675481542062396e-02 9.411764705882352811e-01 2.483915108302321012e-02 82 | 9.003054830375958884e-03 9.411764705882352811e-01 2.243826001048757068e-02 83 | 4.984757688977107387e-03 8.676470588235294379e-01 3.636980711338535371e-02 84 | 3.851720368132310109e-03 9.411764705882352811e-01 1.622232683249998086e-02 85 | 3.238551151372425115e-03 9.411764705882352811e-01 2.241500064015450669e-02 86 | 1.096302712562420567e-02 9.264705882352941568e-01 3.121825621164978429e-02 87 | 5.366276372525504920e-03 9.264705882352941568e-01 3.153748251114553425e-02 88 | 6.248181662747201957e-03 9.264705882352941568e-01 4.451779579564429240e-02 89 | 4.678876089786562949e-03 9.411764705882352811e-01 2.754622021567621620e-02 90 | 5.467388823840614760e-03 9.411764705882352811e-01 2.560703996353769565e-02 91 | 4.230332351641865642e-03 9.411764705882352811e-01 1.925519542470659207e-02 92 | 8.726550634628927119e-03 9.117647058823529216e-01 3.278946263529348287e-02 93 | 3.404269429677956541e-03 8.970588235294117974e-01 3.210063730784318975e-02 94 | 1.514643451061161057e-03 9.558823529411765163e-01 1.701755815481571094e-02 95 | 5.628569726471905135e-03 9.411764705882352811e-01 2.408284699357311764e-02 96 | 6.250654650469755769e-03 9.558823529411765163e-01 1.629259149797093048e-02 97 | 3.933199256265677152e-03 9.411764705882352811e-01 2.965823223631105024e-02 98 | 5.680364593144376743e-03 9.705882352941176405e-01 1.840473538545816037e-02 99 | 4.861410976584094878e-03 9.264705882352941568e-01 2.181059397090440166e-02 100 | 1.187263206438650236e-02 9.411764705882352811e-01 2.506163903296542914e-02 101 | -------------------------------------------------------------------------------- /SET-MLP-Sparse-Python-Data-Structures/data/lung.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/SET-MLP-Sparse-Python-Data-Structures/data/lung.mat -------------------------------------------------------------------------------- /SET-MLP-Sparse-Python-Data-Structures/fixprob_mlp_sparse_data_structures.py: -------------------------------------------------------------------------------- 1 | # Author: Decebal Constantin Mocanu et al.; 2 | # Proof of concept implementation of a Multi Layer Perceptron (MLP) with a fix sparsity pattern (FixProb) on lung dataset using Python, SciPy sparse data structures, and (optionally) Cython. 3 | # This implementation can be used to create MLP-FixProb with hundred of thousands of neurons. 4 | # If you would like to try out MLP-FixProb with various activation functions, optimization methods and so on (in the detriment of scalability) please use the Keras implementation from the folder "SET-MLP-Keras-Weights-Mask". 5 | 6 | # This is a pre-alpha free software and was tested in Ubuntu 16.04 with Python 3.5.2, Numpy 1.14, SciPy 0.19.1, and (optionally) Cython 0.27.3; 7 | # The code is distributed in the hope that it may be useful, but WITHOUT ANY WARRANTIES; The use of this software is entirely at the user's own risk; 8 | # For an easy understanding of the code functionality please read the following articles. 9 | 10 | # If you use parts of this code please cite the following articles: 11 | #@article{Mocanu2018SET, 12 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 13 | # journal = {Nature Communications}, 14 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 15 | # year = {2018}, 16 | # doi = {10.1038/s41467-018-04316-3} 17 | #} 18 | 19 | #@Article{Mocanu2016XBM, 20 | #author="Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio", 21 | #title="A topological insight into restricted Boltzmann machines", 22 | #journal="Machine Learning", 23 | #year="2016", 24 | #volume="104", 25 | #number="2", 26 | #pages="243--270", 27 | #doi="10.1007/s10994-016-5570-z", 28 | #url="https://doi.org/10.1007/s10994-016-5570-z" 29 | #} 30 | 31 | #@phdthesis{Mocanu2017PhDthesis, 32 | #title = "Network computations in artificial intelligence", 33 | #author = "D.C. Mocanu", 34 | #year = "2017", 35 | #isbn = "978-90-386-4305-2", 36 | #publisher = "Eindhoven University of Technology", 37 | #} 38 | 39 | # We thank to: 40 | # Thomas Hagebols: for performing a thorough analyze on the performance of SciPy sparse matrix operations 41 | # Ritchie Vink (https://www.ritchievink.com): for making available on Github a nice Python implementation of fully connected MLPs. This MLP-FixProb implementation was built on top of his MLP code: 42 | # https://github.com/ritchie46/vanilla-machine-learning/blob/master/vanilla_mlp.py 43 | 44 | import numpy as np 45 | from scipy.sparse import csr_matrix 46 | from scipy.sparse import csc_matrix 47 | from scipy.sparse import lil_matrix 48 | from scipy.sparse import coo_matrix 49 | from scipy.sparse import dok_matrix 50 | import scipy.io as sio 51 | #the "sparseoperations" Cython library was tested in Ubuntu 16.04. Please note that you may encounter some "solvable" issues if you compile it in Windows. 52 | import sparseoperations 53 | import datetime 54 | 55 | def backpropagation_updates_Numpy(a, delta, rows, cols, out): 56 | for i in range (out.shape[0]): 57 | s=0 58 | for j in range(a.shape[0]): 59 | s+=a[j,rows[i]]*delta[j, cols[i]] 60 | out[i]=s/a.shape[0] 61 | 62 | def createSparseWeights(epsilon,noRows,noCols): 63 | # generate an Erdos Renyi sparse weights mask 64 | weights=lil_matrix((noRows, noCols)) 65 | for i in range(epsilon * (noRows + noCols)): 66 | weights[np.random.randint(0,noRows),np.random.randint(0,noCols)]=np.float64(np.random.randn()/10) 67 | print ("Create sparse matrix with ",weights.getnnz()," connections and ",(weights.getnnz()/(noRows * noCols))*100,"% density level") 68 | weights=weights.tocsr() 69 | return weights 70 | 71 | 72 | 73 | class Relu: 74 | @staticmethod 75 | def activation(z): 76 | z[z < 0] = 0 77 | return z 78 | 79 | @staticmethod 80 | def prime(z): 81 | z[z < 0] = 0 82 | z[z > 0] = 1 83 | return z 84 | 85 | class Sigmoid: 86 | @staticmethod 87 | def activation(z): 88 | return 1 / (1 + np.exp(-z)) 89 | 90 | @staticmethod 91 | def prime(z): 92 | return Sigmoid.activation(z) * (1 - Sigmoid.activation(z)) 93 | 94 | 95 | class MSE: 96 | def __init__(self, activation_fn=None): 97 | """ 98 | 99 | :param activation_fn: Class object of the activation function. 100 | """ 101 | if activation_fn: 102 | self.activation_fn = activation_fn 103 | else: 104 | self.activation_fn = NoActivation 105 | 106 | def activation(self, z): 107 | return self.activation_fn.activation(z) 108 | 109 | @staticmethod 110 | def loss(y_true, y_pred): 111 | """ 112 | :param y_true: (array) One hot encoded truth vector. 113 | :param y_pred: (array) Prediction vector 114 | :return: (flt) 115 | """ 116 | return np.mean((y_pred - y_true)**2) 117 | 118 | @staticmethod 119 | def prime(y_true, y_pred): 120 | return y_pred - y_true 121 | 122 | def delta(self, y_true, y_pred): 123 | """ 124 | Back propagation error delta 125 | :return: (array) 126 | """ 127 | return self.prime(y_true, y_pred) * self.activation_fn.prime(y_pred) 128 | 129 | 130 | class NoActivation: 131 | """ 132 | This is a plugin function for no activation. 133 | 134 | f(x) = x * 1 135 | """ 136 | @staticmethod 137 | def activation(z): 138 | """ 139 | :param z: (array) w(x) + b 140 | :return: z (array) 141 | """ 142 | return z 143 | 144 | @staticmethod 145 | def prime(z): 146 | """ 147 | The prime of z * 1 = 1 148 | :param z: (array) 149 | :return: z': (array) 150 | """ 151 | return np.ones_like(z) 152 | 153 | 154 | class MLP_FixProb: 155 | def __init__(self, dimensions, activations,epsilon=20): 156 | """ 157 | :param dimensions: (tpl/ list) Dimensions of the neural net. (input, hidden layer, output) 158 | :param activations: (tpl/ list) Activations functions. 159 | 160 | Example of three hidden layer with 161 | - 3312 input features 162 | - 3000 hidden neurons 163 | - 3000 hidden neurons 164 | - 3000 hidden neurons 165 | - 5 output classes 166 | 167 | 168 | layers --> [1, 2, 3, 4, 5] 169 | ---------------------------------------- 170 | 171 | dimensions = (3312, 3000, 3000, 3000, 5) 172 | activations = ( Relu, Relu, Relu, Sigmoid) 173 | """ 174 | self.n_layers = len(dimensions) 175 | self.loss = None 176 | self.learning_rate = None 177 | self.momentum=None 178 | self.weight_decay = None 179 | self.epsilon = epsilon # control the sparsity level as discussed in the paper 180 | self.dimensions=dimensions 181 | 182 | # Weights and biases are initiated by index. For a one hidden layer net you will have a w[1] and w[2] 183 | self.w = {} 184 | self.b = {} 185 | self.pdw={} 186 | self.pdd={} 187 | 188 | # Activations are also initiated by index. For the example we will have activations[2] and activations[3] 189 | self.activations = {} 190 | for i in range(len(dimensions) - 1): 191 | self.w[i + 1] = createSparseWeights(self.epsilon, dimensions[i], dimensions[i + 1])#create sparse weight matrices 192 | self.b[i + 1] = np.zeros(dimensions[i + 1]) 193 | self.activations[i + 2] = activations[i] 194 | 195 | def _feed_forward(self, x): 196 | """ 197 | Execute a forward feed through the network. 198 | :param x: (array) Batch of input data vectors. 199 | :return: (tpl) Node outputs and activations per layer. The numbering of the output is equivalent to the layer numbers. 200 | """ 201 | 202 | # w(x) + b 203 | z = {} 204 | 205 | # activations: f(z) 206 | a = {1: x} # First layer has no activations as input. The input x is the input. 207 | 208 | for i in range(1, self.n_layers): 209 | # current layer = i 210 | # activation layer = i + 1 211 | z[i + 1] = a[i]@self.w[i] + self.b[i] 212 | a[i + 1] = self.activations[i + 1].activation(z[i + 1]) 213 | 214 | return z, a 215 | 216 | def _back_prop(self, z, a, y_true): 217 | """ 218 | The input dicts keys represent the layers of the net. 219 | 220 | a = { 1: x, 221 | 2: f(w1(x) + b1) 222 | 3: f(w2(a2) + b2) 223 | 4: f(w3(a3) + b3) 224 | 5: f(w4(a4) + b4) 225 | } 226 | 227 | :param z: (dict) w(x) + b 228 | :param a: (dict) f(z) 229 | :param y_true: (array) One hot encoded truth vector. 230 | :return: 231 | """ 232 | 233 | # Determine partial derivative and delta for the output layer. 234 | # delta output layer 235 | delta = self.loss.delta(y_true, a[self.n_layers]) 236 | dw=coo_matrix(self.w[self.n_layers-1]) 237 | 238 | # compute backpropagation updates 239 | sparseoperations.backpropagation_updates_Cython(a[self.n_layers - 1],delta,dw.row,dw.col,dw.data) 240 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 241 | #backpropagation_updates_Numpy(a[self.n_layers - 1], delta, dw.row, dw.col, dw.data) 242 | 243 | update_params = { 244 | self.n_layers - 1: (dw.tocsr(), delta) 245 | } 246 | 247 | # In case of three layer net will iterate over i = 2 and i = 1 248 | # Determine partial derivative and delta for the rest of the layers. 249 | # Each iteration requires the delta from the previous layer, propagating backwards. 250 | for i in reversed(range(2, self.n_layers)): 251 | delta = (delta@self.w[i].transpose()) * self.activations[i].prime(z[i]) 252 | dw = coo_matrix(self.w[i - 1]) 253 | 254 | # compute backpropagation updates 255 | sparseoperations.backpropagation_updates_Cython(a[i - 1], delta, dw.row, dw.col, dw.data) 256 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 257 | #backpropagation_updates_Numpy(a[i - 1], delta, dw.row, dw.col, dw.data) 258 | 259 | update_params[i - 1] = (dw.tocsr(), delta) 260 | for k, v in update_params.items(): 261 | self._update_w_b(k, v[0], v[1]) 262 | 263 | def _update_w_b(self, index, dw, delta): 264 | """ 265 | Update weights and biases. 266 | 267 | :param index: (int) Number of the layer 268 | :param dw: (array) Partial derivatives 269 | :param delta: (array) Delta error. 270 | """ 271 | 272 | #perform the update with momentum 273 | if (index not in self.pdw): 274 | self.pdw[index]=-self.learning_rate * dw 275 | self.pdd[index] = - self.learning_rate * np.mean(delta, 0) 276 | else: 277 | self.pdw[index]= self.momentum*self.pdw[index]-self.learning_rate * dw 278 | self.pdd[index] = self.momentum * self.pdd[index] - self.learning_rate * np.mean(delta, 0) 279 | 280 | self.w[index] += self.pdw[index]-self.weight_decay*self.w[index] 281 | self.b[index] += self.pdd[index]-self.weight_decay*self.b[index] 282 | 283 | 284 | def fit(self, x, y_true, x_test,y_test,loss, epochs, batch_size, learning_rate=1e-3, momentum=0.9, weight_decay=0.0002, testing=True, save_filename=""): 285 | """ 286 | :param x: (array) Containing parameters 287 | :param y_true: (array) Containing one hot encoded labels. 288 | :param loss: Loss class (MSE, CrossEntropy etc.) 289 | :param epochs: (int) Number of epochs. 290 | :param batch_size: (int) 291 | :param learning_rate: (flt) 292 | :param momentum: (flt) 293 | :param weight_decay: (flt) 294 | :return (array) A 2D array of metrics (epochs, 3). 295 | """ 296 | if not x.shape[0] == y_true.shape[0]: 297 | raise ValueError("Length of x and y arrays don't match") 298 | # Initiate the loss object with the final activation function 299 | self.loss = loss(self.activations[self.n_layers]) 300 | self.learning_rate = learning_rate 301 | self.momentum = momentum 302 | self.weight_decay = weight_decay 303 | 304 | maximum_accuracy=0 305 | 306 | metrics=np.zeros((epochs,3)) 307 | 308 | for i in range(epochs): 309 | # Shuffle the data 310 | seed = np.arange(x.shape[0]) 311 | np.random.shuffle(seed) 312 | x_=x[seed] 313 | y_=y_true[seed] 314 | 315 | #training 316 | t1 = datetime.datetime.now() 317 | losstrain=0 318 | for j in range(x.shape[0] // batch_size): 319 | k = j * batch_size 320 | l = (j + 1) * batch_size 321 | z, a = self._feed_forward(x_[k:l]) 322 | losstrain+=self.loss.loss(y_[k:l], a[self.n_layers]) 323 | self._back_prop(z, a, y_[k:l]) 324 | # ToDo: adding dropout would improve the performance and decrease overfitting 325 | t2 = datetime.datetime.now() 326 | metrics[i, 0]=losstrain / (x.shape[0] // batch_size) 327 | print ("\nMLP-FixProb Epoch ",i) 328 | print ("Training time: ",t2-t1,"; Loss train: ",losstrain / (x.shape[0] // batch_size)) 329 | 330 | # test model performance on the test data at each epoch 331 | # this part is useful to understand model performance and can be commented for production settings 332 | if (testing): 333 | t3 = datetime.datetime.now() 334 | accuracy,activations=self.predict(x_test,y_test,batch_size) 335 | t4 = datetime.datetime.now() 336 | maximum_accuracy=max(maximum_accuracy,accuracy) 337 | losstest=self.loss.loss(y_test, activations) 338 | metrics[i, 1] = accuracy 339 | metrics[i, 2] = losstest 340 | print("Testing time: ", t4 - t3, "; Loss test: ", losstest,"; Accuracy: ", accuracy,"; Maximum accuracy: ", maximum_accuracy) 341 | 342 | #save performance metrics values in a file 343 | if (save_filename!=""): 344 | np.savetxt(save_filename,metrics) 345 | 346 | return metrics 347 | 348 | def predict(self, x_test,y_test,batch_size=1): 349 | """ 350 | :param x_test: (array) Test input 351 | :param y_test: (array) Correct test output 352 | :param batch_size: 353 | :return: (flt) Classification accuracy 354 | :return: (array) A 2D array of shape (n_cases, n_classes). 355 | """ 356 | activations = np.zeros((y_test.shape[0], y_test.shape[1])) 357 | for j in range(x_test.shape[0] // batch_size): 358 | k = j * batch_size 359 | l = (j + 1) * batch_size 360 | _, a_test = self._feed_forward(x_test[k:l]) 361 | activations[k:l] = a_test[self.n_layers] 362 | correctClassification = 0 363 | for j in range(y_test.shape[0]): 364 | if (np.argmax(activations[j]) == np.argmax(y_test[j])): 365 | correctClassification += 1 366 | accuracy= correctClassification/y_test.shape[0] 367 | return accuracy, activations 368 | 369 | if __name__ == "__main__": 370 | # Comment this if you would like to use the full power of randomization. I use it to have repeatable results. 371 | np.random.seed(0) 372 | 373 | # load data 374 | mat = sio.loadmat('data/lung.mat') #lung dataset was downloaded from http://featureselection.asu.edu/ 375 | X = mat['X'] 376 | # one hot encoding 377 | noClasses = np.max(mat['Y']) 378 | Y=np.zeros((mat['Y'].shape[0],noClasses)) 379 | for i in range(Y.shape[0]): 380 | Y[i,mat['Y'][i]-1]=1 381 | 382 | #split data in training and testing 383 | indices=np.arange(X.shape[0]) 384 | np.random.shuffle(indices) 385 | X_train=X[indices[0:int(X.shape[0]*2/3)]] 386 | Y_train=Y[indices[0:int(X.shape[0]*2/3)]] 387 | X_test=X[indices[int(X.shape[0]*2/3):]] 388 | Y_test=Y[indices[int(X.shape[0]*2/3):]] 389 | 390 | #normalize data 391 | X_train = X_train.astype('float64') 392 | X_test = X_test.astype('float64') 393 | xTrainMean = np.mean(X_train, axis=0) 394 | xTtrainStd = np.std(X_train, axis=0) 395 | X_train = (X_train - xTrainMean) / (xTtrainStd+0.0001) 396 | X_test = (X_test - xTrainMean) / (xTtrainStd+0.0001) 397 | 398 | # create MLP-FixProb 399 | mlp_fixprob = MLP_FixProb(( X_train.shape[1], 3000, 3000, 3000,Y_train.shape[1]), (Relu, Relu,Relu,Sigmoid),epsilon=20) 400 | 401 | # train MLP-FixProb 402 | mlp_fixprob.fit(X_train, Y_train, X_test,Y_test,loss=MSE, epochs=100, batch_size=2, learning_rate=0.01, momentum=0.9, weight_decay=0.0002, testing=True,save_filename="Results/mlp_fixprob.txt") 403 | 404 | # test MLP-FixProb 405 | accuracy,_=mlp_fixprob.predict(X_test,Y_test,batch_size=1) 406 | 407 | print ("\nAccuracy of the last epoch on the testing data: ",accuracy) 408 | -------------------------------------------------------------------------------- /SET-MLP-Sparse-Python-Data-Structures/set_mlp_sparse_data_structures.py: -------------------------------------------------------------------------------- 1 | # Author: Decebal Constantin Mocanu et al.; 2 | # Proof of concept implementation of Sparse Evolutionary Training (SET) of Multi Layer Perceptron (MLP) on lung dataset using Python, SciPy sparse data structures, and (optionally) Cython. 3 | # This implementation can be used to create SET-MLP with hundred of thousands of neurons. 4 | # If you would like to try out SET-MLP with various activation functions, optimization methods and so on (in the detriment of scalability) please use the Keras implementation from the folder "SET-MLP-Keras-Weights-Mask". 5 | 6 | # This is a pre-alpha free software and was tested in Ubuntu 16.04 with Python 3.5.2, Numpy 1.14, SciPy 0.19.1, and (optionally) Cython 0.27.3; 7 | # The code is distributed in the hope that it may be useful, but WITHOUT ANY WARRANTIES; The use of this software is entirely at the user's own risk; 8 | # For an easy understanding of the code functionality please read the following articles. 9 | 10 | # If you use parts of this code please cite the following articles: 11 | #@article{Mocanu2018SET, 12 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 13 | # journal = {Nature Communications}, 14 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 15 | # year = {2018}, 16 | # doi = {10.1038/s41467-018-04316-3} 17 | #} 18 | 19 | #@Article{Mocanu2016XBM, 20 | #author="Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio", 21 | #title="A topological insight into restricted Boltzmann machines", 22 | #journal="Machine Learning", 23 | #year="2016", 24 | #volume="104", 25 | #number="2", 26 | #pages="243--270", 27 | #doi="10.1007/s10994-016-5570-z", 28 | #url="https://doi.org/10.1007/s10994-016-5570-z" 29 | #} 30 | 31 | #@phdthesis{Mocanu2017PhDthesis, 32 | #title = "Network computations in artificial intelligence", 33 | #author = "D.C. Mocanu", 34 | #year = "2017", 35 | #isbn = "978-90-386-4305-2", 36 | #publisher = "Eindhoven University of Technology", 37 | #} 38 | 39 | # We thank to: 40 | # Thomas Hagebols: for performing a thorough analyze on the performance of SciPy sparse matrix operations 41 | # Ritchie Vink (https://www.ritchievink.com): for making available on Github a nice Python implementation of fully connected MLPs. This SET-MLP implementation was built on top of his MLP code: 42 | # https://github.com/ritchie46/vanilla-machine-learning/blob/master/vanilla_mlp.py 43 | # Amarsagar Reddy Ramapuram Matavalam: for provided a fast fast implementation for the "weightsEvolution" method, after the initial release of this code. 44 | 45 | import numpy as np 46 | from scipy.sparse import csr_matrix 47 | from scipy.sparse import csc_matrix 48 | from scipy.sparse import lil_matrix 49 | from scipy.sparse import coo_matrix 50 | from scipy.sparse import dok_matrix 51 | import scipy.io as sio 52 | #the "sparseoperations" Cython library was tested in Ubuntu 16.04. Please note that you may encounter some "solvable" issues if you compile it in Windows. 53 | import sparseoperations 54 | import datetime 55 | 56 | def backpropagation_updates_Numpy(a, delta, rows, cols, out): 57 | for i in range (out.shape[0]): 58 | s=0 59 | for j in range(a.shape[0]): 60 | s+=a[j,rows[i]]*delta[j, cols[i]] 61 | out[i]=s/a.shape[0] 62 | 63 | def find_first_pos(array, value): 64 | idx = (np.abs(array - value)).argmin() 65 | return idx 66 | 67 | 68 | def find_last_pos(array, value): 69 | idx = (np.abs(array - value))[::-1].argmin() 70 | return array.shape[0] - idx 71 | 72 | def createSparseWeights(epsilon,noRows,noCols): 73 | # generate an Erdos Renyi sparse weights mask 74 | weights=lil_matrix((noRows, noCols)) 75 | for i in range(epsilon * (noRows + noCols)): 76 | weights[np.random.randint(0,noRows),np.random.randint(0,noCols)]=np.float64(np.random.randn()/10) 77 | print ("Create sparse matrix with ",weights.getnnz()," connections and ",(weights.getnnz()/(noRows * noCols))*100,"% density level") 78 | weights=weights.tocsr() 79 | return weights 80 | 81 | 82 | def array_intersect(A, B): 83 | # added by Amarsagar Reddy Ramapuram Matavalam (amar@iastate.edu) 84 | # this are for array intersection 85 | # inspired by https://stackoverflow.com/questions/8317022/get-intersecting-rows-across-two-2d-numpy-arrays 86 | nrows, ncols = A.shape 87 | dtype = {'names': ['f{}'.format(i) for i in range(ncols)], 'formats': ncols * [A.dtype]} 88 | return np.in1d(A.view(dtype), B.view(dtype)) # boolean return 89 | 90 | class Relu: 91 | @staticmethod 92 | def activation(z): 93 | z[z < 0] = 0 94 | return z 95 | 96 | @staticmethod 97 | def prime(z): 98 | z[z < 0] = 0 99 | z[z > 0] = 1 100 | return z 101 | 102 | class Sigmoid: 103 | @staticmethod 104 | def activation(z): 105 | return 1 / (1 + np.exp(-z)) 106 | 107 | @staticmethod 108 | def prime(z): 109 | return Sigmoid.activation(z) * (1 - Sigmoid.activation(z)) 110 | 111 | 112 | class MSE: 113 | def __init__(self, activation_fn=None): 114 | """ 115 | 116 | :param activation_fn: Class object of the activation function. 117 | """ 118 | if activation_fn: 119 | self.activation_fn = activation_fn 120 | else: 121 | self.activation_fn = NoActivation 122 | 123 | def activation(self, z): 124 | return self.activation_fn.activation(z) 125 | 126 | @staticmethod 127 | def loss(y_true, y_pred): 128 | """ 129 | :param y_true: (array) One hot encoded truth vector. 130 | :param y_pred: (array) Prediction vector 131 | :return: (flt) 132 | """ 133 | return np.mean((y_pred - y_true)**2) 134 | 135 | @staticmethod 136 | def prime(y_true, y_pred): 137 | return y_pred - y_true 138 | 139 | def delta(self, y_true, y_pred): 140 | """ 141 | Back propagation error delta 142 | :return: (array) 143 | """ 144 | return self.prime(y_true, y_pred) * self.activation_fn.prime(y_pred) 145 | 146 | 147 | class NoActivation: 148 | """ 149 | This is a plugin function for no activation. 150 | 151 | f(x) = x * 1 152 | """ 153 | @staticmethod 154 | def activation(z): 155 | """ 156 | :param z: (array) w(x) + b 157 | :return: z (array) 158 | """ 159 | return z 160 | 161 | @staticmethod 162 | def prime(z): 163 | """ 164 | The prime of z * 1 = 1 165 | :param z: (array) 166 | :return: z': (array) 167 | """ 168 | return np.ones_like(z) 169 | 170 | 171 | class SET_MLP: 172 | def __init__(self, dimensions, activations,epsilon=20): 173 | """ 174 | :param dimensions: (tpl/ list) Dimensions of the neural net. (input, hidden layer, output) 175 | :param activations: (tpl/ list) Activations functions. 176 | 177 | Example of three hidden layer with 178 | - 3312 input features 179 | - 3000 hidden neurons 180 | - 3000 hidden neurons 181 | - 3000 hidden neurons 182 | - 5 output classes 183 | 184 | 185 | layers --> [1, 2, 3, 4, 5] 186 | ---------------------------------------- 187 | 188 | dimensions = (3312, 3000, 3000, 3000, 5) 189 | activations = ( Relu, Relu, Relu, Sigmoid) 190 | """ 191 | self.n_layers = len(dimensions) 192 | self.loss = None 193 | self.learning_rate = None 194 | self.momentum=None 195 | self.weight_decay = None 196 | self.epsilon = epsilon # control the sparsity level as discussed in the paper 197 | self.zeta = None # the fraction of the weights removed 198 | self.dimensions=dimensions 199 | 200 | # Weights and biases are initiated by index. For a one hidden layer net you will have a w[1] and w[2] 201 | self.w = {} 202 | self.b = {} 203 | self.pdw={} 204 | self.pdd={} 205 | 206 | # Activations are also initiated by index. For the example we will have activations[2] and activations[3] 207 | self.activations = {} 208 | for i in range(len(dimensions) - 1): 209 | self.w[i + 1] = createSparseWeights(self.epsilon, dimensions[i], dimensions[i + 1])#create sparse weight matrices 210 | self.b[i + 1] = np.zeros(dimensions[i + 1]) 211 | self.activations[i + 2] = activations[i] 212 | 213 | def _feed_forward(self, x): 214 | """ 215 | Execute a forward feed through the network. 216 | :param x: (array) Batch of input data vectors. 217 | :return: (tpl) Node outputs and activations per layer. The numbering of the output is equivalent to the layer numbers. 218 | """ 219 | 220 | # w(x) + b 221 | z = {} 222 | 223 | # activations: f(z) 224 | a = {1: x} # First layer has no activations as input. The input x is the input. 225 | 226 | for i in range(1, self.n_layers): 227 | # current layer = i 228 | # activation layer = i + 1 229 | z[i + 1] = a[i]@self.w[i] + self.b[i] 230 | a[i + 1] = self.activations[i + 1].activation(z[i + 1]) 231 | 232 | return z, a 233 | 234 | def _back_prop(self, z, a, y_true): 235 | """ 236 | The input dicts keys represent the layers of the net. 237 | 238 | a = { 1: x, 239 | 2: f(w1(x) + b1) 240 | 3: f(w2(a2) + b2) 241 | 4: f(w3(a3) + b3) 242 | 5: f(w4(a4) + b4) 243 | } 244 | 245 | :param z: (dict) w(x) + b 246 | :param a: (dict) f(z) 247 | :param y_true: (array) One hot encoded truth vector. 248 | :return: 249 | """ 250 | 251 | # Determine partial derivative and delta for the output layer. 252 | # delta output layer 253 | delta = self.loss.delta(y_true, a[self.n_layers]) 254 | dw=coo_matrix(self.w[self.n_layers-1]) 255 | 256 | # compute backpropagation updates 257 | sparseoperations.backpropagation_updates_Cython(a[self.n_layers - 1],delta,dw.row,dw.col,dw.data) 258 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 259 | #backpropagation_updates_Numpy(a[self.n_layers - 1], delta, dw.row, dw.col, dw.data) 260 | 261 | update_params = { 262 | self.n_layers - 1: (dw.tocsr(), delta) 263 | } 264 | 265 | # In case of three layer net will iterate over i = 2 and i = 1 266 | # Determine partial derivative and delta for the rest of the layers. 267 | # Each iteration requires the delta from the previous layer, propagating backwards. 268 | for i in reversed(range(2, self.n_layers)): 269 | delta = (delta@self.w[i].transpose()) * self.activations[i].prime(z[i]) 270 | dw = coo_matrix(self.w[i - 1]) 271 | 272 | # compute backpropagation updates 273 | sparseoperations.backpropagation_updates_Cython(a[i - 1], delta, dw.row, dw.col, dw.data) 274 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 275 | #backpropagation_updates_Numpy(a[i - 1], delta, dw.row, dw.col, dw.data) 276 | 277 | update_params[i - 1] = (dw.tocsr(), delta) 278 | for k, v in update_params.items(): 279 | self._update_w_b(k, v[0], v[1]) 280 | 281 | def _update_w_b(self, index, dw, delta): 282 | """ 283 | Update weights and biases. 284 | 285 | :param index: (int) Number of the layer 286 | :param dw: (array) Partial derivatives 287 | :param delta: (array) Delta error. 288 | """ 289 | 290 | #perform the update with momentum 291 | if (index not in self.pdw): 292 | self.pdw[index]=-self.learning_rate * dw 293 | self.pdd[index] = - self.learning_rate * np.mean(delta, 0) 294 | else: 295 | self.pdw[index]= self.momentum*self.pdw[index]-self.learning_rate * dw 296 | self.pdd[index] = self.momentum * self.pdd[index] - self.learning_rate * np.mean(delta, 0) 297 | 298 | self.w[index] += self.pdw[index]-self.weight_decay*self.w[index] 299 | self.b[index] += self.pdd[index]-self.weight_decay*self.b[index] 300 | 301 | 302 | def fit(self, x, y_true, x_test,y_test,loss, epochs, batch_size, learning_rate=1e-3, momentum=0.9, weight_decay=0.0002, zeta=0.3, testing=True, save_filename=""): 303 | """ 304 | :param x: (array) Containing parameters 305 | :param y_true: (array) Containing one hot encoded labels. 306 | :param loss: Loss class (MSE, CrossEntropy etc.) 307 | :param epochs: (int) Number of epochs. 308 | :param batch_size: (int) 309 | :param learning_rate: (flt) 310 | :param momentum: (flt) 311 | :param weight_decay: (flt) 312 | :param zeta: (flt) #control the fraction of weights removed 313 | :return (array) A 2D array of metrics (epochs, 3). 314 | """ 315 | if not x.shape[0] == y_true.shape[0]: 316 | raise ValueError("Length of x and y arrays don't match") 317 | # Initiate the loss object with the final activation function 318 | self.loss = loss(self.activations[self.n_layers]) 319 | self.learning_rate = learning_rate 320 | self.momentum = momentum 321 | self.weight_decay = weight_decay 322 | self.zeta = zeta 323 | 324 | maximum_accuracy=0 325 | 326 | metrics=np.zeros((epochs,3)) 327 | 328 | for i in range(epochs): 329 | # Shuffle the data 330 | seed = np.arange(x.shape[0]) 331 | np.random.shuffle(seed) 332 | x_=x[seed] 333 | y_=y_true[seed] 334 | 335 | #training 336 | t1 = datetime.datetime.now() 337 | losstrain=0 338 | for j in range(x.shape[0] // batch_size): 339 | k = j * batch_size 340 | l = (j + 1) * batch_size 341 | z, a = self._feed_forward(x_[k:l]) 342 | losstrain+=self.loss.loss(y_[k:l], a[self.n_layers]) 343 | self._back_prop(z, a, y_[k:l]) 344 | #ToDo: adding dropout would improve the performance and decrease overfitting 345 | t2 = datetime.datetime.now() 346 | metrics[i, 0]=losstrain / (x.shape[0] // batch_size) 347 | print ("\nSET-MLP Epoch ",i) 348 | print ("Training time: ",t2-t1,"; Loss train: ",losstrain / (x.shape[0] // batch_size)) 349 | 350 | # test model performance on the test data at each epoch 351 | # this part is useful to understand model performance and can be commented for production settings 352 | if (testing): 353 | t3 = datetime.datetime.now() 354 | accuracy,activations=self.predict(x_test,y_test,batch_size) 355 | t4 = datetime.datetime.now() 356 | maximum_accuracy=max(maximum_accuracy,accuracy) 357 | losstest=self.loss.loss(y_test, activations) 358 | metrics[i, 1] = accuracy 359 | metrics[i, 2] = losstest 360 | print("Testing time: ", t4 - t3, "; Loss test: ", losstest,"; Accuracy: ", accuracy,"; Maximum accuracy: ", maximum_accuracy) 361 | 362 | t5 = datetime.datetime.now() 363 | if (i smallestPositive)): 399 | wdok[ik,jk]=val 400 | pdwdok[ik,jk]=pdwlil[ik,jk] 401 | keepConnections+=1 402 | 403 | # add new random connections 404 | for kk in range(self.w[i].data.shape[0]-keepConnections): 405 | ik = np.random.randint(0, self.dimensions[i - 1]) 406 | jk = np.random.randint(0, self.dimensions[i]) 407 | while (wdok[ik,jk]!=0): 408 | ik = np.random.randint(0, self.dimensions[i - 1]) 409 | jk = np.random.randint(0, self.dimensions[i]) 410 | wdok[ik, jk]=np.random.randn() / 10 411 | pdwdok[ik, jk] = 0 412 | 413 | self.pdw[i]=pdwdok.tocsr() 414 | self.w[i]=wdok.tocsr() 415 | 416 | def weightsEvolution_Amar(self): 417 | # this represents the core of the SET procedure. It removes the weights closest to zero in each layer and add new random weights 418 | # improved running time using numpy routines - Amarsagar Reddy Ramapuram Matavalam (amar@iastate.edu) 419 | for i in range(1,self.n_layers): 420 | # uncomment line below to stop evolution of dense weights more than 80% non-zeros 421 | #if(self.w[i].count_nonzero()/(self.w[i].get_shape()[0]*self.w[i].get_shape()[1]) < 0.8): 422 | t_ev_1 = datetime.datetime.now() 423 | # converting to COO form - Added by Amar 424 | wcoo=self.w[i].tocoo() 425 | valsW=wcoo.data 426 | rowsW=wcoo.row 427 | colsW=wcoo.col 428 | 429 | pdcoo=self.pdw[i].tocoo() 430 | valsPD=pdcoo.data 431 | rowsPD=pdcoo.row 432 | colsPD=pdcoo.col 433 | #print("Number of non zeros in W and PD matrix before evolution in layer",i,[np.size(valsW), np.size(valsPD)]) 434 | values=np.sort(self.w[i].data) 435 | firstZeroPos = find_first_pos(values, 0) 436 | lastZeroPos = find_last_pos(values, 0) 437 | 438 | largestNegative = values[int((1-self.zeta) * firstZeroPos)] 439 | smallestPositive = values[int(min(values.shape[0] - 1, lastZeroPos + self.zeta * (values.shape[0] - lastZeroPos)))] 440 | 441 | #remove the weights (W) closest to zero and modify PD as well 442 | valsWNew=valsW[(valsW > smallestPositive) | (valsW < largestNegative)] 443 | rowsWNew=rowsW[(valsW > smallestPositive) | (valsW < largestNegative)] 444 | colsWNew=colsW[(valsW > smallestPositive) | (valsW < largestNegative)] 445 | 446 | newWRowColIndex=np.stack((rowsWNew,colsWNew) , axis=-1) 447 | oldPDRowColIndex=np.stack((rowsPD,colsPD) , axis=-1) 448 | 449 | 450 | newPDRowColIndexFlag=array_intersect(oldPDRowColIndex,newWRowColIndex) # careful about order 451 | 452 | valsPDNew=valsPD[newPDRowColIndexFlag] 453 | rowsPDNew=rowsPD[newPDRowColIndexFlag] 454 | colsPDNew=colsPD[newPDRowColIndexFlag] 455 | 456 | self.pdw[i] = coo_matrix((valsPDNew, (rowsPDNew, colsPDNew)),(self.dimensions[i - 1], self.dimensions[i])).tocsr() 457 | 458 | # add new random connections 459 | keepConnections=np.size(rowsWNew) 460 | lengthRandom=valsW.shape[0]-keepConnections 461 | randomVals=np.random.randn(lengthRandom) / 10 # to avoid multiple whiles, can we call 3*rand? 462 | zeroVals=0*randomVals # explicit zeros 463 | 464 | # adding (wdok[ik,jk]!=0): condition 465 | while (lengthRandom>0): 466 | ik = np.random.randint(0, self.dimensions[i - 1],size=lengthRandom,dtype='int32') 467 | jk = np.random.randint(0, self.dimensions[i],size=lengthRandom,dtype='int32') 468 | 469 | randomWRowColIndex=np.stack((ik,jk) , axis=-1) 470 | randomWRowColIndex=np.unique(randomWRowColIndex,axis=0) # removing duplicates in new rows&cols 471 | oldWRowColIndex=np.stack((rowsWNew,colsWNew) , axis=-1) 472 | 473 | uniqueFlag=~array_intersect(randomWRowColIndex,oldWRowColIndex) # careful about order & tilda 474 | 475 | 476 | ikNew=randomWRowColIndex[uniqueFlag][:,0] 477 | jkNew=randomWRowColIndex[uniqueFlag][:,1] 478 | # be careful - row size and col size needs to be verified 479 | rowsWNew=np.append(rowsWNew, ikNew) 480 | colsWNew=np.append(colsWNew, jkNew) 481 | 482 | lengthRandom=valsW.shape[0]-np.size(rowsWNew) # this will constantly reduce lengthRandom 483 | 484 | # adding all the values along with corresponding row and column indices - Added by Amar 485 | valsWNew=np.append(valsWNew, randomVals) # be careful - we can add to an existing link ? 486 | #valsPDNew=np.append(valsPDNew, zeroVals) # be careful - adding explicit zeros - any reason?? 487 | if (valsWNew.shape[0] != rowsWNew.shape[0]): 488 | print("not good") 489 | self.w[i]=coo_matrix((valsWNew , (rowsWNew , colsWNew)),(self.dimensions[i-1],self.dimensions[i])).tocsr() 490 | 491 | #print("Number of non zeros in W and PD matrix after evolution in layer",i,[(self.w[i].data.shape[0]), (self.pdw[i].data.shape[0])]) 492 | 493 | t_ev_2 = datetime.datetime.now() 494 | #print("Weights evolution time for layer",i,"is", t_ev_2 - t_ev_1) 495 | 496 | 497 | def predict(self, x_test,y_test,batch_size=1): 498 | """ 499 | :param x_test: (array) Test input 500 | :param y_test: (array) Correct test output 501 | :param batch_size: 502 | :return: (flt) Classification accuracy 503 | :return: (array) A 2D array of shape (n_cases, n_classes). 504 | """ 505 | activations = np.zeros((y_test.shape[0], y_test.shape[1])) 506 | for j in range(x_test.shape[0] // batch_size): 507 | k = j * batch_size 508 | l = (j + 1) * batch_size 509 | _, a_test = self._feed_forward(x_test[k:l]) 510 | activations[k:l] = a_test[self.n_layers] 511 | correctClassification = 0 512 | for j in range(y_test.shape[0]): 513 | if (np.argmax(activations[j]) == np.argmax(y_test[j])): 514 | correctClassification += 1 515 | accuracy= correctClassification/y_test.shape[0] 516 | return accuracy, activations 517 | 518 | if __name__ == "__main__": 519 | # Comment this if you would like to use the full power of randomization. I use it to have repeatable results. 520 | np.random.seed(0) 521 | 522 | # load data 523 | mat = sio.loadmat('data/lung.mat') #lung dataset was downloaded from http://featureselection.asu.edu/ 524 | # As the lung dataset has just few hundred samples, and few thousands features, you will observe a high variance in the accuracy from one epoch to another. 525 | # We chose this dataset to show how the SET-MLP model handles overfitting. 526 | # To see a much more stable behaviour of the model please experiment with datasets with a higher amount of samples, e.g. COIL-100 can be a really nice one as it has 100 classes and also the last layer will be sparse. 527 | X = mat['X'] 528 | # one hot encoding 529 | noClasses = np.max(mat['Y']) 530 | Y=np.zeros((mat['Y'].shape[0],noClasses)) 531 | for i in range(Y.shape[0]): 532 | Y[i,mat['Y'][i]-1]=1 533 | 534 | #split data in training and testing 535 | indices=np.arange(X.shape[0]) 536 | np.random.shuffle(indices) 537 | X_train=X[indices[0:int(X.shape[0]*2/3)]] 538 | Y_train=Y[indices[0:int(X.shape[0]*2/3)]] 539 | X_test=X[indices[int(X.shape[0]*2/3):]] 540 | Y_test=Y[indices[int(X.shape[0]*2/3):]] 541 | 542 | #normalize data to have 0 mean and unit variance 543 | X_train = X_train.astype('float64') 544 | X_test = X_test.astype('float64') 545 | xTrainMean = np.mean(X_train, axis=0) 546 | xTtrainStd = np.std(X_train, axis=0) 547 | X_train = (X_train - xTrainMean) / (xTtrainStd+0.0001) 548 | X_test = (X_test - xTrainMean) / (xTtrainStd+0.0001) 549 | 550 | # create SET-MLP 551 | set_mlp = SET_MLP((X_train.shape[1], 3000, 3000, 3000, Y_train.shape[1]), (Relu, Relu, Relu, Sigmoid), epsilon=20) 552 | 553 | # train SET-MLP 554 | set_mlp.fit(X_train, Y_train, X_test, Y_test, loss=MSE, epochs=100, batch_size=2, learning_rate=0.01, momentum=0.9, weight_decay=0.0002, zeta=0.3, testing=True, save_filename="Results/set_mlp.txt") 555 | 556 | # test SET-MLP 557 | accuracy,_=set_mlp.predict(X_test,Y_test,batch_size=1) 558 | 559 | print ("\nAccuracy of the last epoch on the testing data: ",accuracy) 560 | -------------------------------------------------------------------------------- /SET-MLP-Sparse-Python-Data-Structures/sparseoperations.cpython-35m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/SET-MLP-Sparse-Python-Data-Structures/sparseoperations.cpython-35m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /SET-MLP-Sparse-Python-Data-Structures/sparseoperations.pyx: -------------------------------------------------------------------------------- 1 | # compile this file with: "cythonize -a -i sparseoperations.pyx" 2 | # I have tested this method in Linux (Ubuntu). If you compile it in Windows you may need some work around. 3 | 4 | cimport numpy as np 5 | 6 | def backpropagation_updates_Cython(np.ndarray[np.float64_t,ndim=2] a, np.ndarray[np.float64_t,ndim=2] delta, np.ndarray[int,ndim=1] rows, np.ndarray[int,ndim=1] cols,np.ndarray[np.float64_t,ndim=1] out): 7 | cdef: 8 | size_t i,j 9 | double s 10 | for i in range (out.shape[0]): 11 | s=0 12 | for j in range(a.shape[0]): 13 | s+=a[j,rows[i]]*delta[j, cols[i]] 14 | out[i]=s/a.shape[0] 15 | #return out 16 | 17 | -------------------------------------------------------------------------------- /SET-RBM-Sparse-Python-Data-Structures/data/COIL20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/SET-RBM-Sparse-Python-Data-Structures/data/COIL20.mat -------------------------------------------------------------------------------- /SET-RBM-Sparse-Python-Data-Structures/fixprob_rbm_sparse_data_structures.py: -------------------------------------------------------------------------------- 1 | # Author: Decebal Constantin Mocanu et al.; 2 | # Proof of concept implementation of a Restricted Boltzmann Machine (RBM) with a fix sparsity pattern (FixProb) on COIL20 dataset using Python, SciPy sparse data structures, and (optionally) Cython. 3 | # This implementation can be used to create RBM-FixProb with hundred of thousands of neurons. 4 | 5 | # This is a pre-alpha free software and was tested in Ubuntu 16.04 with Python 3.5.2, Numpy 1.14, SciPy 0.19.1, and (optionally) Cython 0.27.3; 6 | # The code is distributed in the hope that it may be useful, but WITHOUT ANY WARRANTIES; The use of this software is entirely at the user's own risk; 7 | # For an easy understanding of the code functionality please read the following articles. 8 | 9 | # If you use parts of this code please cite the following articles: 10 | #@article{Mocanu2018SET, 11 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 12 | # journal = {Nature Communications}, 13 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 14 | # year = {2018}, 15 | # doi = {10.1038/s41467-018-04316-3} 16 | #} 17 | 18 | #@Article{Mocanu2016XBM, 19 | #author="Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio", 20 | #title="A topological insight into restricted Boltzmann machines", 21 | #journal="Machine Learning", 22 | #year="2016", 23 | #volume="104", 24 | #number="2", 25 | #pages="243--270", 26 | #doi="10.1007/s10994-016-5570-z", 27 | #url="https://doi.org/10.1007/s10994-016-5570-z" 28 | #} 29 | 30 | #@phdthesis{Mocanu2017PhDthesis, 31 | #title = "Network computations in artificial intelligence", 32 | #author = "D.C. Mocanu", 33 | #year = "2017", 34 | #isbn = "978-90-386-4305-2", 35 | #publisher = "Eindhoven University of Technology", 36 | #} 37 | 38 | # We thank to: 39 | # Thomas Hagebols: for performing a thorough analyze on the performance of SciPy sparse matrix operations 40 | 41 | import numpy as np 42 | from scipy.sparse import lil_matrix 43 | from scipy.sparse import dok_matrix 44 | #the "sparseoperations" Cython library was tested in Ubuntu 16.04. Please note that you may encounter some "solvable" issues if you compile it in Windows. 45 | import sparseoperations 46 | import datetime 47 | import scipy.io as sio 48 | import matplotlib.pyplot as plt 49 | 50 | def contrastive_divergence_updates_Numpy(wDecay, lr, DV, DH, MV, MH, rows, cols, out): 51 | for i in range (out.shape[0]): 52 | s1=0 53 | s2=0 54 | for j in range(DV.shape[0]): 55 | s1+=DV[j,rows[i]]*DH[j, cols[i]] 56 | s2+=MV[j,rows[i]]*MH[j, cols[i]] 57 | out[i]+=lr*(s1/DV.shape[0]-s2/DV.shape[0])-wDecay*out[i] 58 | #return out 59 | 60 | def createSparseWeights(epsilon,noRows,noCols): 61 | # generate an Erdos Renyi sparse weights mask 62 | weights=lil_matrix((noRows, noCols)) 63 | for i in range(epsilon * (noRows + noCols)): 64 | weights[np.random.randint(0,noRows),np.random.randint(0,noCols)]=np.float64(np.random.randn()/20) 65 | print ("Create sparse matrix with ",weights.getnnz()," connections and ",(weights.getnnz()/(noRows * noCols))*100,"% density level") 66 | weights=weights.tocsr() 67 | return weights 68 | 69 | class Sigmoid: 70 | @staticmethod 71 | def activation(z): 72 | 73 | return 1 / (1 + np.exp(-z)) 74 | 75 | def activationStochastic(z): 76 | z=Sigmoid.activation(z) 77 | za=z.copy() 78 | prob=np.random.uniform(0,1,(z.shape[0],z.shape[1])) 79 | za[za>prob]=1 80 | za[za<=prob]=0 81 | return za 82 | 83 | 84 | class RBM_FixProb: 85 | def __init__(self, noVisible, noHiddens,epsilon=10): 86 | self.noVisible = noVisible #number of visible neurons 87 | self.noHiddens=noHiddens # number of hidden neurons 88 | self.epsilon = epsilon # control the sparsity level as discussed in the paper 89 | 90 | self.learning_rate = None 91 | self.weight_decay = None 92 | 93 | self.W=createSparseWeights(self.epsilon,self.noVisible,self.noHiddens) # create weights sparse matrix 94 | self.bV=np.zeros(self.noVisible) #biases of the visible neurons 95 | self.bH = np.zeros(self.noHiddens) #biases of the hidden neurons 96 | 97 | def fit(self, X_train, X_test, batch_size,epochs,lengthMarkovChain=2,weight_decay=0.0000002,learning_rate=0.1, testing=True, save_filename=""): 98 | 99 | # set learning parameters 100 | self.lengthMarkovChain=lengthMarkovChain 101 | self.weight_decay=weight_decay 102 | self.learning_rate=learning_rate 103 | 104 | 105 | minimum_reconstructin_error=100000 106 | metrics=np.zeros((epochs,2)) 107 | reconstruction_error_train=0 108 | 109 | for i in range (epochs): 110 | # Shuffle the data 111 | seed = np.arange(X_train.shape[0]) 112 | np.random.shuffle(seed) 113 | x_ = X_train[seed] 114 | 115 | # training 116 | t1 = datetime.datetime.now() 117 | for j in range(x_.shape[0] // batch_size): 118 | k = j * batch_size 119 | l = (j + 1) * batch_size 120 | reconstruction_error_train+=self.learn(x_[k:l]) 121 | t2 = datetime.datetime.now() 122 | 123 | reconstruction_error_train=reconstruction_error_train/(x_.shape[0] // batch_size) 124 | metrics[i, 0] = reconstruction_error_train 125 | print ("\nRBM-FixProb Epoch ",i) 126 | print ("Training time: ",t2-t1,"; Reconstruction error train: ",reconstruction_error_train) 127 | 128 | # test model performance on the test data at each epoch 129 | # this part is useful to understand model performance and can be commented for production settings 130 | if (testing): 131 | t3 = datetime.datetime.now() 132 | reconstruction_error_test=self.reconstruct(X_test) 133 | t4 = datetime.datetime.now() 134 | metrics[i, 1] = reconstruction_error_test 135 | minimum_reconstructin_error = min(minimum_reconstructin_error, reconstruction_error_test) 136 | print("Testing time: ", t4 - t3, "; Reconstruction error test: ", reconstruction_error_test,"; Minimum reconstruction error: ", reconstruction_error_test) 137 | 138 | #save performance metrics values in a file 139 | if (save_filename!=""): 140 | np.savetxt(save_filename,metrics) 141 | 142 | def runMarkovChain(self,x): 143 | self.DV=x 144 | self.DH=self.DV@self.W + self.bH 145 | self.DH=Sigmoid.activationStochastic(self.DH) 146 | 147 | for i in range(1,self.lengthMarkovChain): 148 | if (i==1): 149 | self.MV = self.DH @ self.W.transpose() + self.bV 150 | else: 151 | self.MV = self.MH @ self.W.transpose() + self.bV 152 | self.MV = Sigmoid.activation(self.MV) 153 | self.MH=self.MV@self.W + self.bH 154 | self.MH = Sigmoid.activationStochastic(self.MH) 155 | 156 | def reconstruct(self,x): 157 | self.runMarkovChain(x) 158 | return (np.mean((self.DV-self.MV)*(self.DV-self.MV))) 159 | 160 | def learn(self,x): 161 | self.runMarkovChain(x) 162 | self.update() 163 | return (np.mean((self.DV - self.MV) * (self.DV - self.MV))) 164 | 165 | def getRecontructedVisibleNeurons(self,x): 166 | #return recontructions of the visible neurons 167 | self.reconstruct(x) 168 | return self.MV 169 | 170 | def getHiddenNeurons(self,x): 171 | # return hidden neuron values 172 | self.reconstruct(x) 173 | return self.MH 174 | 175 | 176 | def update(self): 177 | #computer Contrastive Divergence updates 178 | self.W=self.W.tocoo() 179 | sparseoperations.contrastive_divergence_updates_Cython(self.weight_decay, self.learning_rate, self.DV, self.DH, self.MV, self.MH, self.W.row, self.W.col, self.W.data) 180 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 181 | #contrastive_divergence_updates_Numpy(self.weight_decay, self.learning_rate, self.DV, self.DH, self.MV, self.MH, self.W.row, self.W.col, self.W.data) 182 | 183 | # perform the weights update 184 | # TODO: adding momentum would make learning faster 185 | self.W=self.W.tocsr() 186 | self.bV=self.bV+self.learning_rate*(np.mean(self.DV,axis=0)-np.mean(self.MV,axis=0))-self.weight_decay*self.bV 187 | self.bH = self.bH + self.learning_rate * (np.mean(self.DH, axis=0) - np.mean(self.MH, axis=0)) - self.weight_decay * self.bH 188 | 189 | if __name__ == "__main__": 190 | # Comment this if you would like to use the full power of randomization. I use it to have repeatable results. 191 | np.random.seed(0) 192 | 193 | # load data 194 | mat = sio.loadmat('data/COIL20.mat') #COIL20 dataset was downloaded from http://featureselection.asu.edu/ 195 | X = mat['X'] 196 | Y=mat['Y'] # the labels are, in fact, not used in this demo 197 | 198 | #split data in training and testing 199 | indices=np.arange(X.shape[0]) 200 | np.random.shuffle(indices) 201 | X_train=X[indices[0:int(X.shape[0]*2/3)]] 202 | Y_train=Y[indices[0:int(X.shape[0]*2/3)]] 203 | X_test=X[indices[int(X.shape[0]*2/3):]] 204 | Y_test=Y[indices[int(X.shape[0]*2/3):]] 205 | 206 | #these data are already normalized in the [0,1] interval. If you use other data you would have to normalize them 207 | X_train = X_train.astype('float64') 208 | X_test = X_test.astype('float64') 209 | 210 | # create RBM-FixProb 211 | rbm_fixprob=RBM_FixProb(X_train.shape[1],noHiddens=200,epsilon=10) 212 | 213 | # train RBM-FixProb 214 | rbm_fixprob.fit(X_train, X_test, batch_size=10,epochs=1000,lengthMarkovChain = 2, weight_decay = 0.0000002, learning_rate = 0.1, testing = True, save_filename = "Results/rbm_fixprob.txt") 215 | 216 | # get reconstructed data 217 | reconstructions=rbm_fixprob.getRecontructedVisibleNeurons(X_test) 218 | print ("\nReconstruction error of the last epoch on the testing data: ",np.mean((reconstructions-X_test)*(reconstructions-X_test))) 219 | 220 | # get hidden neurons values to be used, for instance, with a classifier 221 | hiddens=rbm_fixprob.getHiddenNeurons(X_test) 222 | -------------------------------------------------------------------------------- /SET-RBM-Sparse-Python-Data-Structures/set_rbm_sparse_data_structures.py: -------------------------------------------------------------------------------- 1 | # Author: Decebal Constantin Mocanu et al.; 2 | # Proof of concept implementation of Sparse Evolutionary Training (SET) of Restricted Boltzmann Machine (RBM) on COIL20 dataset using Python, SciPy sparse data structures, and (optionally) Cython. 3 | # This implementation can be used to create SET-RBM with hundred of thousands of neurons. 4 | 5 | # This is a pre-alpha free software and was tested in Ubuntu 16.04 with Python 3.5.2, Numpy 1.14, SciPy 0.19.1, and (optionally) Cython 0.27.3; 6 | # The code is distributed in the hope that it may be useful, but WITHOUT ANY WARRANTIES; The use of this software is entirely at the user's own risk; 7 | # For an easy understanding of the code functionality please read the following articles. 8 | 9 | # If you use parts of this code please cite the following articles: 10 | #@article{Mocanu2018SET, 11 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 12 | # journal = {Nature Communications}, 13 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 14 | # year = {2018}, 15 | # doi = {10.1038/s41467-018-04316-3} 16 | #} 17 | 18 | #@Article{Mocanu2016XBM, 19 | #author="Mocanu, Decebal Constantin and Mocanu, Elena and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio", 20 | #title="A topological insight into restricted Boltzmann machines", 21 | #journal="Machine Learning", 22 | #year="2016", 23 | #volume="104", 24 | #number="2", 25 | #pages="243--270", 26 | #doi="10.1007/s10994-016-5570-z", 27 | #url="https://doi.org/10.1007/s10994-016-5570-z" 28 | #} 29 | 30 | #@phdthesis{Mocanu2017PhDthesis, 31 | #title = "Network computations in artificial intelligence", 32 | #author = "D.C. Mocanu", 33 | #year = "2017", 34 | #isbn = "978-90-386-4305-2", 35 | #publisher = "Eindhoven University of Technology", 36 | #} 37 | 38 | # We thank to: 39 | # Thomas Hagebols: for performing a thorough analyze on the performance of SciPy sparse matrix operations 40 | 41 | 42 | import numpy as np 43 | from scipy.sparse import lil_matrix 44 | from scipy.sparse import dok_matrix 45 | #the "sparseoperations" Cython library was tested in Ubuntu 16.04. Please note that you may encounter some "solvable" issues if you compile it in Windows. 46 | import sparseoperations 47 | import datetime 48 | import scipy.io as sio 49 | import matplotlib.pyplot as plt 50 | 51 | def contrastive_divergence_updates_Numpy(wDecay, lr, DV, DH, MV, MH, rows, cols, out): 52 | for i in range (out.shape[0]): 53 | s1=0 54 | s2=0 55 | for j in range(DV.shape[0]): 56 | s1+=DV[j,rows[i]]*DH[j, cols[i]] 57 | s2+=MV[j,rows[i]]*MH[j, cols[i]] 58 | out[i]+=lr*(s1/DV.shape[0]-s2/DV.shape[0])-wDecay*out[i] 59 | #return out 60 | 61 | def find_first_pos(array, value): 62 | idx = (np.abs(array - value)).argmin() 63 | return idx 64 | 65 | 66 | def find_last_pos(array, value): 67 | idx = (np.abs(array - value))[::-1].argmin() 68 | return array.shape[0] - idx 69 | 70 | def createSparseWeights(epsilon,noRows,noCols): 71 | # generate an Erdos Renyi sparse weights mask 72 | weights=lil_matrix((noRows, noCols)) 73 | for i in range(epsilon * (noRows + noCols)): 74 | weights[np.random.randint(0,noRows),np.random.randint(0,noCols)]=np.float64(np.random.randn()/20) 75 | print ("Create sparse matrix with ",weights.getnnz()," connections and ",(weights.getnnz()/(noRows * noCols))*100,"% density level") 76 | weights=weights.tocsr() 77 | return weights 78 | 79 | class Sigmoid: 80 | @staticmethod 81 | def activation(z): 82 | 83 | return 1 / (1 + np.exp(-z)) 84 | 85 | def activationStochastic(z): 86 | z=Sigmoid.activation(z) 87 | za=z.copy() 88 | prob=np.random.uniform(0,1,(z.shape[0],z.shape[1])) 89 | za[za>prob]=1 90 | za[za<=prob]=0 91 | return za 92 | 93 | 94 | class SET_RBM: 95 | def __init__(self, noVisible, noHiddens,epsilon=10): 96 | self.noVisible = noVisible #number of visible neurons 97 | self.noHiddens=noHiddens # number of hidden neurons 98 | self.epsilon = epsilon # control the sparsity level as discussed in the paper 99 | 100 | self.learning_rate = None #learning rate 101 | self.weight_decay = None #weight decay 102 | self.zeta = None # the fraction of the weights removed 103 | 104 | self.W=createSparseWeights(self.epsilon,self.noVisible,self.noHiddens) # create weights sparse matrix 105 | self.bV=np.zeros(self.noVisible) #biases of the visible neurons 106 | self.bH = np.zeros(self.noHiddens) #biases of the hidden neurons 107 | 108 | def fit(self, X_train, X_test, batch_size,epochs,lengthMarkovChain=2,weight_decay=0.0000002,learning_rate=0.1,zeta=0.3, testing=True, save_filename=""): 109 | 110 | # set learning parameters 111 | self.lengthMarkovChain=lengthMarkovChain #length of Markov chain for Contrastive Divergence 112 | self.weight_decay=weight_decay #weight decay 113 | self.learning_rate=learning_rate #learning rate 114 | self.zeta=zeta #control the fraction of weights removed 115 | 116 | 117 | minimum_reconstructin_error=100000 118 | metrics=np.zeros((epochs,2)) 119 | reconstruction_error_train=0 120 | 121 | for i in range (epochs): 122 | # Shuffle the data 123 | seed = np.arange(X_train.shape[0]) 124 | np.random.shuffle(seed) 125 | x_ = X_train[seed] 126 | 127 | # training 128 | t1 = datetime.datetime.now() 129 | for j in range(x_.shape[0] // batch_size): 130 | k = j * batch_size 131 | l = (j + 1) * batch_size 132 | reconstruction_error_train+=self.learn(x_[k:l]) 133 | t2 = datetime.datetime.now() 134 | 135 | reconstruction_error_train=reconstruction_error_train/(x_.shape[0] // batch_size) 136 | metrics[i, 0] = reconstruction_error_train 137 | print ("\nSET-RBM Epoch ",i) 138 | print ("Training time: ",t2-t1,"; Reconstruction error train: ",reconstruction_error_train) 139 | 140 | # test model performance on the test data at each epoch 141 | # this part is useful to understand model performance and can be commented for production settings 142 | if (testing): 143 | t3 = datetime.datetime.now() 144 | reconstruction_error_test=self.reconstruct(X_test) 145 | t4 = datetime.datetime.now() 146 | metrics[i, 1] = reconstruction_error_test 147 | minimum_reconstructin_error = min(minimum_reconstructin_error, reconstruction_error_test) 148 | print("Testing time: ", t4 - t3, "; Reconstruction error test: ", reconstruction_error_test,"; Minimum reconstruction error: ", reconstruction_error_test) 149 | 150 | # change connectivity pattern 151 | t5 = datetime.datetime.now() 152 | if (i < epochs - 1): 153 | self.weightsEvolution(addition=True) 154 | else: 155 | if (i == epochs - 1): #during the last epoch just connections removal is performed. We did not add new random weights to favour statistics on the connections 156 | self.weightsEvolution(addition=False) 157 | t6 = datetime.datetime.now() 158 | print("Weights evolution time ", t6 - t5) 159 | 160 | #save performance metrics values in a file 161 | if (save_filename!=""): 162 | np.savetxt(save_filename,metrics) 163 | 164 | def runMarkovChain(self,x): 165 | self.DV=x 166 | self.DH=self.DV@self.W + self.bH 167 | self.DH=Sigmoid.activationStochastic(self.DH) 168 | 169 | for i in range(1,self.lengthMarkovChain): 170 | if (i==1): 171 | self.MV = self.DH @ self.W.transpose() + self.bV 172 | else: 173 | self.MV = self.MH @ self.W.transpose() + self.bV 174 | self.MV = Sigmoid.activation(self.MV) 175 | self.MH=self.MV@self.W + self.bH 176 | self.MH = Sigmoid.activationStochastic(self.MH) 177 | 178 | def reconstruct(self,x): 179 | self.runMarkovChain(x) 180 | return (np.mean((self.DV-self.MV)*(self.DV-self.MV))) 181 | 182 | def learn(self,x): 183 | self.runMarkovChain(x) 184 | self.update() 185 | return (np.mean((self.DV - self.MV) * (self.DV - self.MV))) 186 | 187 | def getRecontructedVisibleNeurons(self,x): 188 | #return recontructions of the visible neurons 189 | self.reconstruct(x) 190 | return self.MV 191 | 192 | def getHiddenNeurons(self,x): 193 | # return hidden neuron values 194 | self.reconstruct(x) 195 | return self.MH 196 | 197 | 198 | def weightsEvolution(self,addition): 199 | # this represents the core of the SET procedure. It removes the weights closest to zero in each layer and add new random weights 200 | # TODO: this method could be seriously improved in terms of running time using Cython 201 | values=np.sort(self.W.data) 202 | firstZeroPos = find_first_pos(values, 0) 203 | lastZeroPos = find_last_pos(values, 0) 204 | 205 | largestNegative = values[int((1-self.zeta) * firstZeroPos)] 206 | smallestPositive = values[int(min(values.shape[0] - 1, lastZeroPos + self.zeta * (values.shape[0] - lastZeroPos)))] 207 | 208 | wlil = self.W.tolil() 209 | wdok = dok_matrix((self.noVisible,self.noHiddens),dtype="float64") 210 | 211 | # remove the weights closest to zero 212 | keepConnections=0 213 | for ik, (row, data) in enumerate(zip(wlil.rows, wlil.data)): 214 | for jk, val in zip(row, data): 215 | if (((val < largestNegative) or (val > smallestPositive))): 216 | wdok[ik,jk]=val 217 | keepConnections+=1 218 | 219 | # add new random connections 220 | if (addition): 221 | for kk in range(self.W.data.shape[0]-keepConnections): 222 | ik = np.random.randint(0, self.noVisible) 223 | jk = np.random.randint(0, self.noHiddens) 224 | while ((wdok[ik,jk]!=0)): 225 | ik = np.random.randint(0, self.noVisible) 226 | jk = np.random.randint(0, self.noHiddens) 227 | wdok[ik, jk]=np.random.randn() / 20 228 | 229 | self.W=wdok.tocsr() 230 | 231 | def update(self): 232 | #compute Contrastive Divergence updates 233 | self.W=self.W.tocoo() 234 | sparseoperations.contrastive_divergence_updates_Cython(self.weight_decay, self.learning_rate, self.DV, self.DH, self.MV, self.MH, self.W.row, self.W.col, self.W.data) 235 | # If you have problems with Cython please use the contrastive_divergence_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 236 | #contrastive_divergence_updates_Numpy(self.weight_decay, self.learning_rate, self.DV, self.DH, self.MV, self.MH, self.W.row, self.W.col, self.W.data) 237 | 238 | # perform the weights update 239 | # TODO: adding momentum would make learning faster 240 | self.W=self.W.tocsr() 241 | self.bV=self.bV+self.learning_rate*(np.mean(self.DV,axis=0)-np.mean(self.MV,axis=0))-self.weight_decay*self.bV 242 | self.bH = self.bH + self.learning_rate * (np.mean(self.DH, axis=0) - np.mean(self.MH, axis=0)) - self.weight_decay * self.bH 243 | 244 | if __name__ == "__main__": 245 | # Comment this if you would like to use the full power of randomization. I use it to have repeatable results. 246 | np.random.seed(0) 247 | 248 | # load data 249 | mat = sio.loadmat('data/COIL20.mat') #COIL20 dataset was downloaded from http://featureselection.asu.edu/ 250 | X = mat['X'] 251 | Y=mat['Y'] # the labels are, in fact, not used in this demo 252 | 253 | #split data in training and testing 254 | indices=np.arange(X.shape[0]) 255 | np.random.shuffle(indices) 256 | X_train=X[indices[0:int(X.shape[0]*2/3)]] 257 | Y_train=Y[indices[0:int(X.shape[0]*2/3)]] 258 | X_test=X[indices[int(X.shape[0]*2/3):]] 259 | Y_test=Y[indices[int(X.shape[0]*2/3):]] 260 | 261 | #these data are already normalized in the [0,1] interval. If you use other data you would have to normalize them 262 | X_train = X_train.astype('float64') 263 | X_test = X_test.astype('float64') 264 | 265 | # create SET-RBM 266 | setrbm=SET_RBM(X_train.shape[1],noHiddens=200,epsilon=10) 267 | 268 | # train SET-RBM 269 | setrbm.fit(X_train, X_test, batch_size=10,epochs=1000,lengthMarkovChain = 2, weight_decay = 0.0000002, learning_rate = 0.1, zeta = 0.3, testing = True, save_filename = "Results/set_rbm.txt") 270 | 271 | # get reconstructed data 272 | # please note the very very small difference in error between this one and the one computing during training. This is the (insignificant) effect of the removed weights which are closest to zero 273 | reconstructions=setrbm.getRecontructedVisibleNeurons(X_test) 274 | print ("\nReconstruction error of the last epoch on the testing data: ",np.mean((reconstructions-X_test)*(reconstructions-X_test))) 275 | 276 | # get hidden neurons values to be used, for instance, with a classifier 277 | hiddens=setrbm.getHiddenNeurons(X_test) 278 | 279 | -------------------------------------------------------------------------------- /SET-RBM-Sparse-Python-Data-Structures/sparseoperations.cpython-35m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/SET-RBM-Sparse-Python-Data-Structures/sparseoperations.cpython-35m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /SET-RBM-Sparse-Python-Data-Structures/sparseoperations.pyx: -------------------------------------------------------------------------------- 1 | # compile this file with: "cythonize -a -i sparseoperations.pyx" 2 | # I have tested this method in Linux (Ubuntu). If you compile it in Windows you may need some work around. 3 | 4 | cimport numpy as np 5 | 6 | def contrastive_divergence_updates_Cython(double wDecay, double lr, np.ndarray[np.float64_t,ndim=2] DV, np.ndarray[np.float64_t,ndim=2] DH, np.ndarray[np.float64_t,ndim=2] MV, np.ndarray[np.float64_t,ndim=2] MH, np.ndarray[int,ndim=1] rows, np.ndarray[int,ndim=1] cols,np.ndarray[np.float64_t,ndim=1] out): 7 | cdef: 8 | size_t i,j 9 | double s1,s2 10 | for i in range (out.shape[0]): 11 | s1=0 12 | s2=0 13 | for j in range(DV.shape[0]): 14 | s1+=DV[j,rows[i]]*DH[j, cols[i]] 15 | s2+=MV[j,rows[i]]*MH[j, cols[i]] 16 | out[i]+=lr*(s1/DV.shape[0]-s2/DV.shape[0])-wDecay*out[i] 17 | #return out 18 | -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/fashion_mnist_connections_evolution_per_input_pixel_rand0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/fashion_mnist_connections_evolution_per_input_pixel_rand0.gif -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/mnist_learning_curves_samples2000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/mnist_learning_curves_samples2000.pdf -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand0_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand0_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand1_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand1_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand2_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand2_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand3_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand3_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand4_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand4_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Results/mnist_learning_curves_samples2000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/Results/mnist_learning_curves_samples2000.pdf -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/fc_mlp.py: -------------------------------------------------------------------------------- 1 | # Authors: Decebal Constantin Mocanu et al.; 2 | # Code associated with ECMLPKDD 2019 tutorial "Scalable Deep Learning: from theory to practice"; https://sites.google.com/view/sdl-ecmlpkdd-2019-tutorial 3 | # This is a pre-alpha free software and was tested in Ubuntu 16.04 with Python 3.5.2, Numpy 1.14, SciPy 0.19.1, and (optionally) Cython 0.27.3; 4 | 5 | # If you use parts of this code please cite the following article: 6 | #@article{Mocanu2018SET, 7 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 8 | # journal = {Nature Communications}, 9 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 10 | # year = {2018}, 11 | # doi = {10.1038/s41467-018-04316-3} 12 | #} 13 | 14 | #If you have space please consider citing also these articles 15 | 16 | #@phdthesis{Mocanu2017PhDthesis, 17 | #title = "Network computations in artificial intelligence", 18 | #author = "D.C. Mocanu", 19 | #year = "2017", 20 | #isbn = "978-90-386-4305-2", 21 | #publisher = "Eindhoven University of Technology", 22 | #} 23 | 24 | #@article{Liu2019onemillion, 25 | # author = {Liu, Shiwei and Mocanu, Decebal Constantin and Mocanu and Ramapuram Matavalam, Amarsagar Reddy and Pei, Yulong Pei and Pechenizkiy, Mykola}, 26 | # journal = {arXiv:1901.09181}, 27 | # title = {Sparse evolutionary Deep Learning with over one million artificial neurons on commodity hardware}, 28 | # year = {2019}, 29 | #} 30 | 31 | # We thank to: 32 | # Thomas Hagebols: for performing a thorough analyze on the performance of SciPy sparse matrix operations 33 | # Ritchie Vink (https://www.ritchievink.com): for making available on Github a nice Python implementation of fully connected MLPs. This SET-MLP implementation was built on top of his MLP code: 34 | # https://github.com/ritchie46/vanilla-machine-learning/blob/master/vanilla_mlp.py 35 | 36 | 37 | import numpy as np 38 | import datetime 39 | 40 | 41 | def backpropagation_updates_Numpy(a, delta, rows, cols, out): 42 | for i in range(out.shape[0]): 43 | s = 0 44 | for j in range(a.shape[0]): 45 | s += a[j, rows[i]] * delta[j, cols[i]] 46 | out[i] = s / a.shape[0] 47 | 48 | 49 | def find_first_pos(array, value): 50 | idx = (np.abs(array - value)).argmin() 51 | return idx 52 | 53 | 54 | def find_last_pos(array, value): 55 | idx = (np.abs(array - value))[::-1].argmin() 56 | return array.shape[0] - idx 57 | 58 | 59 | def array_intersect(A, B): 60 | # this are for array intersection 61 | nrows, ncols = A.shape 62 | dtype = {'names': ['f{}'.format(i) for i in range(ncols)], 'formats': ncols * [A.dtype]} 63 | return np.in1d(A.view(dtype), B.view(dtype)) # boolean return 64 | 65 | 66 | class Relu: 67 | @staticmethod 68 | def activation(z): 69 | z[z < 0] = 0 70 | return z 71 | 72 | @staticmethod 73 | def prime(z): 74 | z[z < 0] = 0 75 | z[z > 0] = 1 76 | return z 77 | 78 | 79 | class Sigmoid: 80 | @staticmethod 81 | def activation(z): 82 | return 1 / (1 + np.exp(-z)) 83 | 84 | @staticmethod 85 | def prime(z): 86 | return Sigmoid.activation(z) * (1 - Sigmoid.activation(z)) 87 | 88 | 89 | class MSE: 90 | def __init__(self, activation_fn=None): 91 | """ 92 | 93 | :param activation_fn: Class object of the activation function. 94 | """ 95 | if activation_fn: 96 | self.activation_fn = activation_fn 97 | else: 98 | self.activation_fn = NoActivation 99 | 100 | def activation(self, z): 101 | return self.activation_fn.activation(z) 102 | 103 | @staticmethod 104 | def loss(y_true, y_pred): 105 | """ 106 | :param y_true: (array) One hot encoded truth vector. 107 | :param y_pred: (array) Prediction vector 108 | :return: (flt) 109 | """ 110 | return np.mean((y_pred - y_true) ** 2) 111 | 112 | @staticmethod 113 | def prime(y_true, y_pred): 114 | return y_pred - y_true 115 | 116 | def delta(self, y_true, y_pred): 117 | """ 118 | Back propagation error delta 119 | :return: (array) 120 | """ 121 | return self.prime(y_true, y_pred) * self.activation_fn.prime(y_pred) 122 | 123 | 124 | class NoActivation: 125 | """ 126 | This is a plugin function for no activation. 127 | 128 | f(x) = x * 1 129 | """ 130 | 131 | @staticmethod 132 | def activation(z): 133 | """ 134 | :param z: (array) w(x) + b 135 | :return: z (array) 136 | """ 137 | return z 138 | 139 | @staticmethod 140 | def prime(z): 141 | """ 142 | The prime of z * 1 = 1 143 | :param z: (array) 144 | :return: z': (array) 145 | """ 146 | return np.ones_like(z) 147 | 148 | 149 | class FC_MLP: 150 | def __init__(self, dimensions, activations): 151 | """ 152 | :param dimensions: (tpl/ list) Dimensions of the neural net. (input, hidden layer, output) 153 | :param activations: (tpl/ list) Activations functions. 154 | 155 | Example of three hidden layer with 156 | - 3312 input features 157 | - 3000 hidden neurons 158 | - 3000 hidden neurons 159 | - 3000 hidden neurons 160 | - 5 output classes 161 | 162 | 163 | layers --> [1, 2, 3, 4, 5] 164 | ---------------------------------------- 165 | 166 | dimensions = (3312, 3000, 3000, 3000, 5) 167 | activations = ( Relu, Relu, Relu, Sigmoid) 168 | """ 169 | self.n_layers = len(dimensions) 170 | self.loss = None 171 | self.learning_rate = None 172 | self.momentum = None 173 | self.weight_decay = None 174 | self.droprate = 0 # dropout rate 175 | self.dimensions = dimensions 176 | 177 | # Weights and biases are initiated by index. For a one hidden layer net you will have a w[1] and w[2] 178 | self.w = {} 179 | self.b = {} 180 | self.pdw = {} 181 | self.pdd = {} 182 | 183 | # Activations are also initiated by index. For the example we will have activations[2] and activations[3] 184 | self.activations = {} 185 | for i in range(len(dimensions) - 1): 186 | self.w[i + 1] = np.random.randn( dimensions[i],dimensions[i + 1])/ 10 187 | self.b[i + 1] = np.zeros(dimensions[i + 1]) 188 | self.activations[i + 2] = activations[i] 189 | 190 | def _feed_forward(self, x, drop=False): 191 | """ 192 | Execute a forward feed through the network. 193 | :param x: (array) Batch of input data vectors. 194 | :return: (tpl) Node outputs and activations per layer. The numbering of the output is equivalent to the layer numbers. 195 | """ 196 | 197 | # w(x) + b 198 | z = {} 199 | 200 | # activations: f(z) 201 | a = {1: x} # First layer has no activations as input. The input x is the input. 202 | 203 | for i in range(1, self.n_layers): 204 | z[i + 1] = a[i] @ self.w[i] + self.b[i] 205 | if (drop == False): 206 | if (i > 1): 207 | z[i + 1] = z[i + 1] * (1 - self.droprate) 208 | a[i + 1] = self.activations[i + 1].activation(z[i + 1]) 209 | if (drop): 210 | if (i < self.n_layers - 1): 211 | dropMask = np.random.rand(a[i + 1].shape[0], a[i + 1].shape[1]) 212 | dropMask[dropMask >= self.droprate] = 1 213 | dropMask[dropMask < self.droprate] = 0 214 | a[i + 1] = dropMask * a[i + 1] 215 | 216 | return z, a 217 | 218 | def _back_prop(self, z, a, y_true): 219 | """ 220 | The input dicts keys represent the layers of the net. 221 | 222 | a = { 1: x, 223 | 2: f(w1(x) + b1) 224 | 3: f(w2(a2) + b2) 225 | 4: f(w3(a3) + b3) 226 | 5: f(w4(a4) + b4) 227 | } 228 | 229 | :param z: (dict) w(x) + b 230 | :param a: (dict) f(z) 231 | :param y_true: (array) One hot encoded truth vector. 232 | :return: 233 | """ 234 | 235 | # Determine partial derivative and delta for the output layer. 236 | # delta output layer 237 | delta = self.loss.delta(y_true, a[self.n_layers]) 238 | dw = np.dot(a[self.n_layers - 1].T, delta) 239 | 240 | 241 | update_params = { 242 | self.n_layers - 1: (dw, delta) 243 | } 244 | 245 | # In case of three layer net will iterate over i = 2 and i = 1 246 | # Determine partial derivative and delta for the rest of the layers. 247 | # Each iteration requires the delta from the previous layer, propagating backwards. 248 | for i in reversed(range(2, self.n_layers)): 249 | delta = (delta @ self.w[i].transpose()) * self.activations[i].prime(z[i]) 250 | dw = np.dot(a[i - 1].T, delta) 251 | 252 | update_params[i - 1] = (dw, delta) 253 | for k, v in update_params.items(): 254 | self._update_w_b(k, v[0], v[1]) 255 | 256 | def _update_w_b(self, index, dw, delta): 257 | """ 258 | Update weights and biases. 259 | 260 | :param index: (int) Number of the layer 261 | :param dw: (array) Partial derivatives 262 | :param delta: (array) Delta error. 263 | """ 264 | 265 | # perform the update with momentum 266 | if (index not in self.pdw): 267 | self.pdw[index] = -self.learning_rate * dw 268 | self.pdd[index] = -self.learning_rate * np.mean(delta, 0) 269 | else: 270 | self.pdw[index] = self.momentum * self.pdw[index] - self.learning_rate * dw 271 | self.pdd[index] = self.momentum * self.pdd[index] - self.learning_rate * np.mean(delta, 0) 272 | 273 | self.w[index] += self.pdw[index] - self.weight_decay * self.w[index] 274 | self.b[index] += self.pdd[index] - self.weight_decay * self.b[index] 275 | 276 | def fit(self, x, y_true, x_test, y_test, loss, epochs, batch_size, learning_rate=1e-3, momentum=0.9, 277 | weight_decay=0.0002, dropoutrate=0, testing=True, save_filename=""): 278 | """ 279 | :param x: (array) Containing parameters 280 | :param y_true: (array) Containing one hot encoded labels. 281 | :param loss: Loss class (MSE, CrossEntropy etc.) 282 | :param epochs: (int) Number of epochs. 283 | :param batch_size: (int) 284 | :param learning_rate: (flt) 285 | :param momentum: (flt) 286 | :param weight_decay: (flt) 287 | :param zeta: (flt) #control the fraction of weights removed 288 | :param droprate: (flt) 289 | :return (array) A 2D array of metrics (epochs, 3). 290 | """ 291 | if not x.shape[0] == y_true.shape[0]: 292 | raise ValueError("Length of x and y arrays don't match") 293 | # Initiate the loss object with the final activation function 294 | self.loss = loss(self.activations[self.n_layers]) 295 | self.learning_rate = learning_rate 296 | self.momentum = momentum 297 | self.weight_decay = weight_decay 298 | self.droprate = dropoutrate 299 | 300 | maximum_accuracy = 0 301 | 302 | metrics = np.zeros((epochs, 4)) 303 | 304 | for i in range(epochs): 305 | # Shuffle the data 306 | seed = np.arange(x.shape[0]) 307 | np.random.shuffle(seed) 308 | x_ = x[seed] 309 | y_ = y_true[seed] 310 | 311 | # training 312 | t1 = datetime.datetime.now() 313 | 314 | for j in range(x.shape[0] // batch_size): 315 | k = j * batch_size 316 | l = (j + 1) * batch_size 317 | z, a = self._feed_forward(x_[k:l], True) 318 | 319 | 320 | self._back_prop(z, a, y_[k:l]) 321 | 322 | t2 = datetime.datetime.now() 323 | 324 | print("\nFC-MLP Epoch ", i) 325 | print("Training time: ", t2 - t1) 326 | 327 | # test model performance on the test data at each epoch 328 | # this part is useful to understand model performance and can be commented for production settings 329 | if (testing): 330 | t3 = datetime.datetime.now() 331 | accuracy_test, activations_test = self.predict(x_test, y_test, batch_size) 332 | accuracy_train, activations_train = self.predict(x, y_true, batch_size) 333 | t4 = datetime.datetime.now() 334 | maximum_accuracy = max(maximum_accuracy, accuracy_test) 335 | loss_test = self.loss.loss(y_test, activations_test) 336 | loss_train = self.loss.loss(y_true, activations_train) 337 | metrics[i, 0] = loss_train 338 | metrics[i, 1] = loss_test 339 | metrics[i, 2] = accuracy_train 340 | metrics[i, 3] = accuracy_test 341 | print("Testing time: ", t4 - t3,"; Loss train: ", loss_train, "; Loss test: ", loss_test, "; Accuracy train: ", accuracy_train,"; Accuracy test: ", accuracy_test, 342 | "; Maximum accuracy test: ", maximum_accuracy) 343 | 344 | # save performance metrics values in a file 345 | if (save_filename != ""): 346 | np.savetxt(save_filename, metrics) 347 | 348 | return metrics 349 | 350 | 351 | def predict(self, x_test, y_test, batch_size=1): 352 | """ 353 | :param x_test: (array) Test input 354 | :param y_test: (array) Correct test output 355 | :param batch_size: 356 | :return: (flt) Classification accuracy 357 | :return: (array) A 2D array of shape (n_cases, n_classes). 358 | """ 359 | activations = np.zeros((y_test.shape[0], y_test.shape[1])) 360 | for j in range(x_test.shape[0] // batch_size): 361 | k = j * batch_size 362 | l = (j + 1) * batch_size 363 | _, a_test = self._feed_forward(x_test[k:l]) 364 | activations[k:l] = a_test[self.n_layers] 365 | correctClassification = 0 366 | for j in range(y_test.shape[0]): 367 | if (np.argmax(activations[j]) == np.argmax(y_test[j])): 368 | correctClassification += 1 369 | accuracy = correctClassification / y_test.shape[0] 370 | return accuracy, activations 371 | 372 | def load_fashion_mnist_data(noTrainingSamples,noTestingSamples): 373 | np.random.seed(0) 374 | 375 | data=np.load("../Tutorial-IJCAI-2019-Scalable-Deep-Learning/data/fashion_mnist.npz") 376 | 377 | indexTrain=np.arange(data["X_train"].shape[0]) 378 | np.random.shuffle(indexTrain) 379 | 380 | indexTest=np.arange(data["X_test"].shape[0]) 381 | np.random.shuffle(indexTest) 382 | 383 | X_train=data["X_train"][indexTrain[0:noTrainingSamples],:] 384 | Y_train=data["Y_train"][indexTrain[0:noTrainingSamples],:] 385 | X_test=data["X_test"][indexTest[0:noTestingSamples],:] 386 | Y_test=data["Y_test"][indexTest[0:noTestingSamples],:] 387 | 388 | #normalize in 0..1 389 | X_train = X_train.astype('float64') / 255. 390 | X_test = X_test.astype('float64') / 255. 391 | 392 | return X_train,Y_train,X_test,Y_test 393 | 394 | if __name__ == "__main__": 395 | 396 | for i in range(1): 397 | #load data 398 | noTrainingSamples=2000 #max 60000 for Fashion MNIST 399 | noTestingSamples = 1000 # max 10000 for Fshion MNIST 400 | X_train, Y_train, X_test, Y_test = load_fashion_mnist_data(noTrainingSamples,noTestingSamples) 401 | 402 | #set model parameters 403 | noHiddenNeuronsLayer=1000 404 | noTrainingEpochs=400 405 | batchSize=40 406 | dropoutRate=0.2 407 | learningRate=0.001 408 | momentum=0.9 409 | weightDecay=0.0002 410 | 411 | np.random.seed(i) 412 | 413 | # create FC-MLP ( fully-connected MLP) 414 | fc_mlp = FC_MLP((X_train.shape[1], noHiddenNeuronsLayer, noHiddenNeuronsLayer,noHiddenNeuronsLayer, Y_train.shape[1]), (Relu, Relu,Relu, Sigmoid)) 415 | 416 | # train FC-MLP 417 | fc_mlp.fit(X_train, Y_train, X_test, Y_test, loss=MSE, epochs=noTrainingEpochs, batch_size=batchSize, learning_rate=learningRate, 418 | momentum=momentum, weight_decay=weightDecay, dropoutrate=dropoutRate, testing=True, 419 | save_filename="Results/fc_mlp_"+str(noTrainingSamples)+"_training_samples_rand"+str(i)+".txt") 420 | 421 | # test FC-MLP 422 | accuracy, _ = fc_mlp.predict(X_test, Y_test, batch_size=1) 423 | 424 | print("\nAccuracy of the last epoch on the testing data: ", accuracy) 425 | -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/fixprob_mlp.py: -------------------------------------------------------------------------------- 1 | # Authors: Decebal Constantin Mocanu et al.; 2 | # Code associated with ECMLPKDD 2019 tutorial "Scalable Deep Learning: from theory to practice"; https://sites.google.com/view/sdl-ecmlpkdd-2019-tutorial 3 | # This is a pre-alpha free software and was tested in Ubuntu 16.04 with Python 3.5.2, Numpy 1.14, SciPy 0.19.1, and (optionally) Cython 0.27.3; 4 | 5 | # If you use parts of this code please cite the following article: 6 | #@article{Mocanu2018SET, 7 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 8 | # journal = {Nature Communications}, 9 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 10 | # year = {2018}, 11 | # doi = {10.1038/s41467-018-04316-3} 12 | #} 13 | 14 | #If you have space please consider citing also these articles 15 | 16 | #@phdthesis{Mocanu2017PhDthesis, 17 | #title = "Network computations in artificial intelligence", 18 | #author = "D.C. Mocanu", 19 | #year = "2017", 20 | #isbn = "978-90-386-4305-2", 21 | #publisher = "Eindhoven University of Technology", 22 | #} 23 | 24 | #@article{Liu2019onemillion, 25 | # author = {Liu, Shiwei and Mocanu, Decebal Constantin and Mocanu and Ramapuram Matavalam, Amarsagar Reddy and Pei, Yulong Pei and Pechenizkiy, Mykola}, 26 | # journal = {arXiv:1901.09181}, 27 | # title = {Sparse evolutionary Deep Learning with over one million artificial neurons on commodity hardware}, 28 | # year = {2019}, 29 | #} 30 | 31 | # We thank to: 32 | # Thomas Hagebols: for performing a thorough analyze on the performance of SciPy sparse matrix operations 33 | # Ritchie Vink (https://www.ritchievink.com): for making available on Github a nice Python implementation of fully connected MLPs. This SET-MLP implementation was built on top of his MLP code: 34 | # https://github.com/ritchie46/vanilla-machine-learning/blob/master/vanilla_mlp.py 35 | 36 | 37 | import numpy as np 38 | from scipy.sparse import lil_matrix 39 | from scipy.sparse import coo_matrix 40 | #the "sparseoperations" Cython library was tested in Ubuntu 16.04. Please note that you may encounter some "solvable" issues if you compile it in Windows. 41 | import sparseoperations 42 | import datetime 43 | 44 | 45 | def backpropagation_updates_Numpy(a, delta, rows, cols, out): 46 | for i in range(out.shape[0]): 47 | s = 0 48 | for j in range(a.shape[0]): 49 | s += a[j, rows[i]] * delta[j, cols[i]] 50 | out[i] = s / a.shape[0] 51 | 52 | 53 | def find_first_pos(array, value): 54 | idx = (np.abs(array - value)).argmin() 55 | return idx 56 | 57 | 58 | def find_last_pos(array, value): 59 | idx = (np.abs(array - value))[::-1].argmin() 60 | return array.shape[0] - idx 61 | 62 | 63 | def createSparseWeights(epsilon, noRows, noCols): 64 | # generate an Erdos Renyi sparse weights mask 65 | weights = lil_matrix((noRows, noCols)) 66 | for i in range(epsilon * (noRows + noCols)): 67 | weights[np.random.randint(0, noRows), np.random.randint(0, noCols)] = np.float64(np.random.randn() / 10) 68 | print("Create sparse matrix with ", weights.getnnz(), " connections and ", 69 | (weights.getnnz() / (noRows * noCols)) * 100, "% density level") 70 | weights = weights.tocsr() 71 | return weights 72 | 73 | 74 | def array_intersect(A, B): 75 | # this are for array intersection 76 | nrows, ncols = A.shape 77 | dtype = {'names': ['f{}'.format(i) for i in range(ncols)], 'formats': ncols * [A.dtype]} 78 | return np.in1d(A.view(dtype), B.view(dtype)) # boolean return 79 | 80 | 81 | class Relu: 82 | @staticmethod 83 | def activation(z): 84 | z[z < 0] = 0 85 | return z 86 | 87 | @staticmethod 88 | def prime(z): 89 | z[z < 0] = 0 90 | z[z > 0] = 1 91 | return z 92 | 93 | 94 | class Sigmoid: 95 | @staticmethod 96 | def activation(z): 97 | return 1 / (1 + np.exp(-z)) 98 | 99 | @staticmethod 100 | def prime(z): 101 | return Sigmoid.activation(z) * (1 - Sigmoid.activation(z)) 102 | 103 | 104 | class MSE: 105 | def __init__(self, activation_fn=None): 106 | """ 107 | 108 | :param activation_fn: Class object of the activation function. 109 | """ 110 | if activation_fn: 111 | self.activation_fn = activation_fn 112 | else: 113 | self.activation_fn = NoActivation 114 | 115 | def activation(self, z): 116 | return self.activation_fn.activation(z) 117 | 118 | @staticmethod 119 | def loss(y_true, y_pred): 120 | """ 121 | :param y_true: (array) One hot encoded truth vector. 122 | :param y_pred: (array) Prediction vector 123 | :return: (flt) 124 | """ 125 | return np.mean((y_pred - y_true) ** 2) 126 | 127 | @staticmethod 128 | def prime(y_true, y_pred): 129 | return y_pred - y_true 130 | 131 | def delta(self, y_true, y_pred): 132 | """ 133 | Back propagation error delta 134 | :return: (array) 135 | """ 136 | return self.prime(y_true, y_pred) * self.activation_fn.prime(y_pred) 137 | 138 | 139 | class NoActivation: 140 | """ 141 | This is a plugin function for no activation. 142 | 143 | f(x) = x * 1 144 | """ 145 | 146 | @staticmethod 147 | def activation(z): 148 | """ 149 | :param z: (array) w(x) + b 150 | :return: z (array) 151 | """ 152 | return z 153 | 154 | @staticmethod 155 | def prime(z): 156 | """ 157 | The prime of z * 1 = 1 158 | :param z: (array) 159 | :return: z': (array) 160 | """ 161 | return np.ones_like(z) 162 | 163 | 164 | class FixProb_MLP: 165 | def __init__(self, dimensions, activations, epsilon=20): 166 | """ 167 | :param dimensions: (tpl/ list) Dimensions of the neural net. (input, hidden layer, output) 168 | :param activations: (tpl/ list) Activations functions. 169 | 170 | Example of three hidden layer with 171 | - 3312 input features 172 | - 3000 hidden neurons 173 | - 3000 hidden neurons 174 | - 3000 hidden neurons 175 | - 5 output classes 176 | 177 | 178 | layers --> [1, 2, 3, 4, 5] 179 | ---------------------------------------- 180 | 181 | dimensions = (3312, 3000, 3000, 3000, 5) 182 | activations = ( Relu, Relu, Relu, Sigmoid) 183 | """ 184 | self.n_layers = len(dimensions) 185 | self.loss = None 186 | self.learning_rate = None 187 | self.momentum = None 188 | self.weight_decay = None 189 | self.epsilon = epsilon # control the sparsity level as discussed in the paper 190 | self.zeta = None # the fraction of the weights removed 191 | self.droprate = 0 # dropout rate 192 | self.dimensions = dimensions 193 | 194 | # Weights and biases are initiated by index. For a one hidden layer net you will have a w[1] and w[2] 195 | self.w = {} 196 | self.b = {} 197 | self.pdw = {} 198 | self.pdd = {} 199 | 200 | # Activations are also initiated by index. For the example we will have activations[2] and activations[3] 201 | self.activations = {} 202 | for i in range(len(dimensions) - 1): 203 | self.w[i + 1] = createSparseWeights(self.epsilon, dimensions[i], 204 | dimensions[i + 1]) # create sparse weight matrices 205 | self.b[i + 1] = np.zeros(dimensions[i + 1]) 206 | self.activations[i + 2] = activations[i] 207 | 208 | def _feed_forward(self, x, drop=False): 209 | """ 210 | Execute a forward feed through the network. 211 | :param x: (array) Batch of input data vectors. 212 | :return: (tpl) Node outputs and activations per layer. The numbering of the output is equivalent to the layer numbers. 213 | """ 214 | 215 | # w(x) + b 216 | z = {} 217 | 218 | # activations: f(z) 219 | a = {1: x} # First layer has no activations as input. The input x is the input. 220 | 221 | for i in range(1, self.n_layers): 222 | z[i + 1] = a[i] @ self.w[i] + self.b[i] 223 | if (drop == False): 224 | if (i > 1): 225 | z[i + 1] = z[i + 1] * (1 - self.droprate) 226 | a[i + 1] = self.activations[i + 1].activation(z[i + 1]) 227 | if (drop): 228 | if (i < self.n_layers - 1): 229 | dropMask = np.random.rand(a[i + 1].shape[0], a[i + 1].shape[1]) 230 | dropMask[dropMask >= self.droprate] = 1 231 | dropMask[dropMask < self.droprate] = 0 232 | a[i + 1] = dropMask * a[i + 1] 233 | 234 | return z, a 235 | 236 | def _back_prop(self, z, a, y_true): 237 | """ 238 | The input dicts keys represent the layers of the net. 239 | 240 | a = { 1: x, 241 | 2: f(w1(x) + b1) 242 | 3: f(w2(a2) + b2) 243 | 4: f(w3(a3) + b3) 244 | 5: f(w4(a4) + b4) 245 | } 246 | 247 | :param z: (dict) w(x) + b 248 | :param a: (dict) f(z) 249 | :param y_true: (array) One hot encoded truth vector. 250 | :return: 251 | """ 252 | 253 | # Determine partial derivative and delta for the output layer. 254 | # delta output layer 255 | delta = self.loss.delta(y_true, a[self.n_layers]) 256 | dw = coo_matrix(self.w[self.n_layers - 1]) 257 | 258 | # compute backpropagation updates 259 | sparseoperations.backpropagation_updates_Cython(a[self.n_layers - 1], delta, dw.row, dw.col, dw.data) 260 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 261 | # backpropagation_updates_Numpy(a[self.n_layers - 1], delta, dw.row, dw.col, dw.data) 262 | 263 | update_params = { 264 | self.n_layers - 1: (dw.tocsr(), delta) 265 | } 266 | 267 | # In case of three layer net will iterate over i = 2 and i = 1 268 | # Determine partial derivative and delta for the rest of the layers. 269 | # Each iteration requires the delta from the previous layer, propagating backwards. 270 | for i in reversed(range(2, self.n_layers)): 271 | delta = (delta @ self.w[i].transpose()) * self.activations[i].prime(z[i]) 272 | dw = coo_matrix(self.w[i - 1]) 273 | 274 | # compute backpropagation updates 275 | sparseoperations.backpropagation_updates_Cython(a[i - 1], delta, dw.row, dw.col, dw.data) 276 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 277 | # backpropagation_updates_Numpy(a[i - 1], delta, dw.row, dw.col, dw.data) 278 | 279 | update_params[i - 1] = (dw.tocsr(), delta) 280 | for k, v in update_params.items(): 281 | self._update_w_b(k, v[0], v[1]) 282 | 283 | def _update_w_b(self, index, dw, delta): 284 | """ 285 | Update weights and biases. 286 | 287 | :param index: (int) Number of the layer 288 | :param dw: (array) Partial derivatives 289 | :param delta: (array) Delta error. 290 | """ 291 | 292 | # perform the update with momentum 293 | if (index not in self.pdw): 294 | self.pdw[index] = -self.learning_rate * dw 295 | self.pdd[index] = - self.learning_rate * np.mean(delta, 0) 296 | else: 297 | self.pdw[index] = self.momentum * self.pdw[index] - self.learning_rate * dw 298 | self.pdd[index] = self.momentum * self.pdd[index] - self.learning_rate * np.mean(delta, 0) 299 | 300 | self.w[index] += self.pdw[index] - self.weight_decay * self.w[index] 301 | self.b[index] += self.pdd[index] - self.weight_decay * self.b[index] 302 | 303 | def fit(self, x, y_true, x_test, y_test, loss, epochs, batch_size, learning_rate=1e-3, momentum=0.9, 304 | weight_decay=0.0002, dropoutrate=0, testing=True, save_filename=""): 305 | """ 306 | :param x: (array) Containing parameters 307 | :param y_true: (array) Containing one hot encoded labels. 308 | :param loss: Loss class (MSE, CrossEntropy etc.) 309 | :param epochs: (int) Number of epochs. 310 | :param batch_size: (int) 311 | :param learning_rate: (flt) 312 | :param momentum: (flt) 313 | :param weight_decay: (flt) 314 | :param zeta: (flt) #control the fraction of weights removed 315 | :param droprate: (flt) 316 | :return (array) A 2D array of metrics (epochs, 3). 317 | """ 318 | if not x.shape[0] == y_true.shape[0]: 319 | raise ValueError("Length of x and y arrays don't match") 320 | # Initiate the loss object with the final activation function 321 | self.loss = loss(self.activations[self.n_layers]) 322 | self.learning_rate = learning_rate 323 | self.momentum = momentum 324 | self.weight_decay = weight_decay 325 | self.droprate = dropoutrate 326 | 327 | maximum_accuracy = 0 328 | 329 | metrics = np.zeros((epochs, 4)) 330 | 331 | for i in range(epochs): 332 | # Shuffle the data 333 | seed = np.arange(x.shape[0]) 334 | np.random.shuffle(seed) 335 | x_ = x[seed] 336 | y_ = y_true[seed] 337 | 338 | # training 339 | t1 = datetime.datetime.now() 340 | 341 | for j in range(x.shape[0] // batch_size): 342 | k = j * batch_size 343 | l = (j + 1) * batch_size 344 | z, a = self._feed_forward(x_[k:l], True) 345 | 346 | 347 | self._back_prop(z, a, y_[k:l]) 348 | 349 | t2 = datetime.datetime.now() 350 | 351 | print("\nFixProb-MLP Epoch ", i) 352 | print("Training time: ", t2 - t1) 353 | 354 | # test model performance on the test data at each epoch 355 | # this part is useful to understand model performance and can be commented for production settings 356 | if (testing): 357 | t3 = datetime.datetime.now() 358 | accuracy_test, activations_test = self.predict(x_test, y_test, batch_size) 359 | accuracy_train, activations_train = self.predict(x, y_true, batch_size) 360 | t4 = datetime.datetime.now() 361 | maximum_accuracy = max(maximum_accuracy, accuracy_test) 362 | loss_test = self.loss.loss(y_test, activations_test) 363 | loss_train = self.loss.loss(y_true, activations_train) 364 | metrics[i, 0] = loss_train 365 | metrics[i, 1] = loss_test 366 | metrics[i, 2] = accuracy_train 367 | metrics[i, 3] = accuracy_test 368 | print("Testing time: ", t4 - t3,"; Loss train: ", loss_train, "; Loss test: ", loss_test, "; Accuracy train: ", accuracy_train,"; Accuracy test: ", accuracy_test, 369 | "; Maximum accuracy test: ", maximum_accuracy) 370 | 371 | # save performance metrics values in a file 372 | if (save_filename != ""): 373 | np.savetxt(save_filename, metrics) 374 | 375 | return metrics 376 | 377 | 378 | def predict(self, x_test, y_test, batch_size=1): 379 | """ 380 | :param x_test: (array) Test input 381 | :param y_test: (array) Correct test output 382 | :param batch_size: 383 | :return: (flt) Classification accuracy 384 | :return: (array) A 2D array of shape (n_cases, n_classes). 385 | """ 386 | activations = np.zeros((y_test.shape[0], y_test.shape[1])) 387 | for j in range(x_test.shape[0] // batch_size): 388 | k = j * batch_size 389 | l = (j + 1) * batch_size 390 | _, a_test = self._feed_forward(x_test[k:l]) 391 | activations[k:l] = a_test[self.n_layers] 392 | correctClassification = 0 393 | for j in range(y_test.shape[0]): 394 | if (np.argmax(activations[j]) == np.argmax(y_test[j])): 395 | correctClassification += 1 396 | accuracy = correctClassification / y_test.shape[0] 397 | return accuracy, activations 398 | 399 | def load_fashion_mnist_data(noTrainingSamples,noTestingSamples): 400 | np.random.seed(0) 401 | 402 | data=np.load("../Tutorial-IJCAI-2019-Scalable-Deep-Learning/data/fashion_mnist.npz") 403 | 404 | indexTrain=np.arange(data["X_train"].shape[0]) 405 | np.random.shuffle(indexTrain) 406 | 407 | indexTest=np.arange(data["X_test"].shape[0]) 408 | np.random.shuffle(indexTest) 409 | 410 | X_train=data["X_train"][indexTrain[0:noTrainingSamples],:] 411 | Y_train=data["Y_train"][indexTrain[0:noTrainingSamples],:] 412 | X_test=data["X_test"][indexTest[0:noTestingSamples],:] 413 | Y_test=data["Y_test"][indexTest[0:noTestingSamples],:] 414 | 415 | #normalize in 0..1 416 | X_train = X_train.astype('float64') / 255. 417 | X_test = X_test.astype('float64') / 255. 418 | 419 | return X_train,Y_train,X_test,Y_test 420 | 421 | if __name__ == "__main__": 422 | 423 | for i in range(1): 424 | #load data 425 | noTrainingSamples=2000 #max 60000 for Fashion MNIST 426 | noTestingSamples = 1000 # max 10000 for Fashion MNIST 427 | X_train, Y_train, X_test, Y_test = load_fashion_mnist_data(noTrainingSamples,noTestingSamples) 428 | 429 | #set model parameters 430 | noHiddenNeuronsLayer=1000 431 | epsilon=13 #set the sparsity level 432 | noTrainingEpochs=400 433 | batchSize=40 434 | dropoutRate=0.2 435 | learningRate=0.05 436 | momentum=0.9 437 | weightDecay=0.0002 438 | 439 | np.random.seed(i) 440 | 441 | # create FixProb-MLP (MLP with static sparse connectivity) 442 | fixprob_mlp = FixProb_MLP((X_train.shape[1], noHiddenNeuronsLayer, noHiddenNeuronsLayer,noHiddenNeuronsLayer, Y_train.shape[1]), (Relu, Relu,Relu, Sigmoid), epsilon=epsilon) 443 | 444 | # train FixProb-MLP 445 | fixprob_mlp.fit(X_train, Y_train, X_test, Y_test, loss=MSE, epochs=noTrainingEpochs, batch_size=batchSize, learning_rate=learningRate, 446 | momentum=momentum, weight_decay=weightDecay, dropoutrate=dropoutRate, testing=True, 447 | save_filename="Results/fixprob_mlp_"+str(noTrainingSamples)+"_training_samples_e"+str(epsilon)+"_rand"+str(i)+".txt") 448 | 449 | # test FixProb-MLP 450 | accuracy, _ = fixprob_mlp.predict(X_test, Y_test, batch_size=1) 451 | 452 | print("\nAccuracy of the last epoch on the testing data: ", accuracy) 453 | -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/plot_input_layer_connectivity.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.animation 3 | import matplotlib.colors as mcolors 4 | import numpy as np 5 | from mpl_toolkits.axes_grid1 import make_axes_locatable 6 | 7 | # update the data for each frame 8 | def anim(n): 9 | global data 10 | global allConnections 11 | data = allConnections[:,:,n] 12 | imobj.set_array(data) 13 | return imobj, 14 | 15 | 16 | 17 | for i in range(1): 18 | data = np.load("../Tutorial-IJCAI-2019-Scalable-Deep-Learning/data/fashion_mnist.npz") 19 | connections=np.load("Pretrained_results/set_mlp_2000_training_samples_e13_rand"+str(i)+"_input_connections.npz")["inputLayerConnections"] 20 | 21 | allConnections=np.zeros((28,28,len(connections))) 22 | for j in range(len(connections)): 23 | connectionsEpoch=np.reshape(connections[j],(28,28)) 24 | allConnections[:,:,j]=connectionsEpoch 25 | 26 | fig = plt.figure() 27 | fig.suptitle('ECMLPKDD 2019 tutorials\nScalable Deep Learning: from theory to practice', fontsize=14) 28 | 29 | ax1 = fig.add_subplot(121) 30 | ax1.imshow(np.reshape(data["X_train"][1,:],(28,28)),vmin=0,vmax=255,cmap="gray_r",interpolation=None) 31 | ax1.set_title("Fashion-MNIST example") 32 | 33 | ax2 = fig.add_subplot(122) 34 | data=allConnections[:,:,0] 35 | imobj = ax2.imshow(data,vmin=0,vmax=np.max(allConnections),cmap="jet",interpolation=None) 36 | ax2.set_title("Input connectivity pattern evolution\nwith SET-MLP") 37 | 38 | divider = make_axes_locatable(ax2) 39 | cax = divider.append_axes("right", size="5%", pad=0.05) 40 | 41 | cbar=fig.colorbar(imobj,cax=cax) 42 | cbar.set_label('Connections per input neuron (pixel)',size=8) 43 | 44 | fig.tight_layout() 45 | 46 | # create the animation 47 | ani = matplotlib.animation.FuncAnimation(fig, anim, frames=len(connections)) 48 | ani.save("Pretrained_results/fashion_mnist_connections_evolution_per_input_pixel_rand"+str(i)+".gif", writer='imagemagick',fps=24,codec=None) -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/plot_learning_curve.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | epsilon=13 7 | samples=2000 8 | 9 | set_mlp=np.loadtxt("Pretrained_results/set_mlp_"+str(samples)+"_training_samples_e"+str(epsilon)+"_rand0.txt") 10 | fixprob_mlp=np.loadtxt("Pretrained_results/fixprob_mlp_"+str(samples)+"_training_samples_e"+str(epsilon)+"_rand0.txt") 11 | fc_mlp=np.loadtxt("Pretrained_results/fc_mlp_"+str(samples)+"_training_samples_rand0.txt") 12 | 13 | """ 14 | for i in range(1,5): 15 | set_mlp = set_mlp + np.loadtxt( 16 | "Results/set_mlp_" + str(samples) + "_training_samples_e" + str(epsilon) + "_rand" + str(i) + ".txt") 17 | fixprob_mlp = fixprob_mlp + np.loadtxt( 18 | "Results/fixprob_mlp_" + str(samples) + "_training_samples_e" + str(epsilon) + "_rand" + str(i) + "") 19 | fc_mlp = fc_mlp + np.loadtxt("Pretrained_results/fc_mlp_" + str(samples) + "_training_samples_rand" + str(i) + ".txt") 20 | 21 | set_mlp/=5 22 | fixprob_mlp/=5 23 | fc_mlp/=5 24 | """ 25 | font = { 'size' : 9} 26 | fig = plt.figure(figsize=(10,5)) 27 | matplotlib.rc('font', **font) 28 | fig.subplots_adjust(wspace=0.2,hspace=0.05) 29 | 30 | ax1=fig.add_subplot(1,2,1) 31 | ax1.plot(set_mlp[:,2]*100, label="SET-MLP train accuracy", color="r") 32 | ax1.plot(set_mlp[:,3]*100, label="SET-MLP test accuracy", color="b") 33 | ax1.plot(fixprob_mlp[:,2]*100, label="MLP$_{FixProb}$ train accuracy", color="g") 34 | ax1.plot(fixprob_mlp[:,3]*100, label="MLP$_{FixProb}$ test accuracy", color="m") 35 | ax1.plot(fc_mlp[:,2]*100, label="FC-MLP train accuracy", color="y") 36 | ax1.plot(fc_mlp[:,3]*100, label="FC-MLP test accuracy", color="k") 37 | ax1.grid(True) 38 | ax1.set_ylabel("Fashion MNIST\nAccuracy [%]") 39 | ax1.set_xlabel("Epochs [#]") 40 | ax1.legend(loc=4,fontsize=8) 41 | 42 | ax2=fig.add_subplot(1,2,2) 43 | ax2.plot(set_mlp[:,0], label="SET-MLP train loss", color="r") 44 | ax2.plot(set_mlp[:,1], label="SET-MLP test loss", color="b") 45 | ax2.plot(fixprob_mlp[:,0], label="MLP$_{FixProb}$ train loss", color="g") 46 | ax2.plot(fixprob_mlp[:,1], label="MLP$_{FixProb}$ test loss", color="m") 47 | ax2.plot(fc_mlp[:,0], label="FC-MLP train loss", color="y") 48 | ax2.plot(fc_mlp[:,1], label="FC-MLP test loss", color="k") 49 | ax2.grid(True) 50 | ax2.set_ylabel("Loss (MSE)") 51 | ax2.set_xlabel("Epochs [#]") 52 | ax2.legend(loc=1,fontsize=8) 53 | 54 | 55 | plt.savefig("Pretrained_results/mnist_learning_curves_samples"+str(samples)+".pdf", bbox_inches='tight') 56 | 57 | plt.close() -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/sparseoperations.pyx: -------------------------------------------------------------------------------- 1 | # compile this file with: "cythonize -a -i sparseoperations.pyx" 2 | # I have tested this method in Linux (Ubuntu). If you compile it in Windows you may need some work around. 3 | 4 | cimport numpy as np 5 | 6 | def backpropagation_updates_Cython(np.ndarray[np.float64_t,ndim=2] a, np.ndarray[np.float64_t,ndim=2] delta, np.ndarray[int,ndim=1] rows, np.ndarray[int,ndim=1] cols,np.ndarray[np.float64_t,ndim=1] out): 7 | cdef: 8 | size_t i,j 9 | double s 10 | for i in range (out.shape[0]): 11 | s=0 12 | for j in range(a.shape[0]): 13 | s+=a[j,rows[i]]*delta[j, cols[i]] 14 | out[i]=s/a.shape[0] 15 | #return out 16 | 17 | -------------------------------------------------------------------------------- /Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/sparseoperations.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-ECMLPKDD-2019-Scalable-Deep-Learning/sparseoperations.so -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/fashion_mnist_connections_evolution_per_input_pixel_rand0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/fashion_mnist_connections_evolution_per_input_pixel_rand0.gif -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/mnist_learning_curves_samples2000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/mnist_learning_curves_samples2000.pdf -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand0_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand0_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand1_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand1_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand2_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand2_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand3_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand3_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand4_input_connections.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/Pretrained_results/set_mlp_2000_training_samples_e13_rand4_input_connections.npz -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/Results/mnist_learning_curves_samples2000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/Results/mnist_learning_curves_samples2000.pdf -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/data/fashion_mnist.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/data/fashion_mnist.npz -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/fc_mlp.py: -------------------------------------------------------------------------------- 1 | # Authors: Decebal Constantin Mocanu et al.; 2 | # Code associated with IJCAI 2019 tutorial "Scalable Deep Learning: from theory to practice"; https://sites.google.com/view/scalable-deep-learning-ijcai19 3 | # This is a pre-alpha free software and was tested in Ubuntu 16.04 with Python 3.5.2, Numpy 1.14, SciPy 0.19.1, and (optionally) Cython 0.27.3; 4 | 5 | # If you use parts of this code please cite the following article: 6 | #@article{Mocanu2018SET, 7 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 8 | # journal = {Nature Communications}, 9 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 10 | # year = {2018}, 11 | # doi = {10.1038/s41467-018-04316-3} 12 | #} 13 | 14 | #If you have space please consider citing also these articles 15 | 16 | #@phdthesis{Mocanu2017PhDthesis, 17 | #title = "Network computations in artificial intelligence", 18 | #author = "D.C. Mocanu", 19 | #year = "2017", 20 | #isbn = "978-90-386-4305-2", 21 | #publisher = "Eindhoven University of Technology", 22 | #} 23 | 24 | #@article{Liu2019onemillion, 25 | # author = {Liu, Shiwei and Mocanu, Decebal Constantin and Mocanu and Ramapuram Matavalam, Amarsagar Reddy and Pei, Yulong Pei and Pechenizkiy, Mykola}, 26 | # journal = {arXiv:1901.09181}, 27 | # title = {Sparse evolutionary Deep Learning with over one million artificial neurons on commodity hardware}, 28 | # year = {2019}, 29 | #} 30 | 31 | # We thank to: 32 | # Thomas Hagebols: for performing a thorough analyze on the performance of SciPy sparse matrix operations 33 | # Ritchie Vink (https://www.ritchievink.com): for making available on Github a nice Python implementation of fully connected MLPs. This SET-MLP implementation was built on top of his MLP code: 34 | # https://github.com/ritchie46/vanilla-machine-learning/blob/master/vanilla_mlp.py 35 | 36 | 37 | import numpy as np 38 | import datetime 39 | 40 | 41 | def backpropagation_updates_Numpy(a, delta, rows, cols, out): 42 | for i in range(out.shape[0]): 43 | s = 0 44 | for j in range(a.shape[0]): 45 | s += a[j, rows[i]] * delta[j, cols[i]] 46 | out[i] = s / a.shape[0] 47 | 48 | 49 | def find_first_pos(array, value): 50 | idx = (np.abs(array - value)).argmin() 51 | return idx 52 | 53 | 54 | def find_last_pos(array, value): 55 | idx = (np.abs(array - value))[::-1].argmin() 56 | return array.shape[0] - idx 57 | 58 | 59 | def array_intersect(A, B): 60 | # this are for array intersection 61 | nrows, ncols = A.shape 62 | dtype = {'names': ['f{}'.format(i) for i in range(ncols)], 'formats': ncols * [A.dtype]} 63 | return np.in1d(A.view(dtype), B.view(dtype)) # boolean return 64 | 65 | 66 | class Relu: 67 | @staticmethod 68 | def activation(z): 69 | z[z < 0] = 0 70 | return z 71 | 72 | @staticmethod 73 | def prime(z): 74 | z[z < 0] = 0 75 | z[z > 0] = 1 76 | return z 77 | 78 | 79 | class Sigmoid: 80 | @staticmethod 81 | def activation(z): 82 | return 1 / (1 + np.exp(-z)) 83 | 84 | @staticmethod 85 | def prime(z): 86 | return Sigmoid.activation(z) * (1 - Sigmoid.activation(z)) 87 | 88 | 89 | class MSE: 90 | def __init__(self, activation_fn=None): 91 | """ 92 | 93 | :param activation_fn: Class object of the activation function. 94 | """ 95 | if activation_fn: 96 | self.activation_fn = activation_fn 97 | else: 98 | self.activation_fn = NoActivation 99 | 100 | def activation(self, z): 101 | return self.activation_fn.activation(z) 102 | 103 | @staticmethod 104 | def loss(y_true, y_pred): 105 | """ 106 | :param y_true: (array) One hot encoded truth vector. 107 | :param y_pred: (array) Prediction vector 108 | :return: (flt) 109 | """ 110 | return np.mean((y_pred - y_true) ** 2) 111 | 112 | @staticmethod 113 | def prime(y_true, y_pred): 114 | return y_pred - y_true 115 | 116 | def delta(self, y_true, y_pred): 117 | """ 118 | Back propagation error delta 119 | :return: (array) 120 | """ 121 | return self.prime(y_true, y_pred) * self.activation_fn.prime(y_pred) 122 | 123 | 124 | class NoActivation: 125 | """ 126 | This is a plugin function for no activation. 127 | 128 | f(x) = x * 1 129 | """ 130 | 131 | @staticmethod 132 | def activation(z): 133 | """ 134 | :param z: (array) w(x) + b 135 | :return: z (array) 136 | """ 137 | return z 138 | 139 | @staticmethod 140 | def prime(z): 141 | """ 142 | The prime of z * 1 = 1 143 | :param z: (array) 144 | :return: z': (array) 145 | """ 146 | return np.ones_like(z) 147 | 148 | 149 | class FC_MLP: 150 | def __init__(self, dimensions, activations): 151 | """ 152 | :param dimensions: (tpl/ list) Dimensions of the neural net. (input, hidden layer, output) 153 | :param activations: (tpl/ list) Activations functions. 154 | 155 | Example of three hidden layer with 156 | - 3312 input features 157 | - 3000 hidden neurons 158 | - 3000 hidden neurons 159 | - 3000 hidden neurons 160 | - 5 output classes 161 | 162 | 163 | layers --> [1, 2, 3, 4, 5] 164 | ---------------------------------------- 165 | 166 | dimensions = (3312, 3000, 3000, 3000, 5) 167 | activations = ( Relu, Relu, Relu, Sigmoid) 168 | """ 169 | self.n_layers = len(dimensions) 170 | self.loss = None 171 | self.learning_rate = None 172 | self.momentum = None 173 | self.weight_decay = None 174 | self.droprate = 0 # dropout rate 175 | self.dimensions = dimensions 176 | 177 | # Weights and biases are initiated by index. For a one hidden layer net you will have a w[1] and w[2] 178 | self.w = {} 179 | self.b = {} 180 | self.pdw = {} 181 | self.pdd = {} 182 | 183 | # Activations are also initiated by index. For the example we will have activations[2] and activations[3] 184 | self.activations = {} 185 | for i in range(len(dimensions) - 1): 186 | self.w[i + 1] = np.random.randn( dimensions[i],dimensions[i + 1])/ 10 187 | self.b[i + 1] = np.zeros(dimensions[i + 1]) 188 | self.activations[i + 2] = activations[i] 189 | 190 | def _feed_forward(self, x, drop=False): 191 | """ 192 | Execute a forward feed through the network. 193 | :param x: (array) Batch of input data vectors. 194 | :return: (tpl) Node outputs and activations per layer. The numbering of the output is equivalent to the layer numbers. 195 | """ 196 | 197 | # w(x) + b 198 | z = {} 199 | 200 | # activations: f(z) 201 | a = {1: x} # First layer has no activations as input. The input x is the input. 202 | 203 | for i in range(1, self.n_layers): 204 | z[i + 1] = a[i] @ self.w[i] + self.b[i] 205 | if (drop == False): 206 | if (i > 1): 207 | z[i + 1] = z[i + 1] * (1 - self.droprate) 208 | a[i + 1] = self.activations[i + 1].activation(z[i + 1]) 209 | if (drop): 210 | if (i < self.n_layers - 1): 211 | dropMask = np.random.rand(a[i + 1].shape[0], a[i + 1].shape[1]) 212 | dropMask[dropMask >= self.droprate] = 1 213 | dropMask[dropMask < self.droprate] = 0 214 | a[i + 1] = dropMask * a[i + 1] 215 | 216 | return z, a 217 | 218 | def _back_prop(self, z, a, y_true): 219 | """ 220 | The input dicts keys represent the layers of the net. 221 | 222 | a = { 1: x, 223 | 2: f(w1(x) + b1) 224 | 3: f(w2(a2) + b2) 225 | 4: f(w3(a3) + b3) 226 | 5: f(w4(a4) + b4) 227 | } 228 | 229 | :param z: (dict) w(x) + b 230 | :param a: (dict) f(z) 231 | :param y_true: (array) One hot encoded truth vector. 232 | :return: 233 | """ 234 | 235 | # Determine partial derivative and delta for the output layer. 236 | # delta output layer 237 | delta = self.loss.delta(y_true, a[self.n_layers]) 238 | dw = np.dot(a[self.n_layers - 1].T, delta) 239 | 240 | 241 | update_params = { 242 | self.n_layers - 1: (dw, delta) 243 | } 244 | 245 | # In case of three layer net will iterate over i = 2 and i = 1 246 | # Determine partial derivative and delta for the rest of the layers. 247 | # Each iteration requires the delta from the previous layer, propagating backwards. 248 | for i in reversed(range(2, self.n_layers)): 249 | delta = (delta @ self.w[i].transpose()) * self.activations[i].prime(z[i]) 250 | dw = np.dot(a[i - 1].T, delta) 251 | 252 | update_params[i - 1] = (dw, delta) 253 | for k, v in update_params.items(): 254 | self._update_w_b(k, v[0], v[1]) 255 | 256 | def _update_w_b(self, index, dw, delta): 257 | """ 258 | Update weights and biases. 259 | 260 | :param index: (int) Number of the layer 261 | :param dw: (array) Partial derivatives 262 | :param delta: (array) Delta error. 263 | """ 264 | 265 | # perform the update with momentum 266 | if (index not in self.pdw): 267 | self.pdw[index] = -self.learning_rate * dw 268 | self.pdd[index] = -self.learning_rate * np.mean(delta, 0) 269 | else: 270 | self.pdw[index] = self.momentum * self.pdw[index] - self.learning_rate * dw 271 | self.pdd[index] = self.momentum * self.pdd[index] - self.learning_rate * np.mean(delta, 0) 272 | 273 | self.w[index] += self.pdw[index] - self.weight_decay * self.w[index] 274 | self.b[index] += self.pdd[index] - self.weight_decay * self.b[index] 275 | 276 | def fit(self, x, y_true, x_test, y_test, loss, epochs, batch_size, learning_rate=1e-3, momentum=0.9, 277 | weight_decay=0.0002, dropoutrate=0, testing=True, save_filename=""): 278 | """ 279 | :param x: (array) Containing parameters 280 | :param y_true: (array) Containing one hot encoded labels. 281 | :param loss: Loss class (MSE, CrossEntropy etc.) 282 | :param epochs: (int) Number of epochs. 283 | :param batch_size: (int) 284 | :param learning_rate: (flt) 285 | :param momentum: (flt) 286 | :param weight_decay: (flt) 287 | :param zeta: (flt) #control the fraction of weights removed 288 | :param droprate: (flt) 289 | :return (array) A 2D array of metrics (epochs, 3). 290 | """ 291 | if not x.shape[0] == y_true.shape[0]: 292 | raise ValueError("Length of x and y arrays don't match") 293 | # Initiate the loss object with the final activation function 294 | self.loss = loss(self.activations[self.n_layers]) 295 | self.learning_rate = learning_rate 296 | self.momentum = momentum 297 | self.weight_decay = weight_decay 298 | self.droprate = dropoutrate 299 | 300 | maximum_accuracy = 0 301 | 302 | metrics = np.zeros((epochs, 4)) 303 | 304 | for i in range(epochs): 305 | # Shuffle the data 306 | seed = np.arange(x.shape[0]) 307 | np.random.shuffle(seed) 308 | x_ = x[seed] 309 | y_ = y_true[seed] 310 | 311 | # training 312 | t1 = datetime.datetime.now() 313 | 314 | for j in range(x.shape[0] // batch_size): 315 | k = j * batch_size 316 | l = (j + 1) * batch_size 317 | z, a = self._feed_forward(x_[k:l], True) 318 | 319 | 320 | self._back_prop(z, a, y_[k:l]) 321 | 322 | t2 = datetime.datetime.now() 323 | 324 | print("\nFC-MLP Epoch ", i) 325 | print("Training time: ", t2 - t1) 326 | 327 | # test model performance on the test data at each epoch 328 | # this part is useful to understand model performance and can be commented for production settings 329 | if (testing): 330 | t3 = datetime.datetime.now() 331 | accuracy_test, activations_test = self.predict(x_test, y_test, batch_size) 332 | accuracy_train, activations_train = self.predict(x, y_true, batch_size) 333 | t4 = datetime.datetime.now() 334 | maximum_accuracy = max(maximum_accuracy, accuracy_test) 335 | loss_test = self.loss.loss(y_test, activations_test) 336 | loss_train = self.loss.loss(y_true, activations_train) 337 | metrics[i, 0] = loss_train 338 | metrics[i, 1] = loss_test 339 | metrics[i, 2] = accuracy_train 340 | metrics[i, 3] = accuracy_test 341 | print("Testing time: ", t4 - t3,"; Loss train: ", loss_train, "; Loss test: ", loss_test, "; Accuracy train: ", accuracy_train,"; Accuracy test: ", accuracy_test, 342 | "; Maximum accuracy test: ", maximum_accuracy) 343 | 344 | # save performance metrics values in a file 345 | if (save_filename != ""): 346 | np.savetxt(save_filename, metrics) 347 | 348 | return metrics 349 | 350 | 351 | def predict(self, x_test, y_test, batch_size=1): 352 | """ 353 | :param x_test: (array) Test input 354 | :param y_test: (array) Correct test output 355 | :param batch_size: 356 | :return: (flt) Classification accuracy 357 | :return: (array) A 2D array of shape (n_cases, n_classes). 358 | """ 359 | activations = np.zeros((y_test.shape[0], y_test.shape[1])) 360 | for j in range(x_test.shape[0] // batch_size): 361 | k = j * batch_size 362 | l = (j + 1) * batch_size 363 | _, a_test = self._feed_forward(x_test[k:l]) 364 | activations[k:l] = a_test[self.n_layers] 365 | correctClassification = 0 366 | for j in range(y_test.shape[0]): 367 | if (np.argmax(activations[j]) == np.argmax(y_test[j])): 368 | correctClassification += 1 369 | accuracy = correctClassification / y_test.shape[0] 370 | return accuracy, activations 371 | 372 | def load_fashion_mnist_data(noTrainingSamples,noTestingSamples): 373 | np.random.seed(0) 374 | 375 | data=np.load("data/fashion_mnist.npz") 376 | 377 | indexTrain=np.arange(data["X_train"].shape[0]) 378 | np.random.shuffle(indexTrain) 379 | 380 | indexTest=np.arange(data["X_test"].shape[0]) 381 | np.random.shuffle(indexTest) 382 | 383 | X_train=data["X_train"][indexTrain[0:noTrainingSamples],:] 384 | Y_train=data["Y_train"][indexTrain[0:noTrainingSamples],:] 385 | X_test=data["X_test"][indexTest[0:noTestingSamples],:] 386 | Y_test=data["Y_test"][indexTest[0:noTestingSamples],:] 387 | 388 | #normalize in 0..1 389 | X_train = X_train.astype('float64') / 255. 390 | X_test = X_test.astype('float64') / 255. 391 | 392 | return X_train,Y_train,X_test,Y_test 393 | 394 | if __name__ == "__main__": 395 | 396 | for i in range(1): 397 | #load data 398 | noTrainingSamples=2000 #max 60000 for Fashion MNIST 399 | noTestingSamples = 1000 # max 10000 for Fshion MNIST 400 | X_train, Y_train, X_test, Y_test = load_fashion_mnist_data(noTrainingSamples,noTestingSamples) 401 | 402 | #set model parameters 403 | noHiddenNeuronsLayer=1000 404 | noTrainingEpochs=400 405 | batchSize=40 406 | dropoutRate=0.2 407 | learningRate=0.001 408 | momentum=0.9 409 | weightDecay=0.0002 410 | 411 | np.random.seed(i) 412 | 413 | # create FC-MLP ( fully-connected MLP) 414 | fc_mlp = FC_MLP((X_train.shape[1], noHiddenNeuronsLayer, noHiddenNeuronsLayer,noHiddenNeuronsLayer, Y_train.shape[1]), (Relu, Relu,Relu, Sigmoid)) 415 | 416 | # train FC-MLP 417 | fc_mlp.fit(X_train, Y_train, X_test, Y_test, loss=MSE, epochs=noTrainingEpochs, batch_size=batchSize, learning_rate=learningRate, 418 | momentum=momentum, weight_decay=weightDecay, dropoutrate=dropoutRate, testing=True, 419 | save_filename="Results/fc_mlp_"+str(noTrainingSamples)+"_training_samples_rand"+str(i)+".txt") 420 | 421 | # test FC-MLP 422 | accuracy, _ = fc_mlp.predict(X_test, Y_test, batch_size=1) 423 | 424 | print("\nAccuracy of the last epoch on the testing data: ", accuracy) 425 | -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/fixprob_mlp.py: -------------------------------------------------------------------------------- 1 | # Authors: Decebal Constantin Mocanu et al.; 2 | # Code associated with IJCAI 2019 tutorial "Scalable Deep Learning: from theory to practice"; https://sites.google.com/view/scalable-deep-learning-ijcai19 3 | # This is a pre-alpha free software and was tested in Ubuntu 16.04 with Python 3.5.2, Numpy 1.14, SciPy 0.19.1, and (optionally) Cython 0.27.3; 4 | 5 | # If you use parts of this code please cite the following article: 6 | #@article{Mocanu2018SET, 7 | # author = {Mocanu, Decebal Constantin and Mocanu, Elena and Stone, Peter and Nguyen, Phuong H. and Gibescu, Madeleine and Liotta, Antonio}, 8 | # journal = {Nature Communications}, 9 | # title = {Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity inspired by Network Science}, 10 | # year = {2018}, 11 | # doi = {10.1038/s41467-018-04316-3} 12 | #} 13 | 14 | #If you have space please consider citing also these articles 15 | 16 | #@phdthesis{Mocanu2017PhDthesis, 17 | #title = "Network computations in artificial intelligence", 18 | #author = "D.C. Mocanu", 19 | #year = "2017", 20 | #isbn = "978-90-386-4305-2", 21 | #publisher = "Eindhoven University of Technology", 22 | #} 23 | 24 | #@article{Liu2019onemillion, 25 | # author = {Liu, Shiwei and Mocanu, Decebal Constantin and Mocanu and Ramapuram Matavalam, Amarsagar Reddy and Pei, Yulong Pei and Pechenizkiy, Mykola}, 26 | # journal = {arXiv:1901.09181}, 27 | # title = {Sparse evolutionary Deep Learning with over one million artificial neurons on commodity hardware}, 28 | # year = {2019}, 29 | #} 30 | 31 | # We thank to: 32 | # Thomas Hagebols: for performing a thorough analyze on the performance of SciPy sparse matrix operations 33 | # Ritchie Vink (https://www.ritchievink.com): for making available on Github a nice Python implementation of fully connected MLPs. This SET-MLP implementation was built on top of his MLP code: 34 | # https://github.com/ritchie46/vanilla-machine-learning/blob/master/vanilla_mlp.py 35 | 36 | 37 | import numpy as np 38 | from scipy.sparse import lil_matrix 39 | from scipy.sparse import coo_matrix 40 | #the "sparseoperations" Cython library was tested in Ubuntu 16.04. Please note that you may encounter some "solvable" issues if you compile it in Windows. 41 | import sparseoperations 42 | import datetime 43 | 44 | 45 | def backpropagation_updates_Numpy(a, delta, rows, cols, out): 46 | for i in range(out.shape[0]): 47 | s = 0 48 | for j in range(a.shape[0]): 49 | s += a[j, rows[i]] * delta[j, cols[i]] 50 | out[i] = s / a.shape[0] 51 | 52 | 53 | def find_first_pos(array, value): 54 | idx = (np.abs(array - value)).argmin() 55 | return idx 56 | 57 | 58 | def find_last_pos(array, value): 59 | idx = (np.abs(array - value))[::-1].argmin() 60 | return array.shape[0] - idx 61 | 62 | 63 | def createSparseWeights(epsilon, noRows, noCols): 64 | # generate an Erdos Renyi sparse weights mask 65 | weights = lil_matrix((noRows, noCols)) 66 | for i in range(epsilon * (noRows + noCols)): 67 | weights[np.random.randint(0, noRows), np.random.randint(0, noCols)] = np.float64(np.random.randn() / 10) 68 | print("Create sparse matrix with ", weights.getnnz(), " connections and ", 69 | (weights.getnnz() / (noRows * noCols)) * 100, "% density level") 70 | weights = weights.tocsr() 71 | return weights 72 | 73 | 74 | def array_intersect(A, B): 75 | # this are for array intersection 76 | nrows, ncols = A.shape 77 | dtype = {'names': ['f{}'.format(i) for i in range(ncols)], 'formats': ncols * [A.dtype]} 78 | return np.in1d(A.view(dtype), B.view(dtype)) # boolean return 79 | 80 | 81 | class Relu: 82 | @staticmethod 83 | def activation(z): 84 | z[z < 0] = 0 85 | return z 86 | 87 | @staticmethod 88 | def prime(z): 89 | z[z < 0] = 0 90 | z[z > 0] = 1 91 | return z 92 | 93 | 94 | class Sigmoid: 95 | @staticmethod 96 | def activation(z): 97 | return 1 / (1 + np.exp(-z)) 98 | 99 | @staticmethod 100 | def prime(z): 101 | return Sigmoid.activation(z) * (1 - Sigmoid.activation(z)) 102 | 103 | 104 | class MSE: 105 | def __init__(self, activation_fn=None): 106 | """ 107 | 108 | :param activation_fn: Class object of the activation function. 109 | """ 110 | if activation_fn: 111 | self.activation_fn = activation_fn 112 | else: 113 | self.activation_fn = NoActivation 114 | 115 | def activation(self, z): 116 | return self.activation_fn.activation(z) 117 | 118 | @staticmethod 119 | def loss(y_true, y_pred): 120 | """ 121 | :param y_true: (array) One hot encoded truth vector. 122 | :param y_pred: (array) Prediction vector 123 | :return: (flt) 124 | """ 125 | return np.mean((y_pred - y_true) ** 2) 126 | 127 | @staticmethod 128 | def prime(y_true, y_pred): 129 | return y_pred - y_true 130 | 131 | def delta(self, y_true, y_pred): 132 | """ 133 | Back propagation error delta 134 | :return: (array) 135 | """ 136 | return self.prime(y_true, y_pred) * self.activation_fn.prime(y_pred) 137 | 138 | 139 | class NoActivation: 140 | """ 141 | This is a plugin function for no activation. 142 | 143 | f(x) = x * 1 144 | """ 145 | 146 | @staticmethod 147 | def activation(z): 148 | """ 149 | :param z: (array) w(x) + b 150 | :return: z (array) 151 | """ 152 | return z 153 | 154 | @staticmethod 155 | def prime(z): 156 | """ 157 | The prime of z * 1 = 1 158 | :param z: (array) 159 | :return: z': (array) 160 | """ 161 | return np.ones_like(z) 162 | 163 | 164 | class FixProb_MLP: 165 | def __init__(self, dimensions, activations, epsilon=20): 166 | """ 167 | :param dimensions: (tpl/ list) Dimensions of the neural net. (input, hidden layer, output) 168 | :param activations: (tpl/ list) Activations functions. 169 | 170 | Example of three hidden layer with 171 | - 3312 input features 172 | - 3000 hidden neurons 173 | - 3000 hidden neurons 174 | - 3000 hidden neurons 175 | - 5 output classes 176 | 177 | 178 | layers --> [1, 2, 3, 4, 5] 179 | ---------------------------------------- 180 | 181 | dimensions = (3312, 3000, 3000, 3000, 5) 182 | activations = ( Relu, Relu, Relu, Sigmoid) 183 | """ 184 | self.n_layers = len(dimensions) 185 | self.loss = None 186 | self.learning_rate = None 187 | self.momentum = None 188 | self.weight_decay = None 189 | self.epsilon = epsilon # control the sparsity level as discussed in the paper 190 | self.zeta = None # the fraction of the weights removed 191 | self.droprate = 0 # dropout rate 192 | self.dimensions = dimensions 193 | 194 | # Weights and biases are initiated by index. For a one hidden layer net you will have a w[1] and w[2] 195 | self.w = {} 196 | self.b = {} 197 | self.pdw = {} 198 | self.pdd = {} 199 | 200 | # Activations are also initiated by index. For the example we will have activations[2] and activations[3] 201 | self.activations = {} 202 | for i in range(len(dimensions) - 1): 203 | self.w[i + 1] = createSparseWeights(self.epsilon, dimensions[i], 204 | dimensions[i + 1]) # create sparse weight matrices 205 | self.b[i + 1] = np.zeros(dimensions[i + 1]) 206 | self.activations[i + 2] = activations[i] 207 | 208 | def _feed_forward(self, x, drop=False): 209 | """ 210 | Execute a forward feed through the network. 211 | :param x: (array) Batch of input data vectors. 212 | :return: (tpl) Node outputs and activations per layer. The numbering of the output is equivalent to the layer numbers. 213 | """ 214 | 215 | # w(x) + b 216 | z = {} 217 | 218 | # activations: f(z) 219 | a = {1: x} # First layer has no activations as input. The input x is the input. 220 | 221 | for i in range(1, self.n_layers): 222 | z[i + 1] = a[i] @ self.w[i] + self.b[i] 223 | if (drop == False): 224 | if (i > 1): 225 | z[i + 1] = z[i + 1] * (1 - self.droprate) 226 | a[i + 1] = self.activations[i + 1].activation(z[i + 1]) 227 | if (drop): 228 | if (i < self.n_layers - 1): 229 | dropMask = np.random.rand(a[i + 1].shape[0], a[i + 1].shape[1]) 230 | dropMask[dropMask >= self.droprate] = 1 231 | dropMask[dropMask < self.droprate] = 0 232 | a[i + 1] = dropMask * a[i + 1] 233 | 234 | return z, a 235 | 236 | def _back_prop(self, z, a, y_true): 237 | """ 238 | The input dicts keys represent the layers of the net. 239 | 240 | a = { 1: x, 241 | 2: f(w1(x) + b1) 242 | 3: f(w2(a2) + b2) 243 | 4: f(w3(a3) + b3) 244 | 5: f(w4(a4) + b4) 245 | } 246 | 247 | :param z: (dict) w(x) + b 248 | :param a: (dict) f(z) 249 | :param y_true: (array) One hot encoded truth vector. 250 | :return: 251 | """ 252 | 253 | # Determine partial derivative and delta for the output layer. 254 | # delta output layer 255 | delta = self.loss.delta(y_true, a[self.n_layers]) 256 | dw = coo_matrix(self.w[self.n_layers - 1]) 257 | 258 | # compute backpropagation updates 259 | sparseoperations.backpropagation_updates_Cython(a[self.n_layers - 1], delta, dw.row, dw.col, dw.data) 260 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 261 | # backpropagation_updates_Numpy(a[self.n_layers - 1], delta, dw.row, dw.col, dw.data) 262 | 263 | update_params = { 264 | self.n_layers - 1: (dw.tocsr(), delta) 265 | } 266 | 267 | # In case of three layer net will iterate over i = 2 and i = 1 268 | # Determine partial derivative and delta for the rest of the layers. 269 | # Each iteration requires the delta from the previous layer, propagating backwards. 270 | for i in reversed(range(2, self.n_layers)): 271 | delta = (delta @ self.w[i].transpose()) * self.activations[i].prime(z[i]) 272 | dw = coo_matrix(self.w[i - 1]) 273 | 274 | # compute backpropagation updates 275 | sparseoperations.backpropagation_updates_Cython(a[i - 1], delta, dw.row, dw.col, dw.data) 276 | # If you have problems with Cython please use the backpropagation_updates_Numpy method by uncommenting the line below and commenting the one above. Please note that the running time will be much higher 277 | # backpropagation_updates_Numpy(a[i - 1], delta, dw.row, dw.col, dw.data) 278 | 279 | update_params[i - 1] = (dw.tocsr(), delta) 280 | for k, v in update_params.items(): 281 | self._update_w_b(k, v[0], v[1]) 282 | 283 | def _update_w_b(self, index, dw, delta): 284 | """ 285 | Update weights and biases. 286 | 287 | :param index: (int) Number of the layer 288 | :param dw: (array) Partial derivatives 289 | :param delta: (array) Delta error. 290 | """ 291 | 292 | # perform the update with momentum 293 | if (index not in self.pdw): 294 | self.pdw[index] = -self.learning_rate * dw 295 | self.pdd[index] = - self.learning_rate * np.mean(delta, 0) 296 | else: 297 | self.pdw[index] = self.momentum * self.pdw[index] - self.learning_rate * dw 298 | self.pdd[index] = self.momentum * self.pdd[index] - self.learning_rate * np.mean(delta, 0) 299 | 300 | self.w[index] += self.pdw[index] - self.weight_decay * self.w[index] 301 | self.b[index] += self.pdd[index] - self.weight_decay * self.b[index] 302 | 303 | def fit(self, x, y_true, x_test, y_test, loss, epochs, batch_size, learning_rate=1e-3, momentum=0.9, 304 | weight_decay=0.0002, dropoutrate=0, testing=True, save_filename=""): 305 | """ 306 | :param x: (array) Containing parameters 307 | :param y_true: (array) Containing one hot encoded labels. 308 | :param loss: Loss class (MSE, CrossEntropy etc.) 309 | :param epochs: (int) Number of epochs. 310 | :param batch_size: (int) 311 | :param learning_rate: (flt) 312 | :param momentum: (flt) 313 | :param weight_decay: (flt) 314 | :param zeta: (flt) #control the fraction of weights removed 315 | :param droprate: (flt) 316 | :return (array) A 2D array of metrics (epochs, 3). 317 | """ 318 | if not x.shape[0] == y_true.shape[0]: 319 | raise ValueError("Length of x and y arrays don't match") 320 | # Initiate the loss object with the final activation function 321 | self.loss = loss(self.activations[self.n_layers]) 322 | self.learning_rate = learning_rate 323 | self.momentum = momentum 324 | self.weight_decay = weight_decay 325 | self.droprate = dropoutrate 326 | 327 | maximum_accuracy = 0 328 | 329 | metrics = np.zeros((epochs, 4)) 330 | 331 | for i in range(epochs): 332 | # Shuffle the data 333 | seed = np.arange(x.shape[0]) 334 | np.random.shuffle(seed) 335 | x_ = x[seed] 336 | y_ = y_true[seed] 337 | 338 | # training 339 | t1 = datetime.datetime.now() 340 | 341 | for j in range(x.shape[0] // batch_size): 342 | k = j * batch_size 343 | l = (j + 1) * batch_size 344 | z, a = self._feed_forward(x_[k:l], True) 345 | 346 | 347 | self._back_prop(z, a, y_[k:l]) 348 | 349 | t2 = datetime.datetime.now() 350 | 351 | print("\nFixProb-MLP Epoch ", i) 352 | print("Training time: ", t2 - t1) 353 | 354 | # test model performance on the test data at each epoch 355 | # this part is useful to understand model performance and can be commented for production settings 356 | if (testing): 357 | t3 = datetime.datetime.now() 358 | accuracy_test, activations_test = self.predict(x_test, y_test, batch_size) 359 | accuracy_train, activations_train = self.predict(x, y_true, batch_size) 360 | t4 = datetime.datetime.now() 361 | maximum_accuracy = max(maximum_accuracy, accuracy_test) 362 | loss_test = self.loss.loss(y_test, activations_test) 363 | loss_train = self.loss.loss(y_true, activations_train) 364 | metrics[i, 0] = loss_train 365 | metrics[i, 1] = loss_test 366 | metrics[i, 2] = accuracy_train 367 | metrics[i, 3] = accuracy_test 368 | print("Testing time: ", t4 - t3,"; Loss train: ", loss_train, "; Loss test: ", loss_test, "; Accuracy train: ", accuracy_train,"; Accuracy test: ", accuracy_test, 369 | "; Maximum accuracy test: ", maximum_accuracy) 370 | 371 | # save performance metrics values in a file 372 | if (save_filename != ""): 373 | np.savetxt(save_filename, metrics) 374 | 375 | return metrics 376 | 377 | 378 | def predict(self, x_test, y_test, batch_size=1): 379 | """ 380 | :param x_test: (array) Test input 381 | :param y_test: (array) Correct test output 382 | :param batch_size: 383 | :return: (flt) Classification accuracy 384 | :return: (array) A 2D array of shape (n_cases, n_classes). 385 | """ 386 | activations = np.zeros((y_test.shape[0], y_test.shape[1])) 387 | for j in range(x_test.shape[0] // batch_size): 388 | k = j * batch_size 389 | l = (j + 1) * batch_size 390 | _, a_test = self._feed_forward(x_test[k:l]) 391 | activations[k:l] = a_test[self.n_layers] 392 | correctClassification = 0 393 | for j in range(y_test.shape[0]): 394 | if (np.argmax(activations[j]) == np.argmax(y_test[j])): 395 | correctClassification += 1 396 | accuracy = correctClassification / y_test.shape[0] 397 | return accuracy, activations 398 | 399 | def load_fashion_mnist_data(noTrainingSamples,noTestingSamples): 400 | np.random.seed(0) 401 | 402 | data=np.load("data/fashion_mnist.npz") 403 | 404 | indexTrain=np.arange(data["X_train"].shape[0]) 405 | np.random.shuffle(indexTrain) 406 | 407 | indexTest=np.arange(data["X_test"].shape[0]) 408 | np.random.shuffle(indexTest) 409 | 410 | X_train=data["X_train"][indexTrain[0:noTrainingSamples],:] 411 | Y_train=data["Y_train"][indexTrain[0:noTrainingSamples],:] 412 | X_test=data["X_test"][indexTest[0:noTestingSamples],:] 413 | Y_test=data["Y_test"][indexTest[0:noTestingSamples],:] 414 | 415 | #normalize in 0..1 416 | X_train = X_train.astype('float64') / 255. 417 | X_test = X_test.astype('float64') / 255. 418 | 419 | return X_train,Y_train,X_test,Y_test 420 | 421 | if __name__ == "__main__": 422 | 423 | for i in range(1): 424 | #load data 425 | noTrainingSamples=2000 #max 60000 for Fashion MNIST 426 | noTestingSamples = 1000 # max 10000 for Fashion MNIST 427 | X_train, Y_train, X_test, Y_test = load_fashion_mnist_data(noTrainingSamples,noTestingSamples) 428 | 429 | #set model parameters 430 | noHiddenNeuronsLayer=1000 431 | epsilon=13 #set the sparsity level 432 | noTrainingEpochs=400 433 | batchSize=40 434 | dropoutRate=0.2 435 | learningRate=0.05 436 | momentum=0.9 437 | weightDecay=0.0002 438 | 439 | np.random.seed(i) 440 | 441 | # create FixProb-MLP (MLP with static sparse connectivity) 442 | fixprob_mlp = FixProb_MLP((X_train.shape[1], noHiddenNeuronsLayer, noHiddenNeuronsLayer,noHiddenNeuronsLayer, Y_train.shape[1]), (Relu, Relu,Relu, Sigmoid), epsilon=epsilon) 443 | 444 | # train FixProb-MLP 445 | fixprob_mlp.fit(X_train, Y_train, X_test, Y_test, loss=MSE, epochs=noTrainingEpochs, batch_size=batchSize, learning_rate=learningRate, 446 | momentum=momentum, weight_decay=weightDecay, dropoutrate=dropoutRate, testing=True, 447 | save_filename="Results/fixprob_mlp_"+str(noTrainingSamples)+"_training_samples_e"+str(epsilon)+"_rand"+str(i)+".txt") 448 | 449 | # test FixProb-MLP 450 | accuracy, _ = fixprob_mlp.predict(X_test, Y_test, batch_size=1) 451 | 452 | print("\nAccuracy of the last epoch on the testing data: ", accuracy) 453 | -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/plot_input_layer_connectivity.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.animation 3 | import matplotlib.colors as mcolors 4 | import numpy as np 5 | from mpl_toolkits.axes_grid1 import make_axes_locatable 6 | 7 | # update the data for each frame 8 | def anim(n): 9 | global data 10 | global allConnections 11 | data = allConnections[:,:,n] 12 | imobj.set_array(data) 13 | return imobj, 14 | 15 | 16 | 17 | for i in range(1): 18 | data = np.load("data/fashion_mnist.npz") 19 | connections=np.load("Pretrained_results/set_mlp_2000_training_samples_e13_rand"+str(i)+"_input_connections.npz")["inputLayerConnections"] 20 | 21 | allConnections=np.zeros((28,28,len(connections))) 22 | for j in range(len(connections)): 23 | connectionsEpoch=np.reshape(connections[j],(28,28)) 24 | allConnections[:,:,j]=connectionsEpoch 25 | 26 | fig = plt.figure() 27 | fig.suptitle('IJCAI 2019 tutorials\nScalable Deep Learning: from theory to practice', fontsize=14) 28 | 29 | ax1 = fig.add_subplot(121) 30 | ax1.imshow(np.reshape(data["X_train"][1,:],(28,28)),vmin=0,vmax=255,cmap="gray_r",interpolation=None) 31 | ax1.set_title("Fashion-MNIST example") 32 | 33 | ax2 = fig.add_subplot(122) 34 | data=allConnections[:,:,0] 35 | imobj = ax2.imshow(data,vmin=0,vmax=np.max(allConnections),cmap="jet",interpolation=None) 36 | ax2.set_title("Input connectivity pattern evolution\nwith SET-MLP") 37 | 38 | divider = make_axes_locatable(ax2) 39 | cax = divider.append_axes("right", size="5%", pad=0.05) 40 | 41 | cbar=fig.colorbar(imobj,cax=cax) 42 | cbar.set_label('Connections per input neuron (pixel)',size=8) 43 | 44 | fig.tight_layout() 45 | 46 | # create the animation 47 | ani = matplotlib.animation.FuncAnimation(fig, anim, frames=len(connections)) 48 | ani.save("Pretrained_results/fashion_mnist_connections_evolution_per_input_pixel_rand"+str(i)+".gif", writer='imagemagick',fps=24,codec=None) -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/plot_learning_curve.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | epsilon=13 7 | samples=2000 8 | 9 | set_mlp=np.loadtxt("Pretrained_results/set_mlp_"+str(samples)+"_training_samples_e"+str(epsilon)+"_rand0.txt") 10 | fixprob_mlp=np.loadtxt("Pretrained_results/fixprob_mlp_"+str(samples)+"_training_samples_e"+str(epsilon)+"_rand0.txt") 11 | fc_mlp=np.loadtxt("Pretrained_results/fc_mlp_"+str(samples)+"_training_samples_rand0.txt") 12 | 13 | """ 14 | for i in range(1,5): 15 | set_mlp = set_mlp + np.loadtxt( 16 | "Results/set_mlp_" + str(samples) + "_training_samples_e" + str(epsilon) + "_rand" + str(i) + ".txt") 17 | fixprob_mlp = fixprob_mlp + np.loadtxt( 18 | "Results/fixprob_mlp_" + str(samples) + "_training_samples_e" + str(epsilon) + "_rand" + str(i) + "") 19 | fc_mlp = fc_mlp + np.loadtxt("Pretrained_results/fc_mlp_" + str(samples) + "_training_samples_rand" + str(i) + ".txt") 20 | 21 | set_mlp/=5 22 | fixprob_mlp/=5 23 | fc_mlp/=5 24 | """ 25 | font = { 'size' : 9} 26 | fig = plt.figure(figsize=(10,5)) 27 | matplotlib.rc('font', **font) 28 | fig.subplots_adjust(wspace=0.2,hspace=0.05) 29 | 30 | ax1=fig.add_subplot(1,2,1) 31 | ax1.plot(set_mlp[:,2]*100, label="SET-MLP train accuracy", color="r") 32 | ax1.plot(set_mlp[:,3]*100, label="SET-MLP test accuracy", color="b") 33 | ax1.plot(fixprob_mlp[:,2]*100, label="MLP$_{FixProb}$ train accuracy", color="g") 34 | ax1.plot(fixprob_mlp[:,3]*100, label="MLP$_{FixProb}$ test accuracy", color="m") 35 | ax1.plot(fc_mlp[:,2]*100, label="FC-MLP train accuracy", color="y") 36 | ax1.plot(fc_mlp[:,3]*100, label="FC-MLP test accuracy", color="k") 37 | ax1.grid(True) 38 | ax1.set_ylabel("Fashion MNIST\nAccuracy [%]") 39 | ax1.set_xlabel("Epochs [#]") 40 | ax1.legend(loc=4,fontsize=8) 41 | 42 | ax2=fig.add_subplot(1,2,2) 43 | ax2.plot(set_mlp[:,0], label="SET-MLP train loss", color="r") 44 | ax2.plot(set_mlp[:,1], label="SET-MLP test loss", color="b") 45 | ax2.plot(fixprob_mlp[:,0], label="MLP$_{FixProb}$ train loss", color="g") 46 | ax2.plot(fixprob_mlp[:,1], label="MLP$_{FixProb}$ test loss", color="m") 47 | ax2.plot(fc_mlp[:,0], label="FC-MLP train loss", color="y") 48 | ax2.plot(fc_mlp[:,1], label="FC-MLP test loss", color="k") 49 | ax2.grid(True) 50 | ax2.set_ylabel("Loss (MSE)") 51 | ax2.set_xlabel("Epochs [#]") 52 | ax2.legend(loc=1,fontsize=8) 53 | 54 | 55 | plt.savefig("Pretrained_results/mnist_learning_curves_samples"+str(samples)+".pdf", bbox_inches='tight') 56 | 57 | plt.close() -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/sparseoperations.cpython-35m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dcmocanu/sparse-evolutionary-artificial-neural-networks/62ac9748258a06c2bf68c40cdd2f07e9119640dd/Tutorial-IJCAI-2019-Scalable-Deep-Learning/sparseoperations.cpython-35m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /Tutorial-IJCAI-2019-Scalable-Deep-Learning/sparseoperations.pyx: -------------------------------------------------------------------------------- 1 | # compile this file with: "cythonize -a -i sparseoperations.pyx" 2 | # I have tested this method in Linux (Ubuntu). If you compile it in Windows you may need some work around. 3 | 4 | cimport numpy as np 5 | 6 | def backpropagation_updates_Cython(np.ndarray[np.float64_t,ndim=2] a, np.ndarray[np.float64_t,ndim=2] delta, np.ndarray[int,ndim=1] rows, np.ndarray[int,ndim=1] cols,np.ndarray[np.float64_t,ndim=1] out): 7 | cdef: 8 | size_t i,j 9 | double s 10 | for i in range (out.shape[0]): 11 | s=0 12 | for j in range(a.shape[0]): 13 | s+=a[j,rows[i]]*delta[j, cols[i]] 14 | out[i]=s/a.shape[0] 15 | #return out 16 | 17 | --------------------------------------------------------------------------------