├── .gitignore ├── LICENSE ├── README.md ├── autograd ├── README.md ├── demo_kaf_regression.py └── kafnets.py ├── eq1.png ├── eq2.png ├── keras ├── README.md ├── demo_kaf_convolutional.py ├── demo_kaf_feedforward.py └── kafnets.py ├── kernel_activation_functions.png ├── kernel_activation_functions_2D.png ├── pytorch ├── README.md ├── demo_kaf_convolutional.py ├── demo_kaf_feedforward.py └── kafnets.py └── tensorflow ├── README ├── demo_kaf_convolutional.py ├── demo_kaf_feedforward.py └── kafnets.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | 3 | \.idea/ 4 | 5 | pytorch/data/MNIST/raw/train-labels-idx1-ubyte 6 | 7 | pytorch/data/MNIST/processed/test\.pt 8 | 9 | pytorch/data/MNIST/processed/training\.pt 10 | 11 | pytorch/data/MNIST/raw/t10k-images-idx3-ubyte 12 | 13 | pytorch/data/MNIST/raw/t10k-labels-idx1-ubyte 14 | 15 | pytorch/data/MNIST/raw/train-images-idx3-ubyte 16 | 17 | tensorflow/MNIST_data/ 18 | 19 | tensorflow/logs/ 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ispamm Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kernel Activation Functions 2 | 3 | This repository contains several implementations of the kernel activation functions (KAFs) described in the following paper ([link to the preprint](https://arxiv.org/abs/1707.04035)): 4 | 5 | Scardapane, S., Van Vaerenbergh, S., Totaro, S. and Uncini, A., 2019. 6 | Kafnets: Kernel-based non-parametric activation functions for neural networks. 7 | Neural Networks, 110, pp.19-32. 8 | 9 | ## Available implementations 10 | 11 | We currently provide the following stable implementations: 12 | 13 | * [PyTorch](/pytorch): feedforward and convolutional layers, three kernels (Gaussian/ReLU/Softplus), with random initialization or kernel ridge regression. 14 | * [Keras](/keras): same functionalities as the PyTorch implementation. 15 | * [TensorFlow](/tensorflow/): similar to the Keras implementation, but we use the internal tf.keras.Layer and the eager execution in the demos. 16 | * [Autograd](/autograd): only feedforward layers with a Gaussian kernel and random initialization. 17 | 18 | More information for each implementation is given in the corresponding folder. The code should be relatively easy to plug-in in other architectures or projects. 19 | 20 | ## What is a KAF? 21 | 22 | Most neural networks work by interleaving linear projections and simple (fixed) activation functions, like the ReLU function: 23 | 24 |

25 | 26 |

27 | 28 | A KAF is instead a non-parametric activation function defined as a one-dimensional kernel approximator: 29 | 30 |

31 | 32 |

33 | 34 | where: 35 | 36 | * The dictionary of the kernel elements is fixed by sampling the x-axis with a uniform step around 0. 37 | * The user can select the kernel function (e.g., Gaussian, ReLU, Softplus) and the number of kernel elements D. 38 | * The linear coefficients are adapted independently at every neuron via standard back-propagation. 39 | 40 | In addition, the linear coefficients can be initialized using kernel ridge regression to behave similarly to a known function in the beginning of the optimization process. 41 | 42 |

43 |
44 | Fig. 1. Examples of kernel activation functions learned on the Sensorless data set. The KAF after initialization is shown with a dashed red, while the final KAF is shown with a solid green. As a reference, the distribution of activation values after training is shown in light blue. 45 |

46 | 47 |

48 |
49 | Fig. 2. Examples of two-dimensional kernel activation functions learned on the Sensorless data set. 50 |

51 | 52 | ## Contributing 53 | 54 | If you have an implementation for a different framework, or an enhanced version of the current code, feel free to contribute to the repository. For any issues related to the code you can use the issue tracker from GitHub. 55 | 56 | ## Citation 57 | 58 | If you use this code or a derivative thereof in your research, we would appreciate a citation to the original paper: 59 | 60 | @article{scardapane2019kafnets, 61 | title={Kafnets: Kernel-based non-parametric activation functions for neural networks}, 62 | author={Scardapane, Simone and Van Vaerenbergh, Steven and Totaro, Simone and Uncini, Aurelio}, 63 | journal={Neural Networks}, 64 | volume={110}, 65 | pages={19--32}, 66 | year={2019}, 67 | publisher={Elsevier} 68 | } 69 | 70 | ## License 71 | 72 | The code is released under the MIT License. See the attached LICENSE file. 73 | -------------------------------------------------------------------------------- /autograd/README.md: -------------------------------------------------------------------------------- 1 | ## Kernel activation functions (Autograd) 2 | 3 | The code customizes the example from here: 4 | https://github.com/HIPS/autograd/blob/master/examples/neural_net.py 5 | 6 | In the *kafnets* module you can find the code to initialize and run neural networks having KAF activation functions. 7 | Note that this is a regression example and we use custom functions also in the output layer. 8 | For classification, consider changing the final layer to a standard softmax. 9 | 10 | ## Requirements 11 | 12 | * autograd = 1.2. 13 | * scikit-learn = 0.20.1 (for demo_kaf_regression.py) -------------------------------------------------------------------------------- /autograd/demo_kaf_regression.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Imports from Python libraries 4 | import autograd.numpy as np 5 | from autograd import grad 6 | from autograd.misc.optimizers import adam 7 | from sklearn import datasets, preprocessing, model_selection 8 | 9 | # Custom imports 10 | from kafnets import init_kaf_nn, predict_kaf_nn 11 | 12 | # Extends this example: 13 | # https://github.com/HIPS/autograd/blob/master/examples/neural_net.py 14 | 15 | # Set seed for PRNG 16 | np.random.seed(1) 17 | 18 | # Size of the neural network's layers 19 | layers = [13, 10, 1] 20 | 21 | # Batch size 22 | B = 40 23 | 24 | # Load Boston dataset 25 | data = datasets.load_boston() 26 | X = preprocessing.MinMaxScaler(feature_range=(-1, +1)).fit_transform(data['data']) 27 | y = preprocessing.MinMaxScaler(feature_range=(-0.9, +0.9)).fit_transform(data['target'].reshape(-1, 1)) 28 | (X_train, X_test, y_train, y_test) = model_selection.train_test_split(X, y, test_size=0.25) 29 | 30 | # Initialize KAF neural network 31 | w, info = init_kaf_nn(layers) 32 | predict_fcn = lambda w, inputs: predict_kaf_nn(w, inputs, info) 33 | 34 | # Loss function (MSE) 35 | def loss_fcn(params, inputs, targets): 36 | return np.mean(np.square(predict_fcn(params, inputs) - targets)) 37 | 38 | # Iterator over mini-batches 39 | num_batches = int(np.ceil(X_train.shape[0] / B)) 40 | def batch_indices(iter): 41 | idx = iter % num_batches 42 | return slice(idx * B, (idx+1) * B) 43 | 44 | # Define training objective 45 | def objective(params, iter): 46 | idx = batch_indices(iter) 47 | return loss_fcn(params, X_train[idx], y_train[idx]) 48 | 49 | # Get gradient of objective using autograd. 50 | objective_grad = grad(objective) 51 | 52 | # The optimizers provided can optimize lists, tuples, or dicts of parameters 53 | print('Optimizing the network...\n') 54 | w_final = adam(objective_grad, w, num_iters=1000) 55 | 56 | # Compute test accuracy 57 | print('Final test MSE is ', loss_fcn(w_final, X_test, y_test), '\n') -------------------------------------------------------------------------------- /autograd/kafnets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import autograd.numpy as np 4 | 5 | def init_kaf_nn(layer_sizes, scale=0.01, rs=np.random.RandomState(0), dict_size=20, boundary=3.0): 6 | """ 7 | Initialize the parameters of a KAF feedforward network. 8 | - dict_size: the size of the dictionary for every neuron. 9 | - boundary: the boundary for the activation functions. 10 | """ 11 | 12 | # Initialize the dictionary 13 | D = np.linspace(-boundary, boundary, dict_size).reshape(-1, 1) 14 | 15 | # Rule of thumb for gamma 16 | interval = D[1,0] - D[0,0]; 17 | gamma = 0.5/np.square(2*interval) 18 | D = D.reshape(1, 1, -1) 19 | 20 | # Initialize a list of parameters for the layer 21 | w = [(rs.randn(insize, outsize) * scale, # Weight matrix 22 | rs.randn(outsize) * scale, # Bias vector 23 | rs.randn(1, outsize, dict_size) * 0.5) # Mixing coefficients 24 | for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])] 25 | 26 | return w, (D, gamma) 27 | 28 | def predict_kaf_nn(w, X, info): 29 | """ 30 | Compute the outputs of a KAF feedforward network. 31 | """ 32 | 33 | D, gamma = info 34 | for W, b, alpha in w: 35 | outputs = np.dot(X, W) + b 36 | K = gauss_kernel(outputs, D, gamma) 37 | X = np.sum(K*alpha, axis=2) 38 | return X 39 | 40 | def gauss_kernel(X, D, gamma=1.0): 41 | """ 42 | Compute the 1D Gaussian kernel between all elements of a 43 | NxH matrix and a fixed L-dimensional dictionary, resulting in a NxHxL matrix of kernel 44 | values. 45 | """ 46 | return np.exp(- gamma*np.square(X.reshape(-1, X.shape[1], 1) - D)) -------------------------------------------------------------------------------- /eq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/kernel-activation-functions/e525f7e82b54508ac262f82c9557ad66be2732f3/eq1.png -------------------------------------------------------------------------------- /eq2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/kernel-activation-functions/e525f7e82b54508ac262f82c9557ad66be2732f3/eq2.png -------------------------------------------------------------------------------- /keras/README.md: -------------------------------------------------------------------------------- 1 | ## Kernel activation functions (Keras) 2 | 3 | In the *kafnets* module you can find the modules for defining KAF layers, both for feedforward networks and convolutional networks (using the flag 'conv' during initialization). 4 | The code has two demos to showcase the modules using the Keras Sequential model. 5 | 6 | ## Requirements 7 | 8 | * keras = 2.2.4 9 | * numpy = 1.15.4 10 | -------------------------------------------------------------------------------- /keras/demo_kaf_convolutional.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Simple demo using kernel activation functions with convolutional networks on the MNIST dataset. 5 | """ 6 | 7 | # Keras imports 8 | from keras import datasets 9 | from keras.models import Sequential 10 | from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten 11 | from keras.utils import to_categorical 12 | import keras.backend as K 13 | 14 | # Custom imports 15 | from kafnets import KAF 16 | 17 | # Load Breast Cancer dataset 18 | (X_train, y_train), (X_test, y_test) = datasets.mnist.load_data() 19 | 20 | # Preprocessing is taken from here: 21 | # https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py 22 | if K.image_data_format() == 'channels_first': 23 | X_train = X_train.reshape(X_train.shape[0], 1, 28, 28) 24 | X_test = X_test.reshape(X_test.shape[0], 1, 28, 28) 25 | else: 26 | X_train = X_train.reshape(X_train.shape[0], 28, 28, 1) 27 | X_test = X_test.reshape(X_test.shape[0], 28, 28, 1) 28 | 29 | X_train = X_train.astype('float32') 30 | X_test = X_test.astype('float32') 31 | X_train /= 255 32 | X_test /= 255 33 | 34 | # convert class vectors to binary class matrices 35 | y_train = to_categorical(y_train, 10) 36 | y_test = to_categorical(y_test, 10) 37 | 38 | # Initialize a KAF neural network 39 | kafnet = Sequential() 40 | kafnet.add(Conv2D(32, (3, 3), input_shape=(28, 28, 1))) 41 | kafnet.add(KAF(32, conv=True)) 42 | kafnet.add(Conv2D(32, (3, 3))) 43 | kafnet.add(KAF(32, conv=True)) 44 | kafnet.add(MaxPooling2D(pool_size=(2, 2))) 45 | kafnet.add(Flatten()) 46 | kafnet.add(Dense(100)) 47 | kafnet.add(KAF(100)) 48 | kafnet.add(Dense(10, activation='softmax')) 49 | 50 | # Training 51 | kafnet.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 52 | kafnet.summary() 53 | kafnet.fit(X_train, y_train, epochs=5, batch_size=32, verbose=1) 54 | 55 | # Evaluation 56 | print('Final accuracy is: ' + str(kafnet.evaluate(X_test, y_test, batch_size=64)[1])) -------------------------------------------------------------------------------- /keras/demo_kaf_feedforward.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Simple demo using kernel activation functions on a basic regression dataset. 5 | """ 6 | 7 | # Keras imports 8 | from keras import datasets 9 | from keras.models import Sequential 10 | from keras.layers import Dense 11 | 12 | # Custom imports 13 | from kafnets import KAF 14 | 15 | # Load Breast Cancer dataset 16 | (X_train, y_train), (X_test, y_test) = datasets.boston_housing.load_data() 17 | 18 | # Initialize a KAF neural network 19 | kafnet = Sequential([ 20 | Dense(20, input_shape=(13,)), 21 | KAF(20), 22 | Dense(1), 23 | ]) 24 | 25 | #Uncomment to use KAF with Softplus kernel 26 | #kafnet = Sequential([ 27 | # Dense(20, input_shape=(13,)), 28 | # KAF(20, kernel='softplus', D=5), 29 | # Dense(1), 30 | #]) 31 | 32 | # Training 33 | kafnet.compile(optimizer='adam', loss='mse') 34 | kafnet.summary() 35 | kafnet.fit(X_train, y_train, epochs=250, batch_size=32, verbose=0) 36 | 37 | # Evaluation 38 | print('Final error is: ' + str(kafnet.evaluate(X_test, y_test, batch_size=64))) -------------------------------------------------------------------------------- /keras/kafnets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.layers import Layer 3 | from keras import backend as K 4 | 5 | class KAF(Layer): 6 | """ Implementation of the kernel activation function. 7 | 8 | Parameters 9 | ---------- 10 | num_parameters: int 11 | Size of the layer (number of neurons). 12 | D: int, optional 13 | Size of the dictionary for each neuron. Default to 20. 14 | conv: bool, optional 15 | True if this is a convolutive layer, False for a feedforward layer. Default to False. 16 | boundary: float, optional 17 | Dictionary elements are sampled uniformly in [-boundary, boundary]. Default to 4.0. 18 | init_fcn: None or func, optional 19 | If None, elements are initialized randomly. Otherwise, elements are initialized to approximate given function. 20 | kernel: {'gauss', 'relu', 'softplus'}, optional 21 | Kernel function to be used. Defaults to 'gaussian'. 22 | 23 | Example 24 | ---------- 25 | Neural network with one hidden layer with KAF nonlinearities: 26 | 27 | >>> net = Sequential([Dense(10), KAF(10), Dense(10, 1)]) 28 | 29 | References 30 | ---------- 31 | [1] Scardapane, S., Van Vaerenbergh, S., Totaro, S. and Uncini, A., 2019. 32 | Kafnets: kernel-based non-parametric activation functions for neural networks. 33 | Neural Networks, 110, pp. 19-32. 34 | [2] Marra, G., Zanca, D., Betti, A. and Gori, M., 2018. 35 | Learning Neuron Non-Linearities with Kernel-Based Deep Neural Networks. 36 | arXiv preprint arXiv:1807.06302. 37 | """ 38 | 39 | def __init__(self, num_parameters, D=20, boundary=3.0, conv=False, init_fcn=None, kernel='gaussian', **kwargs): 40 | self.num_parameters = num_parameters 41 | self.D = D 42 | self.boundary = boundary 43 | self.init_fcn = init_fcn 44 | self.conv = conv 45 | if self.conv: 46 | self.unsqueeze_dim = 4 47 | else: 48 | self.unsqueeze_dim = 2 49 | self.kernel = kernel 50 | if not (kernel in ['gaussian', 'relu', 'softplus']): 51 | raise ValueError('Kernel not recognized (must be {gaussian, relu, softplus})') 52 | super().__init__(**kwargs) 53 | 54 | def build(self, input_shape): 55 | 56 | # Initialize the fixed dictionary 57 | d = np.linspace(-self.boundary, self.boundary, self.D).astype(np.float32).reshape(-1, 1) 58 | 59 | if self.conv: 60 | self.dict = self.add_weight(name='dict', 61 | shape=(1, 1, 1, 1, self.D), 62 | initializer='uniform', 63 | trainable=False) 64 | K.set_value(self.dict, d.reshape(1, 1, 1, 1, -1)) 65 | else: 66 | self.dict = self.add_weight(name='dict', 67 | shape=(1, 1, self.D), 68 | initializer='uniform', 69 | trainable=False) 70 | K.set_value(self.dict, d.reshape(1, 1, -1)) 71 | 72 | if self.kernel == 'gaussian': 73 | self.kernel_fcn = self.gaussian_kernel 74 | # Rule of thumb for gamma 75 | interval = (d[1] - d[0]) 76 | sigma = 2 * interval # empirically chosen 77 | self.gamma = 0.5 / np.square(sigma) 78 | elif self.kernel == 'softplus': 79 | self.kernel_fcn = self.softplus_kernel 80 | else: 81 | self.kernel_fcn = self.relu_kernel 82 | 83 | 84 | # Mixing coefficients 85 | if self.conv: 86 | self.alpha = self.add_weight(name='alpha', 87 | shape=(1, 1, 1, self.num_parameters, self.D), 88 | initializer='normal', 89 | trainable=True) 90 | else: 91 | self.alpha = self.add_weight(name='alpha', 92 | shape=(1, self.num_parameters, self.D), 93 | initializer='normal', 94 | trainable=True) 95 | 96 | # Optional initialization with kernel ridge regression 97 | if self.init_fcn is not None: 98 | if self.kernel == 'gaussian': 99 | kernel_matrix = np.exp(- self.gamma*(d - d.T) ** 2) 100 | elif self.kernel == 'softplus': 101 | kernel_matrix = np.log(np.exp(d - d.T) + 1.0) 102 | else: 103 | raise ValueError('Cannot perform kernel ridge regression with ReLU kernel (singular matrix)') 104 | 105 | alpha_init = np.linalg.solve(kernel_matrix + 1e-5*np.eye(self.D), self.init_fcn(d)).reshape(-1) 106 | if self.conv: 107 | K.set_value(self.alpha, np.repeat(alpha_init.reshape(1, 1, 1, 1, -1), self.num_parameters, axis=3)) 108 | else: 109 | K.set_value(self.alpha, np.repeat(alpha_init.reshape(1, 1, -1), self.num_parameters, axis=1)) 110 | 111 | super(KAF, self).build(input_shape) 112 | 113 | def gaussian_kernel(self, x): 114 | return K.exp(- self.gamma * (K.expand_dims(x, axis=self.unsqueeze_dim) - self.dict) ** 2.0) 115 | 116 | def softplus_kernel(self, x): 117 | return K.softplus(K.expand_dims(x, axis=self.unsqueeze_dim) - self.dict) 118 | 119 | def relu_kernel(self, x): 120 | return K.relu(K.expand_dims(x, axis=self.unsqueeze_dim) - self.dict) 121 | 122 | def call(self, x): 123 | kernel_matrix = self.kernel_fcn(x) 124 | return K.sum(kernel_matrix * self.alpha, axis=self.unsqueeze_dim) 125 | 126 | def get_config(self): 127 | return {'num_parameters': self.num_parameters, 128 | 'D': self.D, 129 | 'boundary': self.boundary, 130 | 'conv': self.conv, 131 | 'init_fcn': self.init_fcn, 132 | 'kernel': self.kernel 133 | } 134 | -------------------------------------------------------------------------------- /kernel_activation_functions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/kernel-activation-functions/e525f7e82b54508ac262f82c9557ad66be2732f3/kernel_activation_functions.png -------------------------------------------------------------------------------- /kernel_activation_functions_2D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ispamm/kernel-activation-functions/e525f7e82b54508ac262f82c9557ad66be2732f3/kernel_activation_functions_2D.png -------------------------------------------------------------------------------- /pytorch/README.md: -------------------------------------------------------------------------------- 1 | ## Kernel activation functions (PyTorch) 2 | 3 | In the *kafnets* module you can find the modules for defining KAF layers, both for feedforward networks and convolutional networks (using the flag 'conv' during initialization). 4 | The code has two demos to showcase the modules using the PyTorch sequential class. 5 | 6 | ## Requirements 7 | 8 | * pytorch = 1.0.1 9 | * numpy = 1.15.4 10 | * tqdm = 4.28.1 (for demo_kaf_convolutional.py) 11 | * scikit-learn = 0.20.1 (for demo_feedforward.py) -------------------------------------------------------------------------------- /pytorch/demo_kaf_convolutional.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Simple demo using kernel activation functions with convolutional layers on the MNIST dataset. 5 | """ 6 | 7 | # Imports from Python libraries 8 | import numpy as np 9 | import tqdm 10 | 11 | # PyTorch imports 12 | import torch 13 | import torch.utils.data 14 | from torchvision import datasets, transforms 15 | from torch.nn import Module 16 | 17 | # Custom imports 18 | from kafnets import KAF 19 | 20 | # Set seed for PRNG 21 | np.random.seed(1) 22 | torch.manual_seed(1) 23 | 24 | # Enable CUDA (optional) 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | # Load MNIST dataset 28 | train_loader = torch.utils.data.DataLoader(datasets.MNIST('data/MNIST', train=True, download=True, 29 | transform=transforms.Compose([transforms.ToTensor(), 30 | transforms.Normalize((0.1307,), (0.3081,))])), 31 | batch_size=32, shuffle=True) 32 | test_loader = torch.utils.data.DataLoader(datasets.MNIST('data/MNIST', train=False, transform=transforms.Compose( 33 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), batch_size=32, shuffle=True) 34 | 35 | class Flatten(Module): 36 | """ 37 | Simple flatten module, see this discussion: 38 | https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983 39 | """ 40 | def forward(self, input): 41 | return input.view(input.size(0), -1) 42 | 43 | # Initialize a KAF neural network 44 | kafnet = torch.nn.Sequential( 45 | torch.nn.Conv2d(1, 20, kernel_size=5, padding=(2,2)), 46 | torch.nn.MaxPool2d(3), 47 | KAF(20, conv=True), 48 | torch.nn.Conv2d(20, 20, kernel_size=5, padding=(2,2)), 49 | torch.nn.MaxPool2d(3), 50 | KAF(20, conv=True), 51 | Flatten(), 52 | torch.nn.Linear(180, 10), 53 | ) 54 | 55 | # Reset parameters 56 | for m in kafnet: 57 | if len(m._parameters) > 0: 58 | m.reset_parameters() 59 | 60 | print('Training: **KAFNET**', flush=True) 61 | 62 | # Loss function 63 | loss_fn = torch.nn.CrossEntropyLoss() 64 | 65 | # Build optimizer 66 | optimizer = torch.optim.Adam(kafnet.parameters(), weight_decay=1e-4) 67 | 68 | # Put model on GPU if needed 69 | kafnet.to(device) 70 | 71 | max_epochs = 10 72 | for idx_epoch in range(max_epochs): 73 | 74 | print('Epoch #', idx_epoch, ' of #', max_epochs) 75 | kafnet.train() 76 | 77 | for (X_batch, y_batch) in tqdm.tqdm(train_loader): 78 | 79 | # Eventually move mini-batch to GPU 80 | X_batch, y_batch = X_batch.to(device), y_batch.to(device) 81 | 82 | # Forward pass: compute predicted y by passing x to the model. 83 | y_pred = kafnet(X_batch) 84 | 85 | # Compute loss. 86 | loss = loss_fn(y_pred, y_batch) 87 | 88 | # Zeroes out all gradients 89 | optimizer.zero_grad() 90 | 91 | # Backward pass 92 | loss.backward() 93 | 94 | # Update parameters 95 | optimizer.step() 96 | 97 | # Compute final test score 98 | with torch.no_grad(): 99 | print('Computing test score for: **KAFNET**', flush=True) 100 | kafnet.eval() 101 | acc = 0 102 | for _, (X_batch, y_batch) in enumerate(test_loader): 103 | # Eventually move mini-batch to GPU 104 | X_batch = X_batch.to(device) 105 | acc += np.sum(y_batch.numpy() == np.argmax(kafnet(X_batch).cpu().numpy(), axis=1)) 106 | print('Final score on test set: ', acc / test_loader.dataset.__len__()) 107 | -------------------------------------------------------------------------------- /pytorch/demo_kaf_feedforward.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Simple demo using kernel activation functions on a basic classification dataset. 5 | """ 6 | 7 | # Imports from Python libraries 8 | import numpy as np 9 | from sklearn import datasets, preprocessing, model_selection 10 | 11 | # PyTorch imports 12 | import torch 13 | from torch.utils.data import TensorDataset 14 | from torch.utils.data import DataLoader 15 | 16 | # Custom imports 17 | from kafnets import KAF 18 | 19 | # Set seed for PRNG 20 | np.random.seed(1) 21 | torch.manual_seed(1) 22 | 23 | # Batch size 24 | B = 40 25 | 26 | # Load Breast Cancer dataset 27 | data = datasets.load_breast_cancer() 28 | X = preprocessing.MinMaxScaler(feature_range=(-1, +1)).fit_transform(data['data']).astype(np.float32) 29 | (X_train, X_test, y_train, y_test) = model_selection.train_test_split(X, data['target'].astype(np.float32).reshape(-1, 1), test_size=0.25) 30 | 31 | # Load in PyTorch data loader 32 | data_train = DataLoader(TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)), shuffle=True, batch_size=64) 33 | data_test = DataLoader(TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test)), batch_size=100) 34 | 35 | # Initialize a KAF neural network 36 | kafnet = torch.nn.Sequential( 37 | torch.nn.Linear(30, 20), 38 | KAF(20), 39 | torch.nn.Linear(20, 1), 40 | ) 41 | 42 | # Uncomment to use KAF with custom initialization 43 | #kafnet = torch.nn.Sequential( 44 | # torch.nn.Linear(30, 20), 45 | # KAF(20, init_fcn=np.tanh), 46 | # torch.nn.Linear(20, 1), 47 | #) 48 | 49 | #Uncomment to use KAF with Softplus kernel 50 | #kafnet = torch.nn.Sequential( 51 | # torch.nn.Linear(30, 20), 52 | # KAF(20, kernel='softplus'), 53 | # torch.nn.Linear(20, 1), 54 | #) 55 | 56 | # Reset parameters 57 | for m in kafnet: 58 | if len(m._parameters) > 0: 59 | m.reset_parameters() 60 | 61 | print('Training: **KAFNET**', flush=True) 62 | 63 | # Loss function 64 | loss_fn = torch.nn.BCEWithLogitsLoss() 65 | 66 | # Build optimizer 67 | optimizer = torch.optim.Adam(kafnet.parameters(), weight_decay=1e-4) 68 | 69 | for idx_epoch in range(100): 70 | 71 | kafnet.train() 72 | 73 | for _, (X_batch, y_batch) in enumerate(data_train): 74 | 75 | # Forward pass: compute predicted y by passing x to the model. 76 | y_pred = kafnet(X_batch) 77 | 78 | # Compute loss. 79 | loss = loss_fn(y_pred, y_batch) 80 | 81 | # Zeroes out all gradients 82 | optimizer.zero_grad() 83 | 84 | # Backward pass 85 | loss.backward() 86 | 87 | # Update parameters 88 | optimizer.step() 89 | 90 | with torch.no_grad(): 91 | # Compute final test score 92 | print('Computing test score for: **KAFNET**', flush=True) 93 | kafnet.eval() 94 | acc = 0 95 | for _, (X_batch, y_batch) in enumerate(data_test): 96 | acc += np.sum(y_batch.numpy() == np.round(torch.sigmoid(kafnet(X_batch)).numpy())) 97 | print('Final score on test set: ', acc/data_test.dataset.__len__()) 98 | -------------------------------------------------------------------------------- /pytorch/kafnets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.init import normal_ 8 | import torch.nn.functional as F 9 | 10 | 11 | class KAF(nn.Module): 12 | """ Implementation of the kernel activation function. 13 | 14 | Parameters 15 | ---------- 16 | num_parameters: int 17 | Size of the layer (number of neurons). 18 | D: int, optional 19 | Size of the dictionary for each neuron. Default to 20. 20 | conv: bool, optional 21 | True if this is a convolutive layer, False for a feedforward layer. Default to False. 22 | boundary: float, optional 23 | Dictionary elements are sampled uniformly in [-boundary, boundary]. Default to 4.0. 24 | init_fcn: None or func, optional 25 | If None, elements are initialized randomly. Otherwise, elements are initialized to approximate given function. 26 | kernel: {'gauss', 'relu', 'softplus'}, optional 27 | Kernel function to be used. Defaults to 'gaussian'. 28 | 29 | Example 30 | ---------- 31 | Neural network with one hidden layer with KAF nonlinearities: 32 | 33 | >>> net = Sequential([nn.Linear(10, 20), KAF(20), nn.Linear(20, 1)]) 34 | 35 | References 36 | ---------- 37 | [1] Scardapane, S., Van Vaerenbergh, S., Totaro, S. and Uncini, A., 2019. 38 | Kafnets: kernel-based non-parametric activation functions for neural networks. 39 | Neural Networks, 110, pp. 19-32. 40 | [2] Marra, G., Zanca, D., Betti, A. and Gori, M., 2018. 41 | Learning Neuron Non-Linearities with Kernel-Based Deep Neural Networks. 42 | arXiv preprint arXiv:1807.06302. 43 | """ 44 | 45 | def __init__(self, num_parameters, D=20, conv=False, boundary=4.0, init_fcn=None, kernel='gaussian'): 46 | 47 | super().__init__() 48 | self.num_parameters, self.D, self.conv = num_parameters, D, conv 49 | 50 | # Initialize the dictionary (NumPy) 51 | self.dict_numpy = np.linspace(-boundary, boundary, self.D).astype(np.float32).reshape(-1, 1) 52 | 53 | # Save the dictionary 54 | if self.conv: 55 | self.register_buffer('dict', torch.from_numpy(self.dict_numpy).view(1, 1, 1, 1, -1)) 56 | self.unsqueeze_dim = 4 57 | else: 58 | self.register_buffer('dict', torch.from_numpy(self.dict_numpy).view(1, -1)) 59 | self.unsqueeze_dim = 2 60 | 61 | # Select appropriate kernel function 62 | if not (kernel in ['gaussian', 'relu', 'softplus']): 63 | raise ValueError('Kernel not recognized (must be {gaussian, relu, softplus})') 64 | 65 | if kernel == 'gaussian': 66 | self.kernel_fcn = self.gaussian_kernel 67 | # Rule of thumb for gamma (only needed for Gaussian kernel) 68 | interval = (self.dict_numpy[1] - self.dict_numpy[0]) 69 | sigma = 2 * interval # empirically chosen 70 | self.gamma_init = float(0.5 / np.square(sigma)) 71 | 72 | # Initialize gamma 73 | if self.conv: 74 | self.register_buffer('gamma', torch.from_numpy(np.ones((1, 1, 1, 1, self.D), dtype=np.float32)*self.gamma_init)) 75 | else: 76 | self.register_buffer('gamma', torch.from_numpy(np.ones((1, 1, self.D), dtype=np.float32)*self.gamma_init)) 77 | 78 | elif kernel == 'relu': 79 | self.kernel_fcn = self.relu_kernel 80 | else: 81 | self.kernel_fcn = self.softplus_kernel 82 | 83 | # Initialize mixing coefficients 84 | if self.conv: 85 | self.alpha = Parameter(torch.FloatTensor(1, self.num_parameters, 1, 1, self.D)) 86 | else: 87 | self.alpha = Parameter(torch.FloatTensor(1, self.num_parameters, self.D)) 88 | 89 | # Eventually: initialization with kernel ridge regression 90 | self.init_fcn = init_fcn 91 | if init_fcn != None: 92 | 93 | if kernel == 'gaussian': 94 | K = np.exp(- self.gamma_init*(self.dict_numpy - self.dict_numpy.T) ** 2) 95 | elif kernel == 'softplus': 96 | K = np.log(np.exp(self.dict_numpy - self.dict_numpy.T) + 1.0) 97 | else: 98 | #K = np.maximum(self.dict_numpy - self.dict_numpy.T, 0) 99 | raise ValueError('Cannot perform kernel ridge regression with ReLU kernel (singular matrix)') 100 | 101 | self.alpha_init = np.linalg.solve(K + 1e-4 * np.eye(self.D), self.init_fcn(self.dict_numpy)).reshape(-1).astype(np.float32) 102 | 103 | else: 104 | self.alpha_init = None 105 | 106 | # Reset the parameters 107 | self.reset_parameters() 108 | 109 | def reset_parameters(self): 110 | if self.init_fcn != None: 111 | if self.conv: 112 | self.alpha.data = torch.from_numpy(self.alpha_init).repeat(1, self.num_parameters, 1, 1, 1) 113 | else: 114 | self.alpha.data = torch.from_numpy(self.alpha_init).repeat(1, self.num_parameters, 1) 115 | else: 116 | normal_(self.alpha.data, std=0.8) 117 | 118 | def gaussian_kernel(self, input): 119 | return torch.exp(- torch.mul((torch.add(input.unsqueeze(self.unsqueeze_dim), - self.dict))**2, self.gamma)) 120 | 121 | def relu_kernel(self, input): 122 | return F.relu(input.unsqueeze(self.unsqueeze_dim) - self.dict) 123 | 124 | def softplus_kernel(self, input): 125 | return F.softplus(input.unsqueeze(self.unsqueeze_dim) - self.dict) 126 | 127 | def forward(self, input): 128 | K = self.kernel_fcn(input) 129 | y = torch.sum(K*self.alpha, self.unsqueeze_dim) 130 | return y 131 | 132 | def __repr__(self): 133 | return self.__class__.__name__ + ' (' \ 134 | + str(self.num_parameters) + ')' -------------------------------------------------------------------------------- /tensorflow/README: -------------------------------------------------------------------------------- 1 | ## Kernel activation functions (TensorFlow) 2 | 3 | In the *kafnets* module you can find the modules for defining KAF layers, both for feedforward networks and convolutional networks (using the flag 'conv' during initialization). 4 | The code has two demos to showcase the modules using TensorFlow layers. 5 | 6 | ## Requirements 7 | 8 | * tensorflow = 1.1.13 9 | * numpy = 1.15.4 -------------------------------------------------------------------------------- /tensorflow/demo_kaf_convolutional.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Simple demo using kernel activation functions with convolutional networks on the MNIST dataset. 5 | """ 6 | 7 | # Import TensorFlow 8 | import numpy as np 9 | import tensorflow as tf 10 | import tensorflow.contrib.eager as tfe 11 | tf.enable_eager_execution() 12 | 13 | # Keras imports 14 | from tensorflow.keras import datasets 15 | from tensorflow.keras.models import Sequential 16 | from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten 17 | 18 | # Custom imports 19 | from kafnets import KAF 20 | import tqdm 21 | 22 | # Load Breast Cancer dataset 23 | (X_train, y_train), (X_test, y_test) = datasets.mnist.load_data() 24 | 25 | # Preprocessing is taken from here: 26 | # https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py 27 | X_train = X_train.reshape(X_train.shape[0], 28, 28, 1) 28 | X_test = X_test.reshape(X_test.shape[0], 28, 28, 1) 29 | 30 | X_train = X_train.astype('float32') 31 | X_test = X_test.astype('float32') 32 | X_train /= 255 33 | X_test /= 255 34 | 35 | 36 | # Initialize a KAF neural network 37 | kafnet = Sequential() 38 | kafnet.add(Conv2D(32, (3, 3), input_shape=(28, 28, 1))) 39 | kafnet.add(KAF(32, conv=True)) 40 | kafnet.add(Conv2D(32, (3, 3))) 41 | kafnet.add(KAF(32, conv=True)) 42 | kafnet.add(MaxPooling2D(pool_size=(2, 2))) 43 | kafnet.add(Flatten()) 44 | kafnet.add(Dense(100)) 45 | kafnet.add(KAF(100)) 46 | kafnet.add(Dense(10, activation='softmax')) 47 | 48 | # Use tf.data DataLoader 49 | train_data = tf.data.Dataset.from_tensor_slices((X_train.astype(np.float32), y_train.astype(np.int64))) 50 | test_data = tf.data.Dataset.from_tensor_slices((X_test.astype(np.float32), y_test.astype(np.int64))) 51 | 52 | # Optimizer 53 | opt = tf.train.AdamOptimizer() 54 | 55 | # Training 56 | for e in tqdm.trange(5, desc='Training'): 57 | 58 | for xb, yb in train_data.shuffle(1000).batch(32): 59 | 60 | with tfe.GradientTape() as tape: 61 | loss = tf.losses.sparse_softmax_cross_entropy(yb, kafnet(xb)) 62 | g = tape.gradient(loss, kafnet.variables) 63 | opt.apply_gradients(zip(g, kafnet.variables)) 64 | 65 | # Evaluation 66 | acc = tfe.metrics.Accuracy() 67 | for xb, yb in test_data.batch(32): 68 | acc(yb, tf.argmax(kafnet(xb), axis=1)) 69 | tqdm.tqdm.write('Test accuracy after epoch {} is: '.format(e+1) + str(acc.result())) -------------------------------------------------------------------------------- /tensorflow/demo_kaf_feedforward.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Simple demo using kernel activation functions on a basic regression dataset. 5 | """ 6 | 7 | # Import TensorFlow 8 | import numpy as np 9 | import tensorflow as tf 10 | import tensorflow.contrib.eager as tfe 11 | tf.enable_eager_execution() 12 | 13 | # Keras imports 14 | from tensorflow.keras import datasets 15 | from tensorflow.keras.models import Sequential 16 | from tensorflow.keras.layers import Dense 17 | 18 | # Custom imports 19 | from kafnets import KAF 20 | import tqdm 21 | 22 | # Load Breast Cancer dataset 23 | (X_train, y_train), (X_test, y_test) = datasets.boston_housing.load_data() 24 | 25 | # Initialize a KAF neural network 26 | kafnet = Sequential([ 27 | Dense(20, input_shape=(13,)), 28 | KAF(20), 29 | Dense(1), 30 | ]) 31 | 32 | #Uncomment to use KAF with Softplus kernel 33 | #kafnet = Sequential([ 34 | # Dense(20, input_shape=(13,)), 35 | # KAF(20, kernel='softplus', D=5), 36 | # Dense(1), 37 | #]) 38 | 39 | # Use tf.data DataLoader 40 | train_data = tf.data.Dataset.from_tensor_slices((X_train.astype(np.float32), y_train.reshape(-1, 1))) 41 | test_data = tf.data.Dataset.from_tensor_slices((X_test.astype(np.float32), y_test.astype(np.float32).reshape(-1, 1))) 42 | 43 | # Optimizer 44 | opt = tf.train.AdamOptimizer() 45 | 46 | # Training 47 | for e in tqdm.trange(300, desc='Training'): 48 | 49 | for xb, yb in train_data.shuffle(1000).batch(32): 50 | 51 | with tfe.GradientTape() as tape: 52 | loss = tf.losses.mean_squared_error(yb, kafnet(xb)) 53 | g = tape.gradient(loss, kafnet.variables) 54 | opt.apply_gradients(zip(g, kafnet.variables)) 55 | 56 | # Evaluation 57 | err = tfe.metrics.Mean() 58 | for xb, yb in test_data.batch(32): 59 | err((yb - kafnet(xb))**2) 60 | print('Final error is: ' + str(err.result())) 61 | -------------------------------------------------------------------------------- /tensorflow/kafnets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.layers import Layer 3 | import tensorflow as tf 4 | 5 | class KAF(Layer): 6 | """ Implementation of the kernel activation function. 7 | 8 | Parameters 9 | ---------- 10 | num_parameters: int 11 | Size of the layer (number of neurons). 12 | D: int, optional 13 | Size of the dictionary for each neuron. Default to 20. 14 | conv: bool, optional 15 | True if this is a convolutive layer, False for a feedforward layer. Default to False. 16 | boundary: float, optional 17 | Dictionary elements are sampled uniformly in [-boundary, boundary]. Default to 4.0. 18 | init_fcn: None or func, optional 19 | If None, elements are initialized randomly. Otherwise, elements are initialized to approximate given function. 20 | kernel: {'gauss', 'relu', 'softplus'}, optional 21 | Kernel function to be used. Defaults to 'gaussian'. 22 | 23 | Example 24 | ---------- 25 | Neural network with one hidden layer with KAF nonlinearities: 26 | 27 | >>> net = Sequential([Dense(10), KAF(10), Dense(10, 1)]) 28 | 29 | References 30 | ---------- 31 | [1] Scardapane, S., Van Vaerenbergh, S., Totaro, S. and Uncini, A., 2019. 32 | Kafnets: kernel-based non-parametric activation functions for neural networks. 33 | Neural Networks, 110, pp. 19-32. 34 | [2] Marra, G., Zanca, D., Betti, A. and Gori, M., 2018. 35 | Learning Neuron Non-Linearities with Kernel-Based Deep Neural Networks. 36 | arXiv preprint arXiv:1807.06302. 37 | """ 38 | 39 | def __init__(self, num_parameters, D=20, boundary=3.0, conv=False, init_fcn=None, kernel='gaussian', **kwargs): 40 | self.num_parameters = num_parameters 41 | self.D = D 42 | self.boundary = boundary 43 | self.init_fcn = init_fcn 44 | self.conv = conv 45 | if self.conv: 46 | self.unsqueeze_dim = 4 47 | else: 48 | self.unsqueeze_dim = 2 49 | self.kernel = kernel 50 | if not (kernel in ['gaussian', 'relu', 'softplus']): 51 | raise ValueError('Kernel not recognized (must be {gaussian, relu, softplus})') 52 | super().__init__(**kwargs) 53 | 54 | def build(self, input_shape): 55 | 56 | # Initialize the fixed dictionary 57 | d = np.linspace(-self.boundary, self.boundary, self.D).astype(np.float32).reshape(-1, 1) 58 | 59 | if self.conv: 60 | self.dict = self.add_weight(name='dict', 61 | shape=(1, 1, 1, 1, self.D), 62 | initializer='uniform', 63 | trainable=False) 64 | tf.assign(self.dict, d.reshape(1, 1, 1, 1, -1)) 65 | else: 66 | self.dict = self.add_weight(name='dict', 67 | shape=(1, 1, self.D), 68 | initializer='uniform', 69 | trainable=False) 70 | tf.assign(self.dict, d.reshape(1, 1, -1)) 71 | 72 | if self.kernel == 'gaussian': 73 | self.kernel_fcn = self.gaussian_kernel 74 | # Rule of thumb for gamma 75 | interval = (d[1] - d[0]) 76 | sigma = 2 * interval # empirically chosen 77 | self.gamma = 0.5 / np.square(sigma) 78 | elif self.kernel == 'softplus': 79 | self.kernel_fcn = self.softplus_kernel 80 | else: 81 | self.kernel_fcn = self.relu_kernel 82 | 83 | 84 | # Mixing coefficients 85 | if self.conv: 86 | self.alpha = self.add_weight(name='alpha', 87 | shape=(1, 1, 1, self.num_parameters, self.D), 88 | initializer='normal', 89 | trainable=True) 90 | else: 91 | self.alpha = self.add_weight(name='alpha', 92 | shape=(1, self.num_parameters, self.D), 93 | initializer='normal', 94 | trainable=True) 95 | 96 | # Optional initialization with kernel ridge regression 97 | if self.init_fcn is not None: 98 | if self.kernel == 'gaussian': 99 | kernel_matrix = np.exp(- self.gamma*(d - d.T) ** 2) 100 | elif self.kernel == 'softplus': 101 | kernel_matrix = np.log(np.exp(d - d.T) + 1.0) 102 | else: 103 | raise ValueError('Cannot perform kernel ridge regression with ReLU kernel (singular matrix)') 104 | 105 | alpha_init = np.linalg.solve(kernel_matrix + 1e-5*np.eye(self.D), self.init_fcn(d)).reshape(-1) 106 | if self.conv: 107 | tf.assign(self.alpha, np.repeat(alpha_init.reshape(1, 1, 1, 1, -1), self.num_parameters, axis=3)) 108 | else: 109 | tf.assign(self.alpha, np.repeat(alpha_init.reshape(1, 1, -1), self.num_parameters, axis=1)) 110 | 111 | super(KAF, self).build(input_shape) 112 | 113 | def gaussian_kernel(self, x): 114 | return tf.exp(- self.gamma * (tf.expand_dims(x, axis=self.unsqueeze_dim) - self.dict) ** 2.0) 115 | 116 | def softplus_kernel(self, x): 117 | return tf.softplus(tf.expand_dims(x, axis=self.unsqueeze_dim) - self.dict) 118 | 119 | def relu_kernel(self, x): 120 | return tf.relu(tf.expand_dims(x, axis=self.unsqueeze_dim) - self.dict) 121 | 122 | def call(self, x): 123 | kernel_matrix = self.kernel_fcn(x) 124 | return tf.reduce_sum(kernel_matrix * self.alpha, axis=self.unsqueeze_dim) --------------------------------------------------------------------------------