├── .gitignore
├── LICENSE
├── README.md
├── autograd
├── README.md
├── demo_kaf_regression.py
└── kafnets.py
├── eq1.png
├── eq2.png
├── keras
├── README.md
├── demo_kaf_convolutional.py
├── demo_kaf_feedforward.py
└── kafnets.py
├── kernel_activation_functions.png
├── kernel_activation_functions_2D.png
├── pytorch
├── README.md
├── demo_kaf_convolutional.py
├── demo_kaf_feedforward.py
└── kafnets.py
└── tensorflow
├── README
├── demo_kaf_convolutional.py
├── demo_kaf_feedforward.py
└── kafnets.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__*
2 |
3 | \.idea/
4 |
5 | pytorch/data/MNIST/raw/train-labels-idx1-ubyte
6 |
7 | pytorch/data/MNIST/processed/test\.pt
8 |
9 | pytorch/data/MNIST/processed/training\.pt
10 |
11 | pytorch/data/MNIST/raw/t10k-images-idx3-ubyte
12 |
13 | pytorch/data/MNIST/raw/t10k-labels-idx1-ubyte
14 |
15 | pytorch/data/MNIST/raw/train-images-idx3-ubyte
16 |
17 | tensorflow/MNIST_data/
18 |
19 | tensorflow/logs/
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Ispamm Laboratory
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Kernel Activation Functions
2 |
3 | This repository contains several implementations of the kernel activation functions (KAFs) described in the following paper ([link to the preprint](https://arxiv.org/abs/1707.04035)):
4 |
5 | Scardapane, S., Van Vaerenbergh, S., Totaro, S. and Uncini, A., 2019.
6 | Kafnets: Kernel-based non-parametric activation functions for neural networks.
7 | Neural Networks, 110, pp.19-32.
8 |
9 | ## Available implementations
10 |
11 | We currently provide the following stable implementations:
12 |
13 | * [PyTorch](/pytorch): feedforward and convolutional layers, three kernels (Gaussian/ReLU/Softplus), with random initialization or kernel ridge regression.
14 | * [Keras](/keras): same functionalities as the PyTorch implementation.
15 | * [TensorFlow](/tensorflow/): similar to the Keras implementation, but we use the internal tf.keras.Layer and the eager execution in the demos.
16 | * [Autograd](/autograd): only feedforward layers with a Gaussian kernel and random initialization.
17 |
18 | More information for each implementation is given in the corresponding folder. The code should be relatively easy to plug-in in other architectures or projects.
19 |
20 | ## What is a KAF?
21 |
22 | Most neural networks work by interleaving linear projections and simple (fixed) activation functions, like the ReLU function:
23 |
24 |
25 |
26 |
27 |
28 | A KAF is instead a non-parametric activation function defined as a one-dimensional kernel approximator:
29 |
30 |
31 |
32 |
33 |
34 | where:
35 |
36 | * The dictionary of the kernel elements is fixed by sampling the x-axis with a uniform step around 0.
37 | * The user can select the kernel function (e.g., Gaussian, ReLU, Softplus) and the number of kernel elements D.
38 | * The linear coefficients are adapted independently at every neuron via standard back-propagation.
39 |
40 | In addition, the linear coefficients can be initialized using kernel ridge regression to behave similarly to a known function in the beginning of the optimization process.
41 |
42 |
43 | 
44 | Fig. 1. Examples of kernel activation functions learned on the Sensorless data set. The KAF after initialization is shown with a dashed red, while the final KAF is shown with a solid green. As a reference, the distribution of activation values after training is shown in light blue.
45 |
46 |
47 |
48 | 
49 | Fig. 2. Examples of two-dimensional kernel activation functions learned on the Sensorless data set.
50 |
51 |
52 | ## Contributing
53 |
54 | If you have an implementation for a different framework, or an enhanced version of the current code, feel free to contribute to the repository. For any issues related to the code you can use the issue tracker from GitHub.
55 |
56 | ## Citation
57 |
58 | If you use this code or a derivative thereof in your research, we would appreciate a citation to the original paper:
59 |
60 | @article{scardapane2019kafnets,
61 | title={Kafnets: Kernel-based non-parametric activation functions for neural networks},
62 | author={Scardapane, Simone and Van Vaerenbergh, Steven and Totaro, Simone and Uncini, Aurelio},
63 | journal={Neural Networks},
64 | volume={110},
65 | pages={19--32},
66 | year={2019},
67 | publisher={Elsevier}
68 | }
69 |
70 | ## License
71 |
72 | The code is released under the MIT License. See the attached LICENSE file.
73 |
--------------------------------------------------------------------------------
/autograd/README.md:
--------------------------------------------------------------------------------
1 | ## Kernel activation functions (Autograd)
2 |
3 | The code customizes the example from here:
4 | https://github.com/HIPS/autograd/blob/master/examples/neural_net.py
5 |
6 | In the *kafnets* module you can find the code to initialize and run neural networks having KAF activation functions.
7 | Note that this is a regression example and we use custom functions also in the output layer.
8 | For classification, consider changing the final layer to a standard softmax.
9 |
10 | ## Requirements
11 |
12 | * autograd = 1.2.
13 | * scikit-learn = 0.20.1 (for demo_kaf_regression.py)
--------------------------------------------------------------------------------
/autograd/demo_kaf_regression.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Imports from Python libraries
4 | import autograd.numpy as np
5 | from autograd import grad
6 | from autograd.misc.optimizers import adam
7 | from sklearn import datasets, preprocessing, model_selection
8 |
9 | # Custom imports
10 | from kafnets import init_kaf_nn, predict_kaf_nn
11 |
12 | # Extends this example:
13 | # https://github.com/HIPS/autograd/blob/master/examples/neural_net.py
14 |
15 | # Set seed for PRNG
16 | np.random.seed(1)
17 |
18 | # Size of the neural network's layers
19 | layers = [13, 10, 1]
20 |
21 | # Batch size
22 | B = 40
23 |
24 | # Load Boston dataset
25 | data = datasets.load_boston()
26 | X = preprocessing.MinMaxScaler(feature_range=(-1, +1)).fit_transform(data['data'])
27 | y = preprocessing.MinMaxScaler(feature_range=(-0.9, +0.9)).fit_transform(data['target'].reshape(-1, 1))
28 | (X_train, X_test, y_train, y_test) = model_selection.train_test_split(X, y, test_size=0.25)
29 |
30 | # Initialize KAF neural network
31 | w, info = init_kaf_nn(layers)
32 | predict_fcn = lambda w, inputs: predict_kaf_nn(w, inputs, info)
33 |
34 | # Loss function (MSE)
35 | def loss_fcn(params, inputs, targets):
36 | return np.mean(np.square(predict_fcn(params, inputs) - targets))
37 |
38 | # Iterator over mini-batches
39 | num_batches = int(np.ceil(X_train.shape[0] / B))
40 | def batch_indices(iter):
41 | idx = iter % num_batches
42 | return slice(idx * B, (idx+1) * B)
43 |
44 | # Define training objective
45 | def objective(params, iter):
46 | idx = batch_indices(iter)
47 | return loss_fcn(params, X_train[idx], y_train[idx])
48 |
49 | # Get gradient of objective using autograd.
50 | objective_grad = grad(objective)
51 |
52 | # The optimizers provided can optimize lists, tuples, or dicts of parameters
53 | print('Optimizing the network...\n')
54 | w_final = adam(objective_grad, w, num_iters=1000)
55 |
56 | # Compute test accuracy
57 | print('Final test MSE is ', loss_fcn(w_final, X_test, y_test), '\n')
--------------------------------------------------------------------------------
/autograd/kafnets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import autograd.numpy as np
4 |
5 | def init_kaf_nn(layer_sizes, scale=0.01, rs=np.random.RandomState(0), dict_size=20, boundary=3.0):
6 | """
7 | Initialize the parameters of a KAF feedforward network.
8 | - dict_size: the size of the dictionary for every neuron.
9 | - boundary: the boundary for the activation functions.
10 | """
11 |
12 | # Initialize the dictionary
13 | D = np.linspace(-boundary, boundary, dict_size).reshape(-1, 1)
14 |
15 | # Rule of thumb for gamma
16 | interval = D[1,0] - D[0,0];
17 | gamma = 0.5/np.square(2*interval)
18 | D = D.reshape(1, 1, -1)
19 |
20 | # Initialize a list of parameters for the layer
21 | w = [(rs.randn(insize, outsize) * scale, # Weight matrix
22 | rs.randn(outsize) * scale, # Bias vector
23 | rs.randn(1, outsize, dict_size) * 0.5) # Mixing coefficients
24 | for insize, outsize in zip(layer_sizes[:-1], layer_sizes[1:])]
25 |
26 | return w, (D, gamma)
27 |
28 | def predict_kaf_nn(w, X, info):
29 | """
30 | Compute the outputs of a KAF feedforward network.
31 | """
32 |
33 | D, gamma = info
34 | for W, b, alpha in w:
35 | outputs = np.dot(X, W) + b
36 | K = gauss_kernel(outputs, D, gamma)
37 | X = np.sum(K*alpha, axis=2)
38 | return X
39 |
40 | def gauss_kernel(X, D, gamma=1.0):
41 | """
42 | Compute the 1D Gaussian kernel between all elements of a
43 | NxH matrix and a fixed L-dimensional dictionary, resulting in a NxHxL matrix of kernel
44 | values.
45 | """
46 | return np.exp(- gamma*np.square(X.reshape(-1, X.shape[1], 1) - D))
--------------------------------------------------------------------------------
/eq1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/kernel-activation-functions/e525f7e82b54508ac262f82c9557ad66be2732f3/eq1.png
--------------------------------------------------------------------------------
/eq2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/kernel-activation-functions/e525f7e82b54508ac262f82c9557ad66be2732f3/eq2.png
--------------------------------------------------------------------------------
/keras/README.md:
--------------------------------------------------------------------------------
1 | ## Kernel activation functions (Keras)
2 |
3 | In the *kafnets* module you can find the modules for defining KAF layers, both for feedforward networks and convolutional networks (using the flag 'conv' during initialization).
4 | The code has two demos to showcase the modules using the Keras Sequential model.
5 |
6 | ## Requirements
7 |
8 | * keras = 2.2.4
9 | * numpy = 1.15.4
10 |
--------------------------------------------------------------------------------
/keras/demo_kaf_convolutional.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Simple demo using kernel activation functions with convolutional networks on the MNIST dataset.
5 | """
6 |
7 | # Keras imports
8 | from keras import datasets
9 | from keras.models import Sequential
10 | from keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
11 | from keras.utils import to_categorical
12 | import keras.backend as K
13 |
14 | # Custom imports
15 | from kafnets import KAF
16 |
17 | # Load Breast Cancer dataset
18 | (X_train, y_train), (X_test, y_test) = datasets.mnist.load_data()
19 |
20 | # Preprocessing is taken from here:
21 | # https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
22 | if K.image_data_format() == 'channels_first':
23 | X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
24 | X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
25 | else:
26 | X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
27 | X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
28 |
29 | X_train = X_train.astype('float32')
30 | X_test = X_test.astype('float32')
31 | X_train /= 255
32 | X_test /= 255
33 |
34 | # convert class vectors to binary class matrices
35 | y_train = to_categorical(y_train, 10)
36 | y_test = to_categorical(y_test, 10)
37 |
38 | # Initialize a KAF neural network
39 | kafnet = Sequential()
40 | kafnet.add(Conv2D(32, (3, 3), input_shape=(28, 28, 1)))
41 | kafnet.add(KAF(32, conv=True))
42 | kafnet.add(Conv2D(32, (3, 3)))
43 | kafnet.add(KAF(32, conv=True))
44 | kafnet.add(MaxPooling2D(pool_size=(2, 2)))
45 | kafnet.add(Flatten())
46 | kafnet.add(Dense(100))
47 | kafnet.add(KAF(100))
48 | kafnet.add(Dense(10, activation='softmax'))
49 |
50 | # Training
51 | kafnet.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
52 | kafnet.summary()
53 | kafnet.fit(X_train, y_train, epochs=5, batch_size=32, verbose=1)
54 |
55 | # Evaluation
56 | print('Final accuracy is: ' + str(kafnet.evaluate(X_test, y_test, batch_size=64)[1]))
--------------------------------------------------------------------------------
/keras/demo_kaf_feedforward.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Simple demo using kernel activation functions on a basic regression dataset.
5 | """
6 |
7 | # Keras imports
8 | from keras import datasets
9 | from keras.models import Sequential
10 | from keras.layers import Dense
11 |
12 | # Custom imports
13 | from kafnets import KAF
14 |
15 | # Load Breast Cancer dataset
16 | (X_train, y_train), (X_test, y_test) = datasets.boston_housing.load_data()
17 |
18 | # Initialize a KAF neural network
19 | kafnet = Sequential([
20 | Dense(20, input_shape=(13,)),
21 | KAF(20),
22 | Dense(1),
23 | ])
24 |
25 | #Uncomment to use KAF with Softplus kernel
26 | #kafnet = Sequential([
27 | # Dense(20, input_shape=(13,)),
28 | # KAF(20, kernel='softplus', D=5),
29 | # Dense(1),
30 | #])
31 |
32 | # Training
33 | kafnet.compile(optimizer='adam', loss='mse')
34 | kafnet.summary()
35 | kafnet.fit(X_train, y_train, epochs=250, batch_size=32, verbose=0)
36 |
37 | # Evaluation
38 | print('Final error is: ' + str(kafnet.evaluate(X_test, y_test, batch_size=64)))
--------------------------------------------------------------------------------
/keras/kafnets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from keras.layers import Layer
3 | from keras import backend as K
4 |
5 | class KAF(Layer):
6 | """ Implementation of the kernel activation function.
7 |
8 | Parameters
9 | ----------
10 | num_parameters: int
11 | Size of the layer (number of neurons).
12 | D: int, optional
13 | Size of the dictionary for each neuron. Default to 20.
14 | conv: bool, optional
15 | True if this is a convolutive layer, False for a feedforward layer. Default to False.
16 | boundary: float, optional
17 | Dictionary elements are sampled uniformly in [-boundary, boundary]. Default to 4.0.
18 | init_fcn: None or func, optional
19 | If None, elements are initialized randomly. Otherwise, elements are initialized to approximate given function.
20 | kernel: {'gauss', 'relu', 'softplus'}, optional
21 | Kernel function to be used. Defaults to 'gaussian'.
22 |
23 | Example
24 | ----------
25 | Neural network with one hidden layer with KAF nonlinearities:
26 |
27 | >>> net = Sequential([Dense(10), KAF(10), Dense(10, 1)])
28 |
29 | References
30 | ----------
31 | [1] Scardapane, S., Van Vaerenbergh, S., Totaro, S. and Uncini, A., 2019.
32 | Kafnets: kernel-based non-parametric activation functions for neural networks.
33 | Neural Networks, 110, pp. 19-32.
34 | [2] Marra, G., Zanca, D., Betti, A. and Gori, M., 2018.
35 | Learning Neuron Non-Linearities with Kernel-Based Deep Neural Networks.
36 | arXiv preprint arXiv:1807.06302.
37 | """
38 |
39 | def __init__(self, num_parameters, D=20, boundary=3.0, conv=False, init_fcn=None, kernel='gaussian', **kwargs):
40 | self.num_parameters = num_parameters
41 | self.D = D
42 | self.boundary = boundary
43 | self.init_fcn = init_fcn
44 | self.conv = conv
45 | if self.conv:
46 | self.unsqueeze_dim = 4
47 | else:
48 | self.unsqueeze_dim = 2
49 | self.kernel = kernel
50 | if not (kernel in ['gaussian', 'relu', 'softplus']):
51 | raise ValueError('Kernel not recognized (must be {gaussian, relu, softplus})')
52 | super().__init__(**kwargs)
53 |
54 | def build(self, input_shape):
55 |
56 | # Initialize the fixed dictionary
57 | d = np.linspace(-self.boundary, self.boundary, self.D).astype(np.float32).reshape(-1, 1)
58 |
59 | if self.conv:
60 | self.dict = self.add_weight(name='dict',
61 | shape=(1, 1, 1, 1, self.D),
62 | initializer='uniform',
63 | trainable=False)
64 | K.set_value(self.dict, d.reshape(1, 1, 1, 1, -1))
65 | else:
66 | self.dict = self.add_weight(name='dict',
67 | shape=(1, 1, self.D),
68 | initializer='uniform',
69 | trainable=False)
70 | K.set_value(self.dict, d.reshape(1, 1, -1))
71 |
72 | if self.kernel == 'gaussian':
73 | self.kernel_fcn = self.gaussian_kernel
74 | # Rule of thumb for gamma
75 | interval = (d[1] - d[0])
76 | sigma = 2 * interval # empirically chosen
77 | self.gamma = 0.5 / np.square(sigma)
78 | elif self.kernel == 'softplus':
79 | self.kernel_fcn = self.softplus_kernel
80 | else:
81 | self.kernel_fcn = self.relu_kernel
82 |
83 |
84 | # Mixing coefficients
85 | if self.conv:
86 | self.alpha = self.add_weight(name='alpha',
87 | shape=(1, 1, 1, self.num_parameters, self.D),
88 | initializer='normal',
89 | trainable=True)
90 | else:
91 | self.alpha = self.add_weight(name='alpha',
92 | shape=(1, self.num_parameters, self.D),
93 | initializer='normal',
94 | trainable=True)
95 |
96 | # Optional initialization with kernel ridge regression
97 | if self.init_fcn is not None:
98 | if self.kernel == 'gaussian':
99 | kernel_matrix = np.exp(- self.gamma*(d - d.T) ** 2)
100 | elif self.kernel == 'softplus':
101 | kernel_matrix = np.log(np.exp(d - d.T) + 1.0)
102 | else:
103 | raise ValueError('Cannot perform kernel ridge regression with ReLU kernel (singular matrix)')
104 |
105 | alpha_init = np.linalg.solve(kernel_matrix + 1e-5*np.eye(self.D), self.init_fcn(d)).reshape(-1)
106 | if self.conv:
107 | K.set_value(self.alpha, np.repeat(alpha_init.reshape(1, 1, 1, 1, -1), self.num_parameters, axis=3))
108 | else:
109 | K.set_value(self.alpha, np.repeat(alpha_init.reshape(1, 1, -1), self.num_parameters, axis=1))
110 |
111 | super(KAF, self).build(input_shape)
112 |
113 | def gaussian_kernel(self, x):
114 | return K.exp(- self.gamma * (K.expand_dims(x, axis=self.unsqueeze_dim) - self.dict) ** 2.0)
115 |
116 | def softplus_kernel(self, x):
117 | return K.softplus(K.expand_dims(x, axis=self.unsqueeze_dim) - self.dict)
118 |
119 | def relu_kernel(self, x):
120 | return K.relu(K.expand_dims(x, axis=self.unsqueeze_dim) - self.dict)
121 |
122 | def call(self, x):
123 | kernel_matrix = self.kernel_fcn(x)
124 | return K.sum(kernel_matrix * self.alpha, axis=self.unsqueeze_dim)
125 |
126 | def get_config(self):
127 | return {'num_parameters': self.num_parameters,
128 | 'D': self.D,
129 | 'boundary': self.boundary,
130 | 'conv': self.conv,
131 | 'init_fcn': self.init_fcn,
132 | 'kernel': self.kernel
133 | }
134 |
--------------------------------------------------------------------------------
/kernel_activation_functions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/kernel-activation-functions/e525f7e82b54508ac262f82c9557ad66be2732f3/kernel_activation_functions.png
--------------------------------------------------------------------------------
/kernel_activation_functions_2D.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ispamm/kernel-activation-functions/e525f7e82b54508ac262f82c9557ad66be2732f3/kernel_activation_functions_2D.png
--------------------------------------------------------------------------------
/pytorch/README.md:
--------------------------------------------------------------------------------
1 | ## Kernel activation functions (PyTorch)
2 |
3 | In the *kafnets* module you can find the modules for defining KAF layers, both for feedforward networks and convolutional networks (using the flag 'conv' during initialization).
4 | The code has two demos to showcase the modules using the PyTorch sequential class.
5 |
6 | ## Requirements
7 |
8 | * pytorch = 1.0.1
9 | * numpy = 1.15.4
10 | * tqdm = 4.28.1 (for demo_kaf_convolutional.py)
11 | * scikit-learn = 0.20.1 (for demo_feedforward.py)
--------------------------------------------------------------------------------
/pytorch/demo_kaf_convolutional.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Simple demo using kernel activation functions with convolutional layers on the MNIST dataset.
5 | """
6 |
7 | # Imports from Python libraries
8 | import numpy as np
9 | import tqdm
10 |
11 | # PyTorch imports
12 | import torch
13 | import torch.utils.data
14 | from torchvision import datasets, transforms
15 | from torch.nn import Module
16 |
17 | # Custom imports
18 | from kafnets import KAF
19 |
20 | # Set seed for PRNG
21 | np.random.seed(1)
22 | torch.manual_seed(1)
23 |
24 | # Enable CUDA (optional)
25 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
26 |
27 | # Load MNIST dataset
28 | train_loader = torch.utils.data.DataLoader(datasets.MNIST('data/MNIST', train=True, download=True,
29 | transform=transforms.Compose([transforms.ToTensor(),
30 | transforms.Normalize((0.1307,), (0.3081,))])),
31 | batch_size=32, shuffle=True)
32 | test_loader = torch.utils.data.DataLoader(datasets.MNIST('data/MNIST', train=False, transform=transforms.Compose(
33 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), batch_size=32, shuffle=True)
34 |
35 | class Flatten(Module):
36 | """
37 | Simple flatten module, see this discussion:
38 | https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983
39 | """
40 | def forward(self, input):
41 | return input.view(input.size(0), -1)
42 |
43 | # Initialize a KAF neural network
44 | kafnet = torch.nn.Sequential(
45 | torch.nn.Conv2d(1, 20, kernel_size=5, padding=(2,2)),
46 | torch.nn.MaxPool2d(3),
47 | KAF(20, conv=True),
48 | torch.nn.Conv2d(20, 20, kernel_size=5, padding=(2,2)),
49 | torch.nn.MaxPool2d(3),
50 | KAF(20, conv=True),
51 | Flatten(),
52 | torch.nn.Linear(180, 10),
53 | )
54 |
55 | # Reset parameters
56 | for m in kafnet:
57 | if len(m._parameters) > 0:
58 | m.reset_parameters()
59 |
60 | print('Training: **KAFNET**', flush=True)
61 |
62 | # Loss function
63 | loss_fn = torch.nn.CrossEntropyLoss()
64 |
65 | # Build optimizer
66 | optimizer = torch.optim.Adam(kafnet.parameters(), weight_decay=1e-4)
67 |
68 | # Put model on GPU if needed
69 | kafnet.to(device)
70 |
71 | max_epochs = 10
72 | for idx_epoch in range(max_epochs):
73 |
74 | print('Epoch #', idx_epoch, ' of #', max_epochs)
75 | kafnet.train()
76 |
77 | for (X_batch, y_batch) in tqdm.tqdm(train_loader):
78 |
79 | # Eventually move mini-batch to GPU
80 | X_batch, y_batch = X_batch.to(device), y_batch.to(device)
81 |
82 | # Forward pass: compute predicted y by passing x to the model.
83 | y_pred = kafnet(X_batch)
84 |
85 | # Compute loss.
86 | loss = loss_fn(y_pred, y_batch)
87 |
88 | # Zeroes out all gradients
89 | optimizer.zero_grad()
90 |
91 | # Backward pass
92 | loss.backward()
93 |
94 | # Update parameters
95 | optimizer.step()
96 |
97 | # Compute final test score
98 | with torch.no_grad():
99 | print('Computing test score for: **KAFNET**', flush=True)
100 | kafnet.eval()
101 | acc = 0
102 | for _, (X_batch, y_batch) in enumerate(test_loader):
103 | # Eventually move mini-batch to GPU
104 | X_batch = X_batch.to(device)
105 | acc += np.sum(y_batch.numpy() == np.argmax(kafnet(X_batch).cpu().numpy(), axis=1))
106 | print('Final score on test set: ', acc / test_loader.dataset.__len__())
107 |
--------------------------------------------------------------------------------
/pytorch/demo_kaf_feedforward.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Simple demo using kernel activation functions on a basic classification dataset.
5 | """
6 |
7 | # Imports from Python libraries
8 | import numpy as np
9 | from sklearn import datasets, preprocessing, model_selection
10 |
11 | # PyTorch imports
12 | import torch
13 | from torch.utils.data import TensorDataset
14 | from torch.utils.data import DataLoader
15 |
16 | # Custom imports
17 | from kafnets import KAF
18 |
19 | # Set seed for PRNG
20 | np.random.seed(1)
21 | torch.manual_seed(1)
22 |
23 | # Batch size
24 | B = 40
25 |
26 | # Load Breast Cancer dataset
27 | data = datasets.load_breast_cancer()
28 | X = preprocessing.MinMaxScaler(feature_range=(-1, +1)).fit_transform(data['data']).astype(np.float32)
29 | (X_train, X_test, y_train, y_test) = model_selection.train_test_split(X, data['target'].astype(np.float32).reshape(-1, 1), test_size=0.25)
30 |
31 | # Load in PyTorch data loader
32 | data_train = DataLoader(TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)), shuffle=True, batch_size=64)
33 | data_test = DataLoader(TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test)), batch_size=100)
34 |
35 | # Initialize a KAF neural network
36 | kafnet = torch.nn.Sequential(
37 | torch.nn.Linear(30, 20),
38 | KAF(20),
39 | torch.nn.Linear(20, 1),
40 | )
41 |
42 | # Uncomment to use KAF with custom initialization
43 | #kafnet = torch.nn.Sequential(
44 | # torch.nn.Linear(30, 20),
45 | # KAF(20, init_fcn=np.tanh),
46 | # torch.nn.Linear(20, 1),
47 | #)
48 |
49 | #Uncomment to use KAF with Softplus kernel
50 | #kafnet = torch.nn.Sequential(
51 | # torch.nn.Linear(30, 20),
52 | # KAF(20, kernel='softplus'),
53 | # torch.nn.Linear(20, 1),
54 | #)
55 |
56 | # Reset parameters
57 | for m in kafnet:
58 | if len(m._parameters) > 0:
59 | m.reset_parameters()
60 |
61 | print('Training: **KAFNET**', flush=True)
62 |
63 | # Loss function
64 | loss_fn = torch.nn.BCEWithLogitsLoss()
65 |
66 | # Build optimizer
67 | optimizer = torch.optim.Adam(kafnet.parameters(), weight_decay=1e-4)
68 |
69 | for idx_epoch in range(100):
70 |
71 | kafnet.train()
72 |
73 | for _, (X_batch, y_batch) in enumerate(data_train):
74 |
75 | # Forward pass: compute predicted y by passing x to the model.
76 | y_pred = kafnet(X_batch)
77 |
78 | # Compute loss.
79 | loss = loss_fn(y_pred, y_batch)
80 |
81 | # Zeroes out all gradients
82 | optimizer.zero_grad()
83 |
84 | # Backward pass
85 | loss.backward()
86 |
87 | # Update parameters
88 | optimizer.step()
89 |
90 | with torch.no_grad():
91 | # Compute final test score
92 | print('Computing test score for: **KAFNET**', flush=True)
93 | kafnet.eval()
94 | acc = 0
95 | for _, (X_batch, y_batch) in enumerate(data_test):
96 | acc += np.sum(y_batch.numpy() == np.round(torch.sigmoid(kafnet(X_batch)).numpy()))
97 | print('Final score on test set: ', acc/data_test.dataset.__len__())
98 |
--------------------------------------------------------------------------------
/pytorch/kafnets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn.parameter import Parameter
7 | from torch.nn.init import normal_
8 | import torch.nn.functional as F
9 |
10 |
11 | class KAF(nn.Module):
12 | """ Implementation of the kernel activation function.
13 |
14 | Parameters
15 | ----------
16 | num_parameters: int
17 | Size of the layer (number of neurons).
18 | D: int, optional
19 | Size of the dictionary for each neuron. Default to 20.
20 | conv: bool, optional
21 | True if this is a convolutive layer, False for a feedforward layer. Default to False.
22 | boundary: float, optional
23 | Dictionary elements are sampled uniformly in [-boundary, boundary]. Default to 4.0.
24 | init_fcn: None or func, optional
25 | If None, elements are initialized randomly. Otherwise, elements are initialized to approximate given function.
26 | kernel: {'gauss', 'relu', 'softplus'}, optional
27 | Kernel function to be used. Defaults to 'gaussian'.
28 |
29 | Example
30 | ----------
31 | Neural network with one hidden layer with KAF nonlinearities:
32 |
33 | >>> net = Sequential([nn.Linear(10, 20), KAF(20), nn.Linear(20, 1)])
34 |
35 | References
36 | ----------
37 | [1] Scardapane, S., Van Vaerenbergh, S., Totaro, S. and Uncini, A., 2019.
38 | Kafnets: kernel-based non-parametric activation functions for neural networks.
39 | Neural Networks, 110, pp. 19-32.
40 | [2] Marra, G., Zanca, D., Betti, A. and Gori, M., 2018.
41 | Learning Neuron Non-Linearities with Kernel-Based Deep Neural Networks.
42 | arXiv preprint arXiv:1807.06302.
43 | """
44 |
45 | def __init__(self, num_parameters, D=20, conv=False, boundary=4.0, init_fcn=None, kernel='gaussian'):
46 |
47 | super().__init__()
48 | self.num_parameters, self.D, self.conv = num_parameters, D, conv
49 |
50 | # Initialize the dictionary (NumPy)
51 | self.dict_numpy = np.linspace(-boundary, boundary, self.D).astype(np.float32).reshape(-1, 1)
52 |
53 | # Save the dictionary
54 | if self.conv:
55 | self.register_buffer('dict', torch.from_numpy(self.dict_numpy).view(1, 1, 1, 1, -1))
56 | self.unsqueeze_dim = 4
57 | else:
58 | self.register_buffer('dict', torch.from_numpy(self.dict_numpy).view(1, -1))
59 | self.unsqueeze_dim = 2
60 |
61 | # Select appropriate kernel function
62 | if not (kernel in ['gaussian', 'relu', 'softplus']):
63 | raise ValueError('Kernel not recognized (must be {gaussian, relu, softplus})')
64 |
65 | if kernel == 'gaussian':
66 | self.kernel_fcn = self.gaussian_kernel
67 | # Rule of thumb for gamma (only needed for Gaussian kernel)
68 | interval = (self.dict_numpy[1] - self.dict_numpy[0])
69 | sigma = 2 * interval # empirically chosen
70 | self.gamma_init = float(0.5 / np.square(sigma))
71 |
72 | # Initialize gamma
73 | if self.conv:
74 | self.register_buffer('gamma', torch.from_numpy(np.ones((1, 1, 1, 1, self.D), dtype=np.float32)*self.gamma_init))
75 | else:
76 | self.register_buffer('gamma', torch.from_numpy(np.ones((1, 1, self.D), dtype=np.float32)*self.gamma_init))
77 |
78 | elif kernel == 'relu':
79 | self.kernel_fcn = self.relu_kernel
80 | else:
81 | self.kernel_fcn = self.softplus_kernel
82 |
83 | # Initialize mixing coefficients
84 | if self.conv:
85 | self.alpha = Parameter(torch.FloatTensor(1, self.num_parameters, 1, 1, self.D))
86 | else:
87 | self.alpha = Parameter(torch.FloatTensor(1, self.num_parameters, self.D))
88 |
89 | # Eventually: initialization with kernel ridge regression
90 | self.init_fcn = init_fcn
91 | if init_fcn != None:
92 |
93 | if kernel == 'gaussian':
94 | K = np.exp(- self.gamma_init*(self.dict_numpy - self.dict_numpy.T) ** 2)
95 | elif kernel == 'softplus':
96 | K = np.log(np.exp(self.dict_numpy - self.dict_numpy.T) + 1.0)
97 | else:
98 | #K = np.maximum(self.dict_numpy - self.dict_numpy.T, 0)
99 | raise ValueError('Cannot perform kernel ridge regression with ReLU kernel (singular matrix)')
100 |
101 | self.alpha_init = np.linalg.solve(K + 1e-4 * np.eye(self.D), self.init_fcn(self.dict_numpy)).reshape(-1).astype(np.float32)
102 |
103 | else:
104 | self.alpha_init = None
105 |
106 | # Reset the parameters
107 | self.reset_parameters()
108 |
109 | def reset_parameters(self):
110 | if self.init_fcn != None:
111 | if self.conv:
112 | self.alpha.data = torch.from_numpy(self.alpha_init).repeat(1, self.num_parameters, 1, 1, 1)
113 | else:
114 | self.alpha.data = torch.from_numpy(self.alpha_init).repeat(1, self.num_parameters, 1)
115 | else:
116 | normal_(self.alpha.data, std=0.8)
117 |
118 | def gaussian_kernel(self, input):
119 | return torch.exp(- torch.mul((torch.add(input.unsqueeze(self.unsqueeze_dim), - self.dict))**2, self.gamma))
120 |
121 | def relu_kernel(self, input):
122 | return F.relu(input.unsqueeze(self.unsqueeze_dim) - self.dict)
123 |
124 | def softplus_kernel(self, input):
125 | return F.softplus(input.unsqueeze(self.unsqueeze_dim) - self.dict)
126 |
127 | def forward(self, input):
128 | K = self.kernel_fcn(input)
129 | y = torch.sum(K*self.alpha, self.unsqueeze_dim)
130 | return y
131 |
132 | def __repr__(self):
133 | return self.__class__.__name__ + ' (' \
134 | + str(self.num_parameters) + ')'
--------------------------------------------------------------------------------
/tensorflow/README:
--------------------------------------------------------------------------------
1 | ## Kernel activation functions (TensorFlow)
2 |
3 | In the *kafnets* module you can find the modules for defining KAF layers, both for feedforward networks and convolutional networks (using the flag 'conv' during initialization).
4 | The code has two demos to showcase the modules using TensorFlow layers.
5 |
6 | ## Requirements
7 |
8 | * tensorflow = 1.1.13
9 | * numpy = 1.15.4
--------------------------------------------------------------------------------
/tensorflow/demo_kaf_convolutional.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Simple demo using kernel activation functions with convolutional networks on the MNIST dataset.
5 | """
6 |
7 | # Import TensorFlow
8 | import numpy as np
9 | import tensorflow as tf
10 | import tensorflow.contrib.eager as tfe
11 | tf.enable_eager_execution()
12 |
13 | # Keras imports
14 | from tensorflow.keras import datasets
15 | from tensorflow.keras.models import Sequential
16 | from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
17 |
18 | # Custom imports
19 | from kafnets import KAF
20 | import tqdm
21 |
22 | # Load Breast Cancer dataset
23 | (X_train, y_train), (X_test, y_test) = datasets.mnist.load_data()
24 |
25 | # Preprocessing is taken from here:
26 | # https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
27 | X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
28 | X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
29 |
30 | X_train = X_train.astype('float32')
31 | X_test = X_test.astype('float32')
32 | X_train /= 255
33 | X_test /= 255
34 |
35 |
36 | # Initialize a KAF neural network
37 | kafnet = Sequential()
38 | kafnet.add(Conv2D(32, (3, 3), input_shape=(28, 28, 1)))
39 | kafnet.add(KAF(32, conv=True))
40 | kafnet.add(Conv2D(32, (3, 3)))
41 | kafnet.add(KAF(32, conv=True))
42 | kafnet.add(MaxPooling2D(pool_size=(2, 2)))
43 | kafnet.add(Flatten())
44 | kafnet.add(Dense(100))
45 | kafnet.add(KAF(100))
46 | kafnet.add(Dense(10, activation='softmax'))
47 |
48 | # Use tf.data DataLoader
49 | train_data = tf.data.Dataset.from_tensor_slices((X_train.astype(np.float32), y_train.astype(np.int64)))
50 | test_data = tf.data.Dataset.from_tensor_slices((X_test.astype(np.float32), y_test.astype(np.int64)))
51 |
52 | # Optimizer
53 | opt = tf.train.AdamOptimizer()
54 |
55 | # Training
56 | for e in tqdm.trange(5, desc='Training'):
57 |
58 | for xb, yb in train_data.shuffle(1000).batch(32):
59 |
60 | with tfe.GradientTape() as tape:
61 | loss = tf.losses.sparse_softmax_cross_entropy(yb, kafnet(xb))
62 | g = tape.gradient(loss, kafnet.variables)
63 | opt.apply_gradients(zip(g, kafnet.variables))
64 |
65 | # Evaluation
66 | acc = tfe.metrics.Accuracy()
67 | for xb, yb in test_data.batch(32):
68 | acc(yb, tf.argmax(kafnet(xb), axis=1))
69 | tqdm.tqdm.write('Test accuracy after epoch {} is: '.format(e+1) + str(acc.result()))
--------------------------------------------------------------------------------
/tensorflow/demo_kaf_feedforward.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Simple demo using kernel activation functions on a basic regression dataset.
5 | """
6 |
7 | # Import TensorFlow
8 | import numpy as np
9 | import tensorflow as tf
10 | import tensorflow.contrib.eager as tfe
11 | tf.enable_eager_execution()
12 |
13 | # Keras imports
14 | from tensorflow.keras import datasets
15 | from tensorflow.keras.models import Sequential
16 | from tensorflow.keras.layers import Dense
17 |
18 | # Custom imports
19 | from kafnets import KAF
20 | import tqdm
21 |
22 | # Load Breast Cancer dataset
23 | (X_train, y_train), (X_test, y_test) = datasets.boston_housing.load_data()
24 |
25 | # Initialize a KAF neural network
26 | kafnet = Sequential([
27 | Dense(20, input_shape=(13,)),
28 | KAF(20),
29 | Dense(1),
30 | ])
31 |
32 | #Uncomment to use KAF with Softplus kernel
33 | #kafnet = Sequential([
34 | # Dense(20, input_shape=(13,)),
35 | # KAF(20, kernel='softplus', D=5),
36 | # Dense(1),
37 | #])
38 |
39 | # Use tf.data DataLoader
40 | train_data = tf.data.Dataset.from_tensor_slices((X_train.astype(np.float32), y_train.reshape(-1, 1)))
41 | test_data = tf.data.Dataset.from_tensor_slices((X_test.astype(np.float32), y_test.astype(np.float32).reshape(-1, 1)))
42 |
43 | # Optimizer
44 | opt = tf.train.AdamOptimizer()
45 |
46 | # Training
47 | for e in tqdm.trange(300, desc='Training'):
48 |
49 | for xb, yb in train_data.shuffle(1000).batch(32):
50 |
51 | with tfe.GradientTape() as tape:
52 | loss = tf.losses.mean_squared_error(yb, kafnet(xb))
53 | g = tape.gradient(loss, kafnet.variables)
54 | opt.apply_gradients(zip(g, kafnet.variables))
55 |
56 | # Evaluation
57 | err = tfe.metrics.Mean()
58 | for xb, yb in test_data.batch(32):
59 | err((yb - kafnet(xb))**2)
60 | print('Final error is: ' + str(err.result()))
61 |
--------------------------------------------------------------------------------
/tensorflow/kafnets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow.keras.layers import Layer
3 | import tensorflow as tf
4 |
5 | class KAF(Layer):
6 | """ Implementation of the kernel activation function.
7 |
8 | Parameters
9 | ----------
10 | num_parameters: int
11 | Size of the layer (number of neurons).
12 | D: int, optional
13 | Size of the dictionary for each neuron. Default to 20.
14 | conv: bool, optional
15 | True if this is a convolutive layer, False for a feedforward layer. Default to False.
16 | boundary: float, optional
17 | Dictionary elements are sampled uniformly in [-boundary, boundary]. Default to 4.0.
18 | init_fcn: None or func, optional
19 | If None, elements are initialized randomly. Otherwise, elements are initialized to approximate given function.
20 | kernel: {'gauss', 'relu', 'softplus'}, optional
21 | Kernel function to be used. Defaults to 'gaussian'.
22 |
23 | Example
24 | ----------
25 | Neural network with one hidden layer with KAF nonlinearities:
26 |
27 | >>> net = Sequential([Dense(10), KAF(10), Dense(10, 1)])
28 |
29 | References
30 | ----------
31 | [1] Scardapane, S., Van Vaerenbergh, S., Totaro, S. and Uncini, A., 2019.
32 | Kafnets: kernel-based non-parametric activation functions for neural networks.
33 | Neural Networks, 110, pp. 19-32.
34 | [2] Marra, G., Zanca, D., Betti, A. and Gori, M., 2018.
35 | Learning Neuron Non-Linearities with Kernel-Based Deep Neural Networks.
36 | arXiv preprint arXiv:1807.06302.
37 | """
38 |
39 | def __init__(self, num_parameters, D=20, boundary=3.0, conv=False, init_fcn=None, kernel='gaussian', **kwargs):
40 | self.num_parameters = num_parameters
41 | self.D = D
42 | self.boundary = boundary
43 | self.init_fcn = init_fcn
44 | self.conv = conv
45 | if self.conv:
46 | self.unsqueeze_dim = 4
47 | else:
48 | self.unsqueeze_dim = 2
49 | self.kernel = kernel
50 | if not (kernel in ['gaussian', 'relu', 'softplus']):
51 | raise ValueError('Kernel not recognized (must be {gaussian, relu, softplus})')
52 | super().__init__(**kwargs)
53 |
54 | def build(self, input_shape):
55 |
56 | # Initialize the fixed dictionary
57 | d = np.linspace(-self.boundary, self.boundary, self.D).astype(np.float32).reshape(-1, 1)
58 |
59 | if self.conv:
60 | self.dict = self.add_weight(name='dict',
61 | shape=(1, 1, 1, 1, self.D),
62 | initializer='uniform',
63 | trainable=False)
64 | tf.assign(self.dict, d.reshape(1, 1, 1, 1, -1))
65 | else:
66 | self.dict = self.add_weight(name='dict',
67 | shape=(1, 1, self.D),
68 | initializer='uniform',
69 | trainable=False)
70 | tf.assign(self.dict, d.reshape(1, 1, -1))
71 |
72 | if self.kernel == 'gaussian':
73 | self.kernel_fcn = self.gaussian_kernel
74 | # Rule of thumb for gamma
75 | interval = (d[1] - d[0])
76 | sigma = 2 * interval # empirically chosen
77 | self.gamma = 0.5 / np.square(sigma)
78 | elif self.kernel == 'softplus':
79 | self.kernel_fcn = self.softplus_kernel
80 | else:
81 | self.kernel_fcn = self.relu_kernel
82 |
83 |
84 | # Mixing coefficients
85 | if self.conv:
86 | self.alpha = self.add_weight(name='alpha',
87 | shape=(1, 1, 1, self.num_parameters, self.D),
88 | initializer='normal',
89 | trainable=True)
90 | else:
91 | self.alpha = self.add_weight(name='alpha',
92 | shape=(1, self.num_parameters, self.D),
93 | initializer='normal',
94 | trainable=True)
95 |
96 | # Optional initialization with kernel ridge regression
97 | if self.init_fcn is not None:
98 | if self.kernel == 'gaussian':
99 | kernel_matrix = np.exp(- self.gamma*(d - d.T) ** 2)
100 | elif self.kernel == 'softplus':
101 | kernel_matrix = np.log(np.exp(d - d.T) + 1.0)
102 | else:
103 | raise ValueError('Cannot perform kernel ridge regression with ReLU kernel (singular matrix)')
104 |
105 | alpha_init = np.linalg.solve(kernel_matrix + 1e-5*np.eye(self.D), self.init_fcn(d)).reshape(-1)
106 | if self.conv:
107 | tf.assign(self.alpha, np.repeat(alpha_init.reshape(1, 1, 1, 1, -1), self.num_parameters, axis=3))
108 | else:
109 | tf.assign(self.alpha, np.repeat(alpha_init.reshape(1, 1, -1), self.num_parameters, axis=1))
110 |
111 | super(KAF, self).build(input_shape)
112 |
113 | def gaussian_kernel(self, x):
114 | return tf.exp(- self.gamma * (tf.expand_dims(x, axis=self.unsqueeze_dim) - self.dict) ** 2.0)
115 |
116 | def softplus_kernel(self, x):
117 | return tf.softplus(tf.expand_dims(x, axis=self.unsqueeze_dim) - self.dict)
118 |
119 | def relu_kernel(self, x):
120 | return tf.relu(tf.expand_dims(x, axis=self.unsqueeze_dim) - self.dict)
121 |
122 | def call(self, x):
123 | kernel_matrix = self.kernel_fcn(x)
124 | return tf.reduce_sum(kernel_matrix * self.alpha, axis=self.unsqueeze_dim)
--------------------------------------------------------------------------------