├── .gitattributes ├── .gitignore ├── BinaryNet.py ├── CompositeLayers ├── BinaryNetConvBNReluLayer.py ├── ConvBNReluLayer.py ├── XNORConvLayer.py └── __init__.py ├── CustomLayers ├── BinaryNetLayer.py ├── CustomLayersDictionary.py ├── XNORNetLayer.py └── __init__.py ├── CustomOps ├── __init__.py ├── customOps.py ├── tensorflowOps.py └── theanoOps.py ├── NetworkParameters.py └── readme.md /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *pyc 2 | -------------------------------------------------------------------------------- /BinaryNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from CustomOps.customOps import SetSession 5 | 6 | # Call this first here, to make sure that Tensorflow registers our custom ops properly 7 | SetSession() 8 | 9 | from keras.models import Model, load_model 10 | from keras.layers import Dense, Flatten, Input 11 | from keras.optimizers import Adam 12 | from keras.datasets import mnist, cifar100, cifar10 13 | from keras.utils import np_utils 14 | from keras.callbacks import ModelCheckpoint 15 | from keras import backend as K 16 | 17 | from CompositeLayers.ConvBNReluLayer import ConvBNReluLayer 18 | from CompositeLayers.BinaryNetConvBNReluLayer import BinaryNetConvBNReluLayer, BinaryNetActivation 19 | from CustomLayers.CustomLayersDictionary import customLayersDictionary 20 | from CompositeLayers.XNORConvLayer import XNORConvBNReluLayer, BNXNORConvReluLayer 21 | from NetworkParameters import NetworkParameters 22 | from CustomLayers.CustomLayersDictionary import customLayerCallbacks 23 | 24 | np.random.seed(1337) # for reproducibility 25 | 26 | 27 | def CreateModel(input_shape, nb_classes, parameters): 28 | model_input = Input(shape=input_shape) 29 | 30 | output = model_input 31 | 32 | if parameters.binarisation_type == 'BinaryNet': 33 | print('Using BinaryNet binary convolution layers') 34 | layerType = BinaryNetConvBNReluLayer 35 | elif parameters.binarisation_type == 'XNORNet': 36 | print('Using XNORNet binary convolution layers') 37 | layerType = BNXNORConvReluLayer 38 | else: 39 | assert False, 'Unsupported binarisation type!' 40 | 41 | # As per the paper, the first layer can't be binary 42 | output = ConvBNReluLayer(input=output, nb_filters=16, border='valid', kernel_size=(3, 3), stride=(1, 1)) 43 | 44 | # Add an extra binarisation layer here, as with Theano need input binarisation 45 | if K.backend() == 'theano': 46 | output = BinaryNetActivation()(output) 47 | 48 | output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1)) 49 | output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1)) 50 | output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1)) 51 | output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1)) 52 | output = layerType(input=output, nb_filters=32, border='valid', kernel_size=(3, 3), stride=(1, 1)) 53 | 54 | output = Flatten()(output) 55 | output = Dense(nb_classes, use_bias=True, activation='softmax')(output) 56 | 57 | model = Model(inputs=model_input, outputs=output) 58 | 59 | model.summary() 60 | 61 | return model 62 | 63 | 64 | 65 | ############################ 66 | # Parameters 67 | 68 | modelDirectory = os.getcwd() 69 | 70 | parameters = NetworkParameters(modelDirectory) 71 | parameters.nb_epochs = 1 72 | parameters.batch_size = 32 73 | parameters.lr = 0.0005 74 | parameters.batch_scale_factor = 8 75 | parameters.decay = 0.001 76 | 77 | parameters.binarisation_type = 'BinaryNet' # Either 'BinaryNet' or 'XNORNet' 78 | 79 | parameters.lr *= parameters.batch_scale_factor 80 | parameters.batch_size *= parameters.batch_scale_factor 81 | 82 | print('Learning rate is: %f' % parameters.lr) 83 | print('Batch size is: %d' % parameters.batch_size) 84 | 85 | optimiser = Adam(lr=parameters.lr, decay=parameters.decay) 86 | 87 | ############################ 88 | # Data 89 | 90 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 91 | 92 | y_train = np.squeeze(y_train) 93 | y_test = np.squeeze(y_test) 94 | 95 | if len(X_train.shape) < 4: 96 | X_train = np.expand_dims(X_train, -1) 97 | X_test = np.expand_dims(X_test, -1) 98 | 99 | input_shape = X_train.shape[1:] 100 | 101 | X_train = X_train.astype('float32') 102 | X_test = X_test.astype('float32') 103 | 104 | X_train = X_train / 256.0 105 | X_test = X_test / 256.0 106 | 107 | nb_classes = y_train.max() + 1 108 | 109 | y_test_cat = np_utils.to_categorical(y_test, nb_classes + 1) 110 | y_train_cat = np_utils.to_categorical(y_train, nb_classes + 1) 111 | 112 | 113 | ############################ 114 | # Training 115 | 116 | model = CreateModel(input_shape=input_shape, nb_classes=nb_classes+1, parameters=parameters) 117 | 118 | model.compile(loss='categorical_crossentropy', 119 | optimizer=optimiser, 120 | metrics=['accuracy']) 121 | 122 | checkpointCallback = ModelCheckpoint(filepath=parameters.modelSaveName, verbose=1) 123 | bestCheckpointCallback = ModelCheckpoint(filepath=parameters.bestModelSaveName, verbose=1, save_best_only=True) 124 | 125 | model.fit(x=X_train, 126 | y=y_train_cat, 127 | batch_size=parameters.batch_size, 128 | epochs=parameters.nb_epochs, 129 | callbacks=[checkpointCallback, bestCheckpointCallback] + customLayerCallbacks, 130 | validation_data=(X_test, y_test_cat), 131 | shuffle=True, 132 | verbose=1 133 | ) 134 | 135 | 136 | print('Testing') 137 | modelTest = load_model(filepath=parameters.bestModelSaveName, custom_objects=customLayersDictionary) 138 | 139 | validationAccuracy = model.evaluate(X_test, y_test_cat, verbose=0) 140 | print('\nBest Keras validation accuracy is : %f \n' % (100.0 * validationAccuracy[1])) 141 | -------------------------------------------------------------------------------- /CompositeLayers/BinaryNetConvBNReluLayer.py: -------------------------------------------------------------------------------- 1 | from keras.layers import BatchNormalization, Activation 2 | from CustomLayers.BinaryNetLayer import BinaryNetConv2D, BinaryNetActivation 3 | from keras import backend as K 4 | 5 | def BinaryNetConvBNReluLayer(input, nb_filters, border, kernel_size, stride, use_bias=True, data_format='channels_last', use_activation=False): 6 | output = input 7 | 8 | # BinaryNet uses binarisation as the activation 9 | # To get the graphs to compile properly, add binarisation as a seperate layer to the input for theano 10 | # The tensorflow implementation contains the input binarisation inside the layer definition 11 | if K.backend() == 'theano': 12 | output = BinaryNetActivation()(output) 13 | 14 | output = BinaryNetConv2D(nb_filters, 15 | kernel_size, 16 | use_bias=use_bias, 17 | padding=border, 18 | strides=stride, 19 | data_format=data_format, 20 | )(output) 21 | 22 | # Add output binarisation as a seperate layer for Theano 23 | if K.backend() == 'theano': 24 | output = BinaryNetActivation()(output) 25 | 26 | output = BatchNormalization()(output) 27 | 28 | if use_activation: 29 | output = Activation('relu')(output) 30 | 31 | return output -------------------------------------------------------------------------------- /CompositeLayers/ConvBNReluLayer.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Activation, BatchNormalization 2 | from keras.layers import Convolution2D 3 | 4 | def ConvBNReluLayer(input, nb_filters, border, kernel_size, stride, use_bias=True, data_format='channels_last'): 5 | 6 | output = Convolution2D(filters=nb_filters, 7 | kernel_size=kernel_size, 8 | strides=stride, 9 | padding=border, 10 | data_format=data_format, 11 | use_bias=use_bias 12 | )(input) 13 | 14 | output = BatchNormalization()(output) 15 | output = Activation('relu')(output) 16 | 17 | return output 18 | -------------------------------------------------------------------------------- /CompositeLayers/XNORConvLayer.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Activation, BatchNormalization 2 | from CustomLayers.XNORNetLayer import XNORNetConv2D 3 | 4 | def BNXNORConvReluLayer(input, 5 | nb_filters, 6 | border, 7 | kernel_size, 8 | stride, 9 | use_BN=True, 10 | use_bias=False, 11 | use_activation=True, 12 | binarise_input=True, 13 | data_format='channels_last'): 14 | 15 | output = input 16 | 17 | if use_BN: 18 | output = BatchNormalization()(output) 19 | 20 | output = XNORNetConv2D(filters=nb_filters, 21 | kernel_size=kernel_size, 22 | use_bias=use_bias, 23 | padding=border, 24 | strides=stride, 25 | data_format=data_format, 26 | binarise_input=binarise_input 27 | )(output) 28 | 29 | if use_activation: 30 | output = Activation('relu')(output) 31 | 32 | return output 33 | 34 | 35 | def XNORConvBNReluLayer(input, 36 | nb_filters, 37 | border, 38 | kernel_size, 39 | stride, 40 | use_BN=True, 41 | use_bias=False, 42 | use_activation=True, 43 | binarise_input=True, 44 | data_format='channels_last'): 45 | 46 | output = input 47 | 48 | output = XNORNetConv2D(nb_filters=nb_filters, 49 | kernel_size=kernel_size, 50 | use_bias=use_bias, 51 | padding=border, 52 | strides=stride, 53 | data_format=data_format, 54 | binarise_input=binarise_input 55 | )(output) 56 | 57 | if use_BN: 58 | output = BatchNormalization()(output) 59 | 60 | if use_activation: 61 | output = Activation('relu')(output) 62 | 63 | return output 64 | -------------------------------------------------------------------------------- /CompositeLayers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaysummeriscoming/BinaryNet_and_XNORNet/070d8289249e432d187a588f0e26b2679b0d43f4/CompositeLayers/__init__.py -------------------------------------------------------------------------------- /CustomLayers/BinaryNetLayer.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import numpy as np 3 | from keras.engine import InputSpec 4 | from keras.engine import Layer 5 | from keras.layers import Convolution2D 6 | 7 | from CustomOps.customOps import passthroughSign 8 | 9 | 10 | class BinaryNetActivation(Layer): 11 | 12 | def __init__(self, **kwargs): 13 | super(BinaryNetActivation, self).__init__(**kwargs) 14 | # self.supports_masking = True 15 | 16 | def build(self, input_shape): 17 | super(BinaryNetActivation, self).build(input_shape) # Be sure to call this somewhere! 18 | 19 | def call(self, inputs): 20 | # In BinaryNet, the output activation is binarised (normally done at the input to each layer in our implementation) 21 | return passthroughSign(inputs) 22 | 23 | def get_config(self): 24 | base_config = super(BinaryNetActivation, self).get_config() 25 | return base_config 26 | 27 | def compute_output_shape(self, input_shape): 28 | return input_shape 29 | 30 | class BinaryNetConv2D(Convolution2D): 31 | """2D binary convolution layer (e.g. spatial convolution over images). 32 | 33 | This is an implementation of the BinaryNet layer described in: 34 | Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1 35 | 36 | It's based off the Convolution2D class, featuring an idential argument list. 37 | 38 | NOTE: The weight binarisation functionality is implemented using a 'on batch end' function, 39 | which must be called at the end of every batch (ideally using a callback). Currently this functionality 40 | is implemented using Numpy. In practice this incurs a negligible performance penalty, 41 | as this function uses far fewer operations than the base convolution operation. 42 | 43 | # Arguments 44 | Same as base Convolution2D layer 45 | 46 | # Input shape 47 | 4D tensor with shape: 48 | `(samples, channels, rows, cols)` if data_format='channels_first' 49 | or 4D tensor with shape: 50 | `(samples, rows, cols, channels)` if data_format='channels_last'. 51 | 52 | # Output shape 53 | 4D tensor with shape: 54 | `(samples, filters, new_rows, new_cols)` if data_format='channels_first' 55 | or 4D tensor with shape: 56 | `(samples, new_rows, new_cols, filters)` if data_format='channels_last'. 57 | `rows` and `cols` values might have changed due to padding. 58 | """ 59 | 60 | def build(self, input_shape): 61 | # Call the build function of the base class (in this case, convolution) 62 | # super(BinaryNetConv2D, self).build(input_shape) # Be sure to call this somewhere! 63 | super().build(input_shape) # Be sure to call this somewhere! 64 | 65 | # Get the initialised weights as save as the 'full precision' weights 66 | weights = K.get_value(self.weights[0]) 67 | self.fullPrecisionWeights = weights.copy() 68 | 69 | # Compute the binary approximated weights & save ready for the first batch 70 | B = np.sign(self.fullPrecisionWeights) 71 | self.lastIterationWeights = B.copy() 72 | K.set_value(self.weights[0], B) 73 | 74 | 75 | def call(self, inputs): 76 | 77 | # For theano, binarisation is done as a seperate layer 78 | if K.backend() == 'tensorflow': 79 | binarisedInput = passthroughSign(inputs) 80 | else: 81 | binarisedInput = inputs 82 | 83 | return super().call(binarisedInput) 84 | 85 | 86 | def on_batch_end(self): 87 | # Weight arrangement is: (kernel_size, kernel_size, num_input_channels, num_output_channels) 88 | # for both data formats in keras 2 notation 89 | 90 | # Work out the weights update from the last batch and then apply this to the full precision weights 91 | # The current weights correspond to the binarised weights + last batch update 92 | newWeights = K.get_value(self.weights[0]) 93 | weightsUpdate = newWeights - self.lastIterationWeights 94 | self.fullPrecisionWeights += weightsUpdate 95 | self.fullPrecisionWeights = np.clip(self.fullPrecisionWeights, -1., 1.) 96 | 97 | # Work out new approximated weights based off the full precision values 98 | B = np.sign(self.fullPrecisionWeights) 99 | 100 | # Save the weights, both in the keras kernel and a reference variable 101 | # so that we can compute the weights update that keras makes 102 | self.lastIterationWeights = B.copy() 103 | K.set_value(self.weights[0], B) 104 | -------------------------------------------------------------------------------- /CustomLayers/CustomLayersDictionary.py: -------------------------------------------------------------------------------- 1 | from keras.callbacks import Callback 2 | 3 | from CustomLayers.XNORNetLayer import XNORNetConv2D 4 | from CustomLayers.BinaryNetLayer import BinaryNetConv2D, BinaryNetActivation 5 | 6 | 7 | # This file holds a dictionary of all custom layers, for use when loading a Keras model 8 | customLayersDictionary = { 9 | "XNORNetConv2D": XNORNetConv2D, 10 | "BinaryNetActivation" : BinaryNetActivation, 11 | "BinaryNetConv2D" : BinaryNetConv2D, 12 | } 13 | 14 | 15 | class CustomLayerUpdate(Callback): 16 | 17 | def on_batch_begin(self, batch, logs=None): 18 | for curLayer in self.model.layers: 19 | CallMethodName(object=curLayer, fn_name='on_batch_begin') 20 | 21 | def on_batch_end(self, batch, logs=None): 22 | for curLayer in self.model.layers: 23 | CallMethodName(object=curLayer, fn_name='on_batch_end') 24 | 25 | def on_epoch_begin(self, epoch, logs=None): 26 | for curLayer in self.model.layers: 27 | CallMethodName(object=curLayer, fn_name='on_epoch_begin') 28 | 29 | def on_epoch_end(self, epoch, logs=None): 30 | for curLayer in self.model.layers: 31 | CallMethodName(object=curLayer, fn_name='on_epoch_end') 32 | 33 | 34 | # Call the desired class method if it exists 35 | def CallMethodName(object, fn_name): 36 | fn = getattr(object, fn_name, None) 37 | if callable(fn): 38 | fn() 39 | 40 | # Callbacks to implement custom layer specific code at the end of each training batch 41 | customLayerCallbacks = [CustomLayerUpdate()] 42 | -------------------------------------------------------------------------------- /CustomLayers/XNORNetLayer.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.layers import Convolution2D 3 | from keras.engine import InputSpec 4 | import numpy as np 5 | 6 | from CustomOps.customOps import passthroughSign 7 | 8 | class XNORNetConv2D(Convolution2D): 9 | """2D 'XNORNet' convolution layer (e.g. spatial convolution over images). 10 | 11 | This is an implementation of the XNORNet layer described in: 12 | XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks 13 | 14 | It's based off the Convolution2D class, featuring an identical argument list, with the addition of 15 | a 'binarise input' parameter. 16 | 17 | NOTE: The weight binarisation functionality is implemented using a 'on batch end' function, 18 | which must be called at the end of every batch (ideally using a callback). Currently this functionality 19 | is implemented using Numpy. In practice this incurs a negligible performance penalty, 20 | as this function uses far fewer operations than the base convolution operation. 21 | 22 | # Arguments 23 | Same as base Convolution2D layer, except: 24 | binarise_input: This controls whether we operate with just binary weights, or with binarised activations as well 25 | 26 | # Input shape 27 | 4D tensor with shape: 28 | `(samples, channels, rows, cols)` if data_format='channels_first' 29 | or 4D tensor with shape: 30 | `(samples, rows, cols, channels)` if data_format='channels_last'. 31 | 32 | # Output shape 33 | 4D tensor with shape: 34 | `(samples, filters, new_rows, new_cols)` if data_format='channels_first' 35 | or 4D tensor with shape: 36 | `(samples, new_rows, new_cols, filters)` if data_format='channels_last'. 37 | `rows` and `cols` values might have changed due to padding. 38 | """ 39 | 40 | def __init__(self, 41 | binarise_input=True, 42 | **kwargs): 43 | 44 | super().__init__(**kwargs) 45 | self.input_spec = InputSpec(ndim=4) 46 | 47 | self.binarise_input = binarise_input 48 | 49 | 50 | def build(self, input_shape): 51 | # Call the build function of the base class (in this case, convolution) 52 | super(XNORNetConv2D, self).build(input_shape) # Be sure to call this somewhere! 53 | 54 | # k filter should be of shape (filter_size, filter_size, 1, 1), following standard keras 2 notation 55 | k_numpy = np.ones(shape=(self.kernel_size[0], self.kernel_size[1], 1, 1)) 56 | k_numpy = k_numpy / np.sum(k_numpy) 57 | self.k_filter = K.variable(k_numpy, dtype='float32') 58 | 59 | weights = K.get_value(self.weights[0]) 60 | self.fullPrecisionWeights = weights.copy() 61 | 62 | B = np.sign(self.fullPrecisionWeights) 63 | 64 | # Calculate a seperate alpha value for each filter 65 | alpha = np.mean(np.abs(self.fullPrecisionWeights), axis=(0, 1, 2)) 66 | alphaB = np.broadcast_to(alpha, B.shape) 67 | 68 | newApproximatedWeights = np.multiply(alphaB, B) 69 | self.lastIterationWeights = newApproximatedWeights.copy() 70 | K.set_value(self.weights[0], newApproximatedWeights) 71 | 72 | 73 | def call(self, inputs): 74 | # Channels first arrangement: (batch_size, num_input_channels, width, height) 75 | # Channels last arrangement: (batch_size, width, height, num_input_channels) 76 | 77 | # If activation quantisation is enabled 78 | if self.binarise_input: 79 | 80 | # Compute the axis ID of the channels. Use tensorflow channels last arrangement as standard 81 | channels_axis = 3 82 | 83 | if self.data_format == 'channels_first': 84 | channels_axis = 1 85 | 86 | # Compute A, which is the average across channels. 87 | # The input will thus reduce to a single-channel image 88 | # In Keras, (minibatch_size, 1, height, width) 89 | A = K.mean(K.abs(inputs), axis=channels_axis, keepdims=True) 90 | 91 | # k filter should be of shape (filter_size, filter_size, 1, 1) as per keras 2 notation 92 | # K is of shape (batch_size, 1, width, height) (using channels first data format) 93 | K_variable = K.conv2d(A, 94 | self.k_filter, 95 | strides=self.strides, 96 | padding=self.padding, 97 | data_format=self.data_format, 98 | dilation_rate=self.dilation_rate) 99 | 100 | # Binarise the input 101 | binarisedInput = passthroughSign(inputs) 102 | 103 | # Call the base convolution operation 104 | # Convolution output will be of shape (batch_size, width, height, num_output_channels) (channels first) 105 | convolutionOutput = K.conv2d( 106 | binarisedInput, 107 | self.kernel, 108 | strides=self.strides, 109 | padding=self.padding, 110 | data_format=self.data_format, 111 | dilation_rate=self.dilation_rate) 112 | 113 | # Copy K for each output channel 114 | # K will thus go from shape (batch_size, 1, width, height) to (batch_size, 1, width, height) 115 | # (with channels_first data format) 116 | if K.backend() == 'tensorflow': 117 | K_variable = K.repeat_elements(K_variable, K.int_shape(convolutionOutput)[channels_axis], axis=channels_axis) 118 | else: 119 | K_variable = K.repeat_elements(K_variable, K.shape(convolutionOutput)[channels_axis], axis=channels_axis) 120 | 121 | outputs = K_variable * convolutionOutput 122 | 123 | return outputs 124 | 125 | else: 126 | # Call the base convolution operation. Only the weights are quantised in this case 127 | return super(XNORNetConv2D, self).call(inputs) 128 | 129 | 130 | def on_batch_end(self): 131 | # Weight arrangement is: (kernel_size, kernel_size, num_input_channels, num_output_channels) 132 | # for both data formats in keras 2 notation 133 | 134 | # Work out the weights update from the last batch and then apply this to the full precision weights 135 | # The current weights correspond to the binarised weights + last batch update 136 | newWeights = K.get_value(self.weights[0]) 137 | weightsUpdate = newWeights - self.lastIterationWeights 138 | self.fullPrecisionWeights += weightsUpdate 139 | 140 | # Calculate the binary 'B' and 'alpha' scaling factors for each filter 141 | B = np.sign(self.fullPrecisionWeights) 142 | alpha = np.mean(np.abs(self.fullPrecisionWeights), axis=(0, 1, 2)) 143 | alphaB = np.broadcast_to(alpha, B.shape) 144 | 145 | # Save the weights, both in the keras kernel and a reference variable 146 | # so that we can compute the weights update that keras makes 147 | newApproximatedWeights = np.multiply(alphaB, B) 148 | self.lastIterationWeights = newApproximatedWeights.copy() 149 | 150 | K.set_value(self.weights[0], newApproximatedWeights) 151 | -------------------------------------------------------------------------------- /CustomLayers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaysummeriscoming/BinaryNet_and_XNORNet/070d8289249e432d187a588f0e26b2679b0d43f4/CustomLayers/__init__.py -------------------------------------------------------------------------------- /CustomOps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaysummeriscoming/BinaryNet_and_XNORNet/070d8289249e432d187a588f0e26b2679b0d43f4/CustomOps/__init__.py -------------------------------------------------------------------------------- /CustomOps/customOps.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | 3 | if K.backend() == 'tensorflow': 4 | 5 | import CustomOps.tensorflowOps as tensorflowOps 6 | 7 | def passthroughSign(x): 8 | return tensorflowOps.passthroughSignTF(x) 9 | 10 | def passthroughTanh(x): 11 | return tensorflowOps.passthroughTanhTF(x) 12 | 13 | def SetSession(): 14 | tensorflowOps.SetSession() 15 | 16 | elif K.backend() == 'theano': 17 | 18 | from CustomOps.theanoOps import BinaryTanh 19 | 20 | def passthroughSign(x): 21 | return BinaryTanh(x) 22 | 23 | def passtrhoughTanh(x): 24 | assert "This op hasn't been programmed for theano yet" 25 | 26 | def SetSession(): 27 | empty = True 28 | 29 | 30 | else: 31 | raise NameError('backend not supported') -------------------------------------------------------------------------------- /CustomOps/tensorflowOps.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import tensorflow as tf 3 | from tensorflow.python.framework import function 4 | 5 | def clipped_passthrough_grad_multiply(op, grad): 6 | return [K.clip(grad, -1., 1.), K.clip(grad, -1., 1.)] 7 | 8 | def clipped_passthrough_grad(op, grad): 9 | return K.clip(grad, -1., 1.) 10 | 11 | 12 | def variable_tanh_grad(op, grad): 13 | 14 | # tanh_grad = 1. - (grad * grad) 15 | # 16 | # return (tanh_grad + grad) / 2. 17 | 18 | tanh_grad = (1. - (op * op)) * grad 19 | 20 | return tanh_grad 21 | 22 | 23 | def identity(op): 24 | return op 25 | 26 | # COMEBACKTO_PBL 27 | # Some defun examples: 28 | # https://stackoverflow.com/questions/38833934/write-custom-python-based-gradient-function-for-an-operation-without-c-imple 29 | # http://programtalk.com/python-examples/tensorflow.python.framework.function.Defun/ 30 | # 31 | # https://stackoverflow.com/questions/39605798/treating-a-tensorflow-defun-as-a-closure 32 | 33 | 34 | 35 | 36 | 37 | 38 | # @function.Defun(tf.float32, tf.float32, python_grad_func=clipped_passthrough_grad_multiply, func_name="passthroughMultiplyTF") 39 | @function.Defun(tf.float32, tf.float32, func_name="passthroughMultiplyTF") 40 | def passthroughMultiplyTF(x, y): 41 | x_new = tf.identity(x) 42 | y_new = tf.identity(y) 43 | output = x_new * y_new 44 | # output = tf.multiply(x_new, y_new) 45 | realOutput = tf.identity(output) 46 | 47 | return realOutput 48 | 49 | 50 | 51 | # @function.Defun(tf.float32, python_grad_func=sign_grad, shape_func=identity, func_name="passthroughSign") 52 | # @function.Defun(tf.float32, func_name="passthroughSign") 53 | @function.Defun(tf.float32, python_grad_func=clipped_passthrough_grad, func_name="passthroughSignTF") 54 | def passthroughSignTF(x): 55 | x_new = tf.identity(x) 56 | output = tf.sign(x_new) 57 | realOutput = tf.identity(output) 58 | 59 | return realOutput 60 | 61 | 62 | @function.Defun(tf.float32, python_grad_func=clipped_passthrough_grad, func_name="passthroughTanhTF") 63 | def passthroughTanhTF(x): 64 | x_new = tf.identity(x) 65 | output = tf.tanh(x_new) 66 | realOutput = tf.identity(output) 67 | 68 | return realOutput 69 | 70 | 71 | def SetSession(): 72 | print(tf.__version__) 73 | a = tf.Variable(tf.constant([-5., 4., -3., 2., 1.], dtype=tf.float32)) 74 | 75 | # Make sure there's a reference to our custom passthroughSign function so that tensorflow includes it 76 | grad = tf.gradients(passthroughSignTF(a), [a]) 77 | grad1 = tf.gradients(passthroughTanhTF(a), [a]) 78 | grad2 = tf.gradients(passthroughMultiplyTF(a, a), [a]) 79 | 80 | # COMEBACKTO_PBL: Testing multi-core usage 81 | jobs = 8 82 | 83 | config = tf.ConfigProto(intra_op_parallelism_threads=jobs, \ 84 | inter_op_parallelism_threads=jobs, \ 85 | allow_soft_placement=True, \ 86 | device_count={'CPU': jobs}) 87 | 88 | # Set a new keras tensorflow session so that all of our custom tensorflow code is included 89 | sess = tf.Session(config=config) 90 | K.set_session(sess) -------------------------------------------------------------------------------- /CustomOps/theanoOps.py: -------------------------------------------------------------------------------- 1 | from keras.backend import theano_backend as T 2 | from theano.scalar.basic import UnaryScalarOp, same_out_nocomplex 3 | from theano.tensor.elemwise import Elemwise 4 | 5 | # Our own rounding function that does not set the gradient to 0 like Theano's 6 | class __Round(UnaryScalarOp): 7 | 8 | def c_code(self, node, name, inputs, outputs, sub): 9 | x, = inputs 10 | z, = outputs 11 | return "%(z)s = round(%(x)s);" % locals() 12 | 13 | def grad(self, inputs, gout): 14 | (gz,) = gout 15 | return gz, 16 | 17 | __round_scalar = __Round(same_out_nocomplex, name='__round') 18 | __round = Elemwise(__round_scalar) 19 | 20 | def HardSigmoid(x): 21 | return T.clip((x+1.)/2.,0,1) 22 | 23 | # The neurons' activations binarization function 24 | # It behaves like the sign function during forward propagation 25 | # And like: 26 | # hard_tanh(x) = 2*hard_sigmoid(x)-1. 27 | # during back propagation 28 | def BinaryTanh(x): 29 | return 2.*__round(HardSigmoid(x))-1. -------------------------------------------------------------------------------- /NetworkParameters.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class NetworkParameters: 4 | def __init__(self, modelDirectory): 5 | self.modelDirectory = modelDirectory 6 | 7 | 8 | if os.path.exists(self.modelDirectory) is False: 9 | os.mkdir(self.modelDirectory) 10 | 11 | self.checkpointedModelDir = os.path.join(self.modelDirectory, 'savedModels') 12 | 13 | if os.path.exists(self.checkpointedModelDir) is False: 14 | os.mkdir(self.checkpointedModelDir) 15 | 16 | self.modelSaveName = os.path.join(self.checkpointedModelDir, 'model_{epoch:02d}.hdf5') 17 | self.bestModelSaveName = os.path.join(self.checkpointedModelDir, 'best_model.hdf5') 18 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | This project contains Keras implementations of the BinaryNet and XNORNet papers: 2 | 3 | [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830) 4 | 5 | 6 | [XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks](https://arxiv.org/abs/1603.05279) 7 | 8 | Code supports the Tensorflow and Theano backends. 9 | 10 | The most difficult part of coding these implementations was the sign function gradient. I’ve used the clipped ‘passthrough’ sign implementation detailed in the BinaryNet paper. The XNORNet doesn’t mention anything, so I’ve used the same implementation here too. 11 | 12 | NOTE: This code is Python 3 compatible only! 13 | --------------------------------------------------------------------------------