├── .gitignore ├── Examples └── LeNet │ ├── main.py │ └── model.py ├── Kernels ├── Makefile ├── QuantOp_binary.cc ├── QuantOp_halffp.cc ├── QuantOp_log.cc ├── QuantOp_sparse.cc ├── QuantOp_ternary.cc ├── README.md ├── RoundOp_down.cc ├── RoundOp_nearest.cc ├── RoundOp_stochastic.cc ├── RoundOp_zero.cc └── compile.sh ├── LICENSE ├── Quantize ├── .goutputstream-27Z54Y ├── FixedPoint.py ├── QLayer.py ├── QuantKernelWrapper.py ├── Quantizers.py ├── README.md ├── __init__.py ├── override.py ├── override_functions.py └── utils.py ├── README.md └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | *.so 3 | __pycache__ 4 | *.pyc 5 | *.o 6 | *.out 7 | startExpPipeline.sh 8 | *.log 9 | slim/*.sh 10 | experiment_results/ 11 | old/ 12 | -------------------------------------------------------------------------------- /Examples/LeNet/main.py: -------------------------------------------------------------------------------- 1 | #import sys 2 | #sys.path.append('..') 3 | 4 | # LeNet for MNIST using Keras and TensorFlow 5 | import tensorflow as tf 6 | 7 | from tensorflow.keras.optimizers import SGD 8 | from tensorflow.keras.datasets import mnist 9 | import numpy as np 10 | 11 | import model 12 | 13 | # Add this for TensorQuant 14 | from TensorQuant.Quantize import override 15 | 16 | def main(): 17 | 18 | # TensorQuant 19 | # Make sure the overrides are set before the model is created! 20 | # QQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ 21 | override.extr_q_map={"Conv1" : "nearest,12,11"} 22 | override.weight_q_map={ "Conv1" : "nearest,32,16", "Dense3" : "nearest,32,16"} 23 | # QQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ 24 | 25 | # Download the MNIST dataset 26 | dataset = mnist.load_data() 27 | 28 | train_data = dataset[0][0] 29 | train_labels = dataset[0][1] 30 | 31 | test_data = dataset[1][0] 32 | test_labels = dataset[1][1] 33 | 34 | # Reshape the data to a (70000, 28, 28, 1) tensord 35 | train_data = train_data.reshape([*train_data.shape,1]) / 255.0 36 | 37 | test_data = test_data.reshape([*test_data.shape,1]) / 255.0 38 | 39 | # Tranform training labels to one-hot encoding 40 | train_labels = np.eye(10)[train_labels] 41 | 42 | # Tranform test labels to one-hot encoding 43 | test_labels = np.eye(10)[test_labels] 44 | 45 | lenet = model.LeNet() 46 | 47 | lenet.summary() 48 | 49 | optimizer = tf.keras.optimizers.SGD(lr=0.01) 50 | 51 | # Compile the network 52 | lenet.compile( 53 | loss = "categorical_crossentropy", 54 | optimizer = optimizer, 55 | metrics = ["accuracy"]) 56 | 57 | # Callbacks 58 | callbacks_list=[] 59 | #callbacks_list.append(callbacks.WriteTrace("timeline_%02d.json"%(myRank), run_metadata) ) 60 | 61 | # Train the model 62 | lenet.fit( 63 | train_data, 64 | train_labels, 65 | batch_size = 128, 66 | nb_epoch = 1, 67 | verbose = 1, 68 | callbacks=callbacks_list) 69 | 70 | # Evaluate the model 71 | (loss, accuracy) = lenet.evaluate( 72 | test_data, 73 | test_labels, 74 | batch_size = 128, 75 | verbose = 1) 76 | # Print the model's accuracy 77 | print("Test accuracy: %.2f"%(accuracy)) 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /Examples/LeNet/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def LeNet(): 4 | # TensorQuant is sensitive to the exact identifiers. 5 | # It is advised to use the full name ('tf.keras.layers.SomeLayer') or use aliases like shown here. 6 | Convolution2D = tf.keras.layers.Convolution2D 7 | MaxPooling2D = tf.keras.layers.MaxPooling2D 8 | Flatten = tf.keras.layers.Flatten 9 | Dense = tf.keras.layers.Dense 10 | 11 | model = tf.keras.models.Sequential() 12 | 13 | with tf.name_scope("LeNet"): 14 | with tf.name_scope("Convolution_Block"): 15 | # Add the first convolution layer 16 | model.add(Convolution2D( 17 | filters = 20, 18 | kernel_size = (5, 5), 19 | padding = "same", 20 | input_shape = (28, 28, 1), 21 | activation="relu", 22 | name="Conv1")) 23 | 24 | # Add a pooling layer 25 | model.add(MaxPooling2D( 26 | pool_size = (2, 2), 27 | strides = (2, 2), 28 | name="MaxPool1")) 29 | 30 | # Add the second convolution layer 31 | model.add(Convolution2D( 32 | filters = 50, 33 | kernel_size = (5, 5), 34 | padding = "same", 35 | activation="relu", 36 | name="Conv2")) 37 | 38 | # Add a second pooling layer 39 | model.add(MaxPooling2D( 40 | pool_size = (2, 2), 41 | strides = (2, 2), 42 | name="MaxPool2")) 43 | 44 | # Flatten the network 45 | model.add(Flatten()) 46 | 47 | with tf.name_scope("Dense_Block"): 48 | # Add a fully-connected hidden layer 49 | model.add(Dense(500, 50 | activation="relu", 51 | name="Dense3")) 52 | 53 | # Add a fully-connected output layer 54 | model.add(Dense(10, 55 | activation="softmax", 56 | name="Dense4")) 57 | return model 58 | -------------------------------------------------------------------------------- /Kernels/Makefile: -------------------------------------------------------------------------------- 1 | CC=g++ 2 | TF_CFLAGS=$(shell python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') 3 | TF_LFLAGS=$(shell python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') 4 | CFLAGS=-std=c++11 -shared -fPIC $(TF_CFLAGS) $(TF_LFLAGS) -O2 5 | 6 | DEPS= 7 | OBJ=$(shell find . -maxdepth 1 -name '*.cc' | sed 's/.cc/.so/g' | sed 's|./||g') 8 | 9 | 10 | all: $(OBJ) 11 | 12 | %.so: %.cc $(DEPS) 13 | $(CC) $< -o $@ $(CFLAGS) 14 | 15 | clean: 16 | rm -f *.so 17 | 18 | .PHONY: all run clean nodes 19 | -------------------------------------------------------------------------------- /Kernels/QuantOp_binary.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | 6 | using namespace tensorflow; 7 | 8 | REGISTER_OP("QuantBinary") 9 | .Attr("marginal: float") 10 | .Attr("T: {float, double}") 11 | .Input("to_reshape: T") 12 | .Output("reshaped: T") 13 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 14 | c->set_output(0, c->input(0)); 15 | return Status::OK(); 16 | }); 17 | 18 | template 19 | class QuantOp : public OpKernel { 20 | public: 21 | explicit QuantOp(OpKernelConstruction* context) : OpKernel(context) { 22 | // Get Attributes 23 | OP_REQUIRES_OK(context, 24 | context->GetAttr("marginal", &marginal)); 25 | // Check Attributes 26 | OP_REQUIRES(context, marginal >= 0, 27 | errors::InvalidArgument("marginal needs to be positive, got ", 28 | marginal)); 29 | } 30 | 31 | void Compute(OpKernelContext* context) override { 32 | // Grab the input tensor 33 | const Tensor& input_tensor = context->input(0); 34 | auto input = input_tensor.flat(); 35 | 36 | // Create an output tensor 37 | Tensor* output_tensor = NULL; 38 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 39 | &output_tensor)); 40 | auto output = output_tensor->flat(); 41 | 42 | // set elements >=0 to marginal, or -marginal else 43 | const int N = input.size(); 44 | for (int i = 0; i < N; i++) { 45 | output(i) = input(i)>=0 ? marginal : -marginal; 46 | } 47 | } 48 | 49 | private: 50 | float marginal; 51 | }; 52 | 53 | REGISTER_KERNEL_BUILDER( 54 | Name("QuantBinary") 55 | .Device(DEVICE_CPU) 56 | .TypeConstraint("T"), 57 | QuantOp); 58 | 59 | REGISTER_KERNEL_BUILDER( 60 | Name("QuantBinary") 61 | .Device(DEVICE_CPU) 62 | .TypeConstraint("T"), 63 | QuantOp); 64 | -------------------------------------------------------------------------------- /Kernels/QuantOp_halffp.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | 6 | using namespace tensorflow; 7 | 8 | REGISTER_OP("QuantHalffp") 9 | .Attr("T: {float}") 10 | .Input("to_reshape: T") 11 | .Output("reshaped: T") 12 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 13 | c->set_output(0, c->input(0)); 14 | return Status::OK(); 15 | }); 16 | 17 | template 18 | class QuantOp : public OpKernel { 19 | private: 20 | static int basetable[512]; 21 | static unsigned int masktable[512]; 22 | public: 23 | explicit QuantOp(OpKernelConstruction* context) : OpKernel(context) { 24 | 25 | // basetable and masktable for half-precision floating point rounding 26 | 27 | 28 | for(unsigned int i=0; i<256; ++i){ 29 | int e=i-127; 30 | if(e<-24){ 31 | // Very small numbers map to zero 32 | basetable[i|0x000]=0x00000000; 33 | basetable[i|0x100]=0x80000000; 34 | masktable[i|0x000]=0x00000000; 35 | masktable[i|0x100]=0x00000000; 36 | } 37 | 38 | else if(e<-14){ // TODO: denorms in FP16 are cut at the end, but normalized 39 | // Small numbers map to denorms 40 | basetable[i|0x000]=((e+127)<<23); 41 | basetable[i|0x100]=((e+127)<<23) | 0x80000000; 42 | masktable[i|0x000]=(1<<23)-(1<<(-e-14+13)); 43 | masktable[i|0x100]=(1<<23)-(1<<(-e-14+13)); 44 | } 45 | else if(e<=15){ // TODO: cut normalized numbers 46 | // Normal numbers just lose precision 47 | basetable[i|0x000]=((e+127)<<23); 48 | basetable[i|0x100]=((e+127)<<23) | 0x80000000; 49 | masktable[i|0x000]=(1<<23)-(1<<(13)); 50 | masktable[i|0x100]=(1<<23)-(1<<(13)); 51 | } 52 | else if(e<128){ // TODO: FP32 infinity 53 | // Large numbers map to Infinity 54 | basetable[i|0x000]=0x7F800000; 55 | basetable[i|0x100]=0xFF800000; 56 | masktable[i|0x000]=0x00000000; 57 | masktable[i|0x100]=0x00000000; 58 | } 59 | else{ // TODO: do nothing 60 | // Infinity and NaN's stay Infinity and NaN's 61 | basetable[i|0x000]=0x7F800000; 62 | basetable[i|0x100]=0xFF800000; 63 | masktable[i|0x000]=(1<<23)-(1<<(13)); 64 | masktable[i|0x100]=(1<<23)-(1<<(13)); 65 | } 66 | } 67 | } 68 | 69 | 70 | void Compute(OpKernelContext* context) override { 71 | // Grab the input tensor 72 | const Tensor& input_tensor = context->input(0); 73 | auto input = input_tensor.flat(); 74 | 75 | // Create an output tensor 76 | Tensor* output_tensor = NULL; 77 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 78 | &output_tensor)); 79 | auto output = output_tensor->flat(); 80 | 81 | // change every element to half-precision floating point 82 | const int N = input.size(); 83 | unsigned int f, index, h; 84 | T result; 85 | for (int i = 0; i < N; i++) { 86 | f= *(unsigned int*)&input(i); //converter.i; 87 | index = (f>>23)&0x1ff; 88 | h=basetable[index]+(f & masktable[index]); 89 | T result = *(T*)&h; 90 | output(i) = result; 91 | } 92 | } 93 | }; 94 | template 95 | int QuantOp::basetable[512]; 96 | template 97 | unsigned int QuantOp::masktable[512]; 98 | 99 | REGISTER_KERNEL_BUILDER( 100 | Name("QuantHalffp") 101 | .Device(DEVICE_CPU) 102 | .TypeConstraint("T"), 103 | QuantOp); 104 | /* 105 | REGISTER_KERNEL_BUILDER( 106 | Name("QuantHalffp") 107 | .Device(DEVICE_CPU) 108 | .TypeConstraint("T"), 109 | QuantOp);*/ 110 | -------------------------------------------------------------------------------- /Kernels/QuantOp_log.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | 6 | using namespace tensorflow; 7 | 8 | #define EPS (0.000000001) 9 | 10 | REGISTER_OP("QuantLog") 11 | .Attr("T: {float, double}") 12 | .Input("to_reshape: T") 13 | .Output("reshaped: T") 14 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 15 | c->set_output(0, c->input(0)); 16 | return Status::OK(); 17 | }); 18 | 19 | template 20 | class QuantOp : public OpKernel { 21 | public: 22 | explicit QuantOp(OpKernelConstruction* context) : OpKernel(context) { 23 | } 24 | 25 | void Compute(OpKernelContext* context) override { 26 | // Grab the input tensor 27 | const Tensor& input_tensor = context->input(0); 28 | auto input = input_tensor.flat(); 29 | 30 | // Create an output tensor 31 | Tensor* output_tensor = NULL; 32 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 33 | &output_tensor)); 34 | auto output = output_tensor->flat(); 35 | 36 | // change every element to the nearest value of the form 2^i. 37 | const int N = input.size(); 38 | for (int i = 0; i < N; i++) { 39 | int exp=std::ilogb(std::abs(input(i))+EPS); 40 | float number = exp>0 ? 1<("T"), 50 | QuantOp); 51 | 52 | REGISTER_KERNEL_BUILDER( 53 | Name("QuantLog") 54 | .Device(DEVICE_CPU) 55 | .TypeConstraint("T"), 56 | QuantOp); 57 | -------------------------------------------------------------------------------- /Kernels/QuantOp_sparse.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | 6 | using namespace tensorflow; 7 | 8 | REGISTER_OP("QuantSparse") 9 | .Attr("threshold: float") 10 | .Attr("T: {float, double}") 11 | .Input("to_reshape: T") 12 | .Output("reshaped: T") 13 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 14 | c->set_output(0, c->input(0)); 15 | return Status::OK(); 16 | }); 17 | 18 | template 19 | class QuantOp : public OpKernel { 20 | public: 21 | explicit QuantOp(OpKernelConstruction* context) : OpKernel(context) { 22 | // Get Attributes 23 | OP_REQUIRES_OK(context, 24 | context->GetAttr("threshold", &threshold)); 25 | // Check Attributes 26 | OP_REQUIRES(context, threshold >= 0, 27 | errors::InvalidArgument("threshold needs to be positive, got ", 28 | threshold)); 29 | } 30 | 31 | void Compute(OpKernelContext* context) override { 32 | // Grab the input tensor 33 | const Tensor& input_tensor = context->input(0); 34 | auto input = input_tensor.flat(); 35 | 36 | // Create an output tensor 37 | Tensor* output_tensor = NULL; 38 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 39 | &output_tensor)); 40 | auto output = output_tensor->flat(); 41 | 42 | // every element whose magnitude is below the threshold is set to 0. 43 | const int N = input.size(); 44 | for (int i = 0; i < N; i++) { 45 | output(i) = std::abs(input(i))>threshold ? input(i) : 0; 46 | } 47 | } 48 | 49 | private: 50 | float threshold; 51 | }; 52 | 53 | REGISTER_KERNEL_BUILDER( 54 | Name("QuantSparse") 55 | .Device(DEVICE_CPU) 56 | .TypeConstraint("T"), 57 | QuantOp); 58 | 59 | REGISTER_KERNEL_BUILDER( 60 | Name("QuantSparse") 61 | .Device(DEVICE_CPU) 62 | .TypeConstraint("T"), 63 | QuantOp); 64 | -------------------------------------------------------------------------------- /Kernels/QuantOp_ternary.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | #include 6 | 7 | using namespace tensorflow; 8 | 9 | REGISTER_OP("QuantTernary") 10 | .Attr("marginal: float") 11 | .Attr("auto_threshold: bool=true") 12 | .Attr("threshold: float=0.5") 13 | .Attr("T: {float, double}") 14 | .Input("to_reshape: T") 15 | .Output("reshaped: T") 16 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 17 | c->set_output(0, c->input(0)); 18 | return Status::OK(); 19 | }); 20 | 21 | template 22 | class QuantOp : public OpKernel { 23 | public: 24 | explicit QuantOp(OpKernelConstruction* context) : OpKernel(context) { 25 | // Get Attributes 26 | OP_REQUIRES_OK(context, 27 | context->GetAttr("marginal", &marginal)); 28 | OP_REQUIRES_OK(context, 29 | context->GetAttr("auto_threshold", &auto_threshold)); 30 | OP_REQUIRES_OK(context, 31 | context->GetAttr("threshold", &threshold)); 32 | // Check Attributes 33 | OP_REQUIRES(context, marginal >= 0, 34 | errors::InvalidArgument("marginal needs to be positive, got ", 35 | marginal)); 36 | OP_REQUIRES(context, threshold >= 0, 37 | errors::InvalidArgument("threshold needs to be positive, got ", 38 | threshold)); 39 | } 40 | 41 | void Compute(OpKernelContext* context) override { 42 | // Grab the input tensor 43 | const Tensor& input_tensor = context->input(0); 44 | auto input = input_tensor.flat(); 45 | //Eigen::Tensor input = input_tensor.flat(); 46 | // Create an output tensor 47 | Tensor* output_tensor = NULL; 48 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 49 | &output_tensor)); 50 | auto output = output_tensor->flat(); 51 | 52 | // set elements >threshold to marginal, > 58 | ( input.data(), input.size() ); 59 | 60 | float abs_sum = abs_input.abs().sum(); 61 | 62 | threshold = 0.7 * abs_sum / N; 63 | //std::cout<< threshold << std::endl; 64 | } 65 | for (int i = 0; i < N; i++) { 66 | if(input(i)>threshold) { 67 | output(i)=marginal; 68 | } 69 | else if(input(i)<-threshold) { 70 | output(i)=-marginal; 71 | } 72 | else { 73 | output(i)=0; 74 | } 75 | } 76 | } 77 | 78 | private: 79 | float marginal; 80 | bool auto_threshold; 81 | float threshold; 82 | }; 83 | 84 | REGISTER_KERNEL_BUILDER( 85 | Name("QuantTernary") 86 | .Device(DEVICE_CPU) 87 | .TypeConstraint("T"), 88 | QuantOp); 89 | 90 | REGISTER_KERNEL_BUILDER( 91 | Name("QuantTernary") 92 | .Device(DEVICE_CPU) 93 | .TypeConstraint("T"), 94 | QuantOp); 95 | -------------------------------------------------------------------------------- /Kernels/README.md: -------------------------------------------------------------------------------- 1 | # Kernels 2 | This folder contains the C-kernels used by the quantizers. 3 | 4 | ## Structure 5 | **compile.sh** - This script can be used to compile a single C-file. 6 | 7 | **Makefile** - This makefile is set up to compile all files ending in '.cc' as TensorFlow kernels. 8 | 9 | **KernelName.cc** - The C-kernels. 10 | 11 | ## Compilation 12 | Simply run 13 | ``` 14 | make all 15 | ``` 16 | in this folder to compile all kernels. There can be issues with the compiler flags ( especially -D_GLIBCXX_USE_CXX11_ABI=0; see [this link](https://www.tensorflow.org/extend/adding_an_op) (under 'Build the op library') for more help on this topic.). 17 | 18 | Single kernels can also be compiled with: 19 | ``` 20 | source compile.sh kernel_name 21 | ``` 22 | Pass the kernel file without the filename extension (.cc)! 23 | 24 | ## Details 25 | A description on how to implement kernels is given [here](https://www.tensorflow.org/extend/adding_an_op). 26 | -------------------------------------------------------------------------------- /Kernels/RoundOp_down.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | 6 | using namespace tensorflow; 7 | 8 | REGISTER_OP("RoundDown") 9 | .Attr("fixed_size: int") 10 | .Attr("fixed_prec: int") 11 | .Attr("T: {float, double}") 12 | .Input("to_reshape: T") 13 | .Output("reshaped: T") 14 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 15 | c->set_output(0, c->input(0)); 16 | return Status::OK(); 17 | }); 18 | 19 | template 20 | class RoundOp : public OpKernel { 21 | public: 22 | explicit RoundOp(OpKernelConstruction* context) : OpKernel(context) { 23 | // Get Attributes 24 | OP_REQUIRES_OK(context, 25 | context->GetAttr("fixed_size", &fixed_size)); 26 | OP_REQUIRES_OK(context, 27 | context->GetAttr("fixed_prec", &fixed_prec)); 28 | // Check Attributes 29 | OP_REQUIRES(context, fixed_size > 0, 30 | errors::InvalidArgument("fixed_size needs to be bigger than 0, got ", 31 | fixed_size)); 32 | OP_REQUIRES(context, fixed_prec >= 0 && fixed_prec < fixed_size, 33 | errors::InvalidArgument("fixed_prec needs to be between 0 and fixed_size, got ", 34 | fixed_prec)); 35 | 36 | } 37 | 38 | void Compute(OpKernelContext* context) override { 39 | // Grab the input tensor 40 | const Tensor& input_tensor = context->input(0); 41 | auto input = input_tensor.flat(); 42 | 43 | // Create an output tensor 44 | Tensor* output_tensor = NULL; 45 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 46 | &output_tensor)); 47 | auto output = output_tensor->flat(); 48 | 49 | // truncate every element of the tensor to fixed point 50 | const int N = input.size(); 51 | const T fixed_max_signed = ((T)(1UL<<(fixed_size-1))-1)/(1UL< -0.5 ) 57 | T fixed_number = floor(input(i)*(1UL<("T"), 72 | RoundOp); 73 | 74 | REGISTER_KERNEL_BUILDER( 75 | Name("RoundDown") 76 | .Device(DEVICE_CPU) 77 | .TypeConstraint("T"), 78 | RoundOp); 79 | -------------------------------------------------------------------------------- /Kernels/RoundOp_nearest.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | 6 | using namespace tensorflow; 7 | 8 | REGISTER_OP("RoundNearest") 9 | .Attr("fixed_size: int") 10 | .Attr("fixed_prec: int") 11 | .Attr("T: {float, double}") 12 | .Input("to_reshape: T") 13 | .Output("reshaped: T") 14 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 15 | c->set_output(0, c->input(0)); 16 | return Status::OK(); 17 | }); 18 | 19 | template 20 | class RoundOp : public OpKernel { 21 | public: 22 | explicit RoundOp(OpKernelConstruction* context) : OpKernel(context) { 23 | // Get Attributes 24 | OP_REQUIRES_OK(context, 25 | context->GetAttr("fixed_size", &fixed_size)); 26 | OP_REQUIRES_OK(context, 27 | context->GetAttr("fixed_prec", &fixed_prec)); 28 | // Check Attributes 29 | OP_REQUIRES(context, fixed_size > 0, 30 | errors::InvalidArgument("fixed_size needs to be bigger than 0, got ", 31 | fixed_size)); 32 | OP_REQUIRES(context, fixed_prec >= 0 && fixed_prec < fixed_size, 33 | errors::InvalidArgument("fixed_prec needs to be between 0 and fixed_size, got ", 34 | fixed_prec)); 35 | 36 | } 37 | 38 | void Compute(OpKernelContext* context) override { 39 | // Grab the input tensor 40 | const Tensor& input_tensor = context->input(0); 41 | auto input = input_tensor.flat(); 42 | 43 | // Create an output tensor 44 | Tensor* output_tensor = NULL; 45 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 46 | &output_tensor)); 47 | auto output = output_tensor->flat(); 48 | 49 | // truncate every element of the tensor to fixed point 50 | const int N = input.size(); 51 | const T fixed_max_signed = ((T)(1UL<<(fixed_size-1))-1)/(1UL< 0, -0.5 ->-1 ) 57 | T fixed_number = round(input(i)*(1UL<("T"), 72 | RoundOp); 73 | 74 | REGISTER_KERNEL_BUILDER( 75 | Name("RoundNearest") 76 | .Device(DEVICE_CPU) 77 | .TypeConstraint("T"), 78 | RoundOp); 79 | -------------------------------------------------------------------------------- /Kernels/RoundOp_stochastic.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | #include /* srand, rand */ 6 | #include /* time */ 7 | 8 | #define RND_ENTRIES 1000 9 | 10 | using namespace tensorflow; 11 | 12 | REGISTER_OP("RoundStochastic") 13 | .Attr("fixed_size: int") 14 | .Attr("fixed_prec: int") 15 | .Attr("T: {float, double}") 16 | .Input("to_reshape: T") 17 | .Output("reshaped: T") 18 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 19 | c->set_output(0, c->input(0)); 20 | return Status::OK(); 21 | }); 22 | 23 | template 24 | class RoundOp : public OpKernel { 25 | public: 26 | explicit RoundOp(OpKernelConstruction* context) : OpKernel(context) { 27 | // Get Attributes 28 | OP_REQUIRES_OK(context, 29 | context->GetAttr("fixed_size", &fixed_size)); 30 | OP_REQUIRES_OK(context, 31 | context->GetAttr("fixed_prec", &fixed_prec)); 32 | // Check Attributes 33 | OP_REQUIRES(context, fixed_size > 0, 34 | errors::InvalidArgument("fixed_size needs to be bigger than 0, got ", 35 | fixed_size)); 36 | OP_REQUIRES(context, fixed_prec >= 0 && fixed_prec < fixed_size, 37 | errors::InvalidArgument("fixed_prec needs to be between 0 and fixed_size, got ", 38 | fixed_prec)); 39 | rnd_counter=0; 40 | srand (time(NULL)); 41 | for(int i=0; iinput(0); 49 | auto input = input_tensor.flat(); 50 | 51 | // Create an output tensor 52 | Tensor* output_tensor = NULL; 53 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 54 | &output_tensor)); 55 | auto output = output_tensor->flat(); 56 | 57 | // truncate every element of the tensor to fixed point 58 | const int N = input.size(); 59 | const T fixed_max_signed = ((T)(1UL<<(fixed_size-1))-1)/(1UL<("T"), 91 | RoundOp); 92 | 93 | REGISTER_KERNEL_BUILDER( 94 | Name("RoundStochastic") 95 | .Device(DEVICE_CPU) 96 | .TypeConstraint("T"), 97 | RoundOp); 98 | -------------------------------------------------------------------------------- /Kernels/RoundOp_zero.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/shape_inference.h" 3 | #include "tensorflow/core/framework/op_kernel.h" 4 | #include 5 | 6 | using namespace tensorflow; 7 | 8 | REGISTER_OP("RoundZero") 9 | .Attr("fixed_size: int") 10 | .Attr("fixed_prec: int") 11 | .Attr("T: {float, double}") 12 | .Input("to_reshape: T") 13 | .Output("reshaped: T") 14 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 15 | c->set_output(0, c->input(0)); 16 | return Status::OK(); 17 | }); 18 | 19 | template 20 | class RoundOp : public OpKernel { 21 | public: 22 | explicit RoundOp(OpKernelConstruction* context) : OpKernel(context) { 23 | // Get Attributes 24 | OP_REQUIRES_OK(context, 25 | context->GetAttr("fixed_size", &fixed_size)); 26 | OP_REQUIRES_OK(context, 27 | context->GetAttr("fixed_prec", &fixed_prec)); 28 | // Check Attributes 29 | OP_REQUIRES(context, fixed_size > 0, 30 | errors::InvalidArgument("fixed_size needs to be bigger than 0, got ", 31 | fixed_size)); 32 | OP_REQUIRES(context, fixed_prec >= 0 && fixed_prec < fixed_size, 33 | errors::InvalidArgument("fixed_prec needs to be between 0 and fixed_size, got ", 34 | fixed_prec)); 35 | 36 | } 37 | 38 | void Compute(OpKernelContext* context) override { 39 | // Grab the input tensor 40 | const Tensor& input_tensor = context->input(0); 41 | auto input = input_tensor.flat(); 42 | 43 | // Create an output tensor 44 | Tensor* output_tensor = NULL; 45 | OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), 46 | &output_tensor)); 47 | auto output = output_tensor->flat(); 48 | 49 | // truncate every element of the tensor to fixed point 50 | const int N = input.size(); 51 | const T fixed_max_signed = ((T)(1UL<<(fixed_size-1))-1)/(1UL< 0 ) 57 | // T fixed_number = floor(std::abs(input(i))*(1UL<("T"), 73 | RoundOp); 74 | 75 | REGISTER_KERNEL_BUILDER( 76 | Name("RoundZero") 77 | .Device(DEVICE_CPU) 78 | .TypeConstraint("T"), 79 | RoundOp); 80 | -------------------------------------------------------------------------------- /Kernels/compile.sh: -------------------------------------------------------------------------------- 1 | TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') 2 | TF_LIB=$(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') 3 | g++ -std=c++11 -shared $1.cc -o $1.so -fPIC -I $TF_INC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 Fraunhofer ITWM. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "{}" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2018 Dominik Marek Loroch, Fraunhofer ITWM 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /Quantize/.goutputstream-27Z54Y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cc-hpc-itwm/TensorQuant/bb14aacb489d8b5838c141c82b5b5d8c605202ba/Quantize/.goutputstream-27Z54Y -------------------------------------------------------------------------------- /Quantize/FixedPoint.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.framework import ops 4 | import os 5 | 6 | local_dir = os.path.dirname(__file__) 7 | 8 | round_module_zero = tf.load_op_library(local_dir+'/../Kernels/RoundOp_zero.so') 9 | round_module_down = tf.load_op_library(local_dir+'/../Kernels/RoundOp_down.so') 10 | round_module_nearest = tf.load_op_library(local_dir+'/../Kernels/RoundOp_nearest.so') 11 | round_module_stochastic = tf.load_op_library(local_dir+'/../Kernels/RoundOp_stochastic.so') 12 | 13 | def round_zero(input1, fixed_size, fixed_prec): 14 | ''' Truncates input1 to a fixed point number. 15 | Rounds towards zero.''' 16 | result = round_module_zero.round_zero(input1,fixed_size=fixed_size,fixed_prec=fixed_prec) 17 | return result 18 | @ops.RegisterGradient("RoundZero") 19 | def _round_zero_grad(op, grad): 20 | return [grad] 21 | 22 | def round_down(input1, fixed_size, fixed_prec): 23 | ''' Truncates input1 to a fixed point number. 24 | Rounds down towards negative numbers.''' 25 | result = round_module_down.round_down(input1,fixed_size=fixed_size,fixed_prec=fixed_prec) 26 | return result 27 | @ops.RegisterGradient("RoundDown") 28 | def _round_down_grad(op, grad): 29 | return [grad] 30 | 31 | def round_nearest(input1, fixed_size, fixed_prec): 32 | ''' Truncates input1 to a fixed point number. 33 | Rounds towards the nearest number.''' 34 | result = round_module_nearest.round_nearest(input1,fixed_size=fixed_size,fixed_prec=fixed_prec) 35 | return result 36 | @ops.RegisterGradient("RoundNearest") 37 | def _round_nearest_grad(op, grad): 38 | return [grad] 39 | 40 | def round_stochastic(input1, fixed_size, fixed_prec): 41 | ''' Truncates input1 to a fixed point number. 42 | Rounds randomly, but the fractional part influences the probability.''' 43 | result = round_module_stochastic.round_stochastic(input1,fixed_size=fixed_size,fixed_prec=fixed_prec) 44 | return result 45 | @ops.RegisterGradient("RoundStochastic") 46 | def _round_stochastic_grad(op, grad): 47 | return [grad] 48 | 49 | 50 | # Fixed point functions, python implementation 51 | # ------------------------------------------------- 52 | # for debugging 53 | def toFixed(fp_number, fixed_size, fixed_prec): 54 | ''' Turns the elements of a floating point numpy matrix into fixed point equivalents with bitwidth fixed_size and fractional bits fixed_prec.''' 55 | fixed_max_signed = (2.0**(fixed_size-1)-1)/(2**fixed_prec) # maximum value of fixed representation 56 | fixed_min_signed = -(2.0**(fixed_size-fixed_prec-1)) # minimum (negative) value of fixed representation 57 | # adjust fractional part (round towards zero) 58 | fixed_number = np.multiply( ((np.absolute(fp_number)*2**fixed_prec)//1) /2**fixed_prec , (np.sign(fp_number)) ) 59 | # handle overflow (saturate number towards maximum or minimum) 60 | fixed_number = np.maximum(np.minimum(fixed_number,fixed_max_signed), fixed_min_signed) 61 | return fixed_number 62 | 63 | def fixTensor(tensor, session, fixed_size, fixed_prec): 64 | ''' Truncates elements of a tensor to fixed point representation. ''' 65 | tensor_a = session.run(tensor) 66 | tensor_a = toFixed(tensor_a, fixed_size, fixed_prec) 67 | tensor.load(tensor_a,session) 68 | return tensor 69 | 70 | ZERO_ROUND=False 71 | def FixedPointOp(tensor, fixed_size, fixed_prec): 72 | fixed_max_signed = (2**(fixed_size-1)-1)/(2**fixed_prec) 73 | fixed_min_signed = -(2**(fixed_size-fixed_prec-1)) 74 | 75 | if ZERO_ROUND: 76 | # rounds towards zero (e.g. -0.001 -> 0 ) 77 | tensor = tf.floor(tf.abs(tensor)*(2**fixed_prec)) / (2**fixed_prec) * tf.sign(tensor); 78 | else: 79 | # rounds towards negative (e.g. -0.001 -> -0.5 ) 80 | tensor = tf.floor(tensor*(2**fixed_prec)) / (2**fixed_prec); 81 | 82 | tensor = tf.maximum(tf.minimum(tensor,tf.ones(tensor.shape)*fixed_max_signed), tf.ones(tensor.shape)*fixed_min_signed); 83 | return tensor 84 | -------------------------------------------------------------------------------- /Quantize/QLayer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | ################################## 4 | ### Keras Layer implementation ### 5 | ################################## 6 | # tensorflow/python/keras/layers.py 7 | def create_qLayer(layer, intr_quantizer=None, extr_quantizer=None, weight_quantizer=None, name=None): 8 | class QLayer(tf.keras.layers.Layer): 9 | 10 | def __init__(self, name, config): 11 | super(self.__class__, self).__init__(**config) 12 | self.intr_quantizer = intr_quantizer 13 | self.extr_quantizer = extr_quantizer 14 | self.weight_quantizer = weight_quantizer 15 | 16 | #if name is None: 17 | # name = "Quantized_%s" % self.__class__.__base__.__name__ 18 | #self._name = name 19 | 20 | 21 | def build(self, input_shape): 22 | super(self.__class__, self).build(input_shape) 23 | # quantize weights 24 | if self.weight_quantizer is not None: 25 | weight_list = [w.name for w in self.weights] 26 | weight_list = [w.split("/")[-1][:-2] for w in weight_list] #[:-2] to remove ':0' 27 | for w in weight_list: 28 | with tf.name_scope(w+"/quant"): 29 | setattr(self, w, self.weight_quantizer(getattr(self, w))) 30 | #tf.add_to_collection("quantized_variables", getattr(self, w)) 31 | 32 | 33 | def call(self, inputs): 34 | output = super(self.__class__, self).call(inputs) 35 | if self.extr_quantizer is not None: 36 | with tf.name_scope("output/quant"): 37 | output = self.extr_quantizer(output) 38 | #tf.add_to_collection("quantized_outputs", output) 39 | return output 40 | 41 | cls = type(layer.__class__.__name__, (layer.__class__,), 42 | dict(QLayer.__dict__)) 43 | return cls(name, layer.get_config()) 44 | -------------------------------------------------------------------------------- /Quantize/QuantKernelWrapper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import ops 3 | import os 4 | 5 | local_dir = os.path.dirname(__file__) 6 | 7 | quant_module_log = tf.load_op_library(local_dir+'/../Kernels/QuantOp_log.so') 8 | quant_module_sparse = tf.load_op_library(local_dir+'/../Kernels/QuantOp_sparse.so') 9 | quant_module_halffp = tf.load_op_library(local_dir+'/../Kernels/QuantOp_halffp.so') 10 | quant_module_binary = tf.load_op_library(local_dir+'/../Kernels/QuantOp_binary.so') 11 | quant_module_ternary = tf.load_op_library(local_dir+'/../Kernels/QuantOp_ternary.so') 12 | 13 | def quant_log(input1): 14 | '''Takes closest value of input to the form +/- 2^i.''' 15 | result = quant_module_log.quant_log(input1) 16 | return result 17 | @ops.RegisterGradient("QuantLog") 18 | def _quant_log_grad(op, grad): 19 | return [grad] 20 | 21 | def quant_sparse(input1, threshold): 22 | '''Every element whose magnitude is below the threshold is set to 0.''' 23 | result = quant_module_sparse.quant_sparse(input1, threshold=threshold) 24 | return result 25 | @ops.RegisterGradient("QuantSparse") 26 | def _quant_sparse_grad(op, grad): 27 | return [grad] 28 | 29 | def quant_halffp(input1): 30 | '''Rounds to half-precision floating point''' 31 | result = quant_module_halffp.quant_halffp(input1) 32 | return result 33 | @ops.RegisterGradient("QuantHalffp") 34 | def _quant_halffp_grad(op, grad): 35 | return [grad] 36 | 37 | def quant_binary(input1, marginal): 38 | '''Binarizes to +/- marginal value''' 39 | result = quant_module_binary.quant_binary(input1, marginal) 40 | return result 41 | @ops.RegisterGradient("QuantBinary") 42 | def _quant_binary_grad(op, grad): 43 | return [grad] 44 | 45 | def quant_ternary(input1, marginal, auto_threshold=True, threshold=0.5): 46 | '''Ternary quantization to +/- marginal and 0''' 47 | result = quant_module_ternary.quant_ternary(input1, marginal, auto_threshold, threshold) 48 | return result 49 | @ops.RegisterGradient("QuantTernary") 50 | def _quant_ternary_grad(op, grad): 51 | return [grad] 52 | -------------------------------------------------------------------------------- /Quantize/Quantizers.py: -------------------------------------------------------------------------------- 1 | from TensorQuant.Quantize import FixedPoint 2 | from TensorQuant.Quantize import QuantKernelWrapper as Wrapped 3 | import tensorflow as tf 4 | 5 | class Quantizer_if(): 6 | """Interface for quantizer classes""" 7 | def __str__(self): 8 | return self.__class__.__name__ 9 | def quantize(self,tensor): 10 | raise NotImplementedError 11 | def __call__(self, tensor): 12 | return self.quantize(tensor) 13 | 14 | ############################### 15 | ### Fixed Point 16 | ############################### 17 | 18 | class FixedPointQuantizer_zero(Quantizer_if): 19 | """Fixed point quantization with fixed_size bits and fixed_prec fractional bits. 20 | Uses c-kernel for quantization. 21 | """ 22 | def __init__(self, fixed_size, fixed_prec): 23 | self.fixed_size=fixed_size 24 | self.fixed_prec=fixed_prec 25 | self.fixed_max_signed = (1<<(fixed_size-1)-1)/(1<=0, or -marginal if input <0. 195 | """ 196 | def __init__(self, marginal): 197 | self.marginal=marginal 198 | def C_quantize(self,tensor): 199 | return Wrapped.quant_binary(tensor, self.marginal) 200 | def quantize(self, tensor): 201 | @tf.custom_gradient 202 | def op(tensor): 203 | def grad(dy): 204 | return dy 205 | out = (tf.dtypes.cast(tf.greater_equal(tensor,0),tensor.dtype)*2-1)*self.marginal 206 | out = tf.identity(out, name=str(self)+"_output") 207 | return out, grad 208 | return op(tensor) 209 | 210 | ############################### 211 | ### Ternary 212 | ############################### 213 | class TernaryQuantizer(Quantizer_if): 214 | """Ternary quantization. Rounds to +marginal if input is >threshold, -marginal if input <-threshold or 0. 215 | 216 | """ 217 | def __init__(self, marginal, auto_threshold=True, threshold=0.5): 218 | self.marginal=marginal 219 | self.auto_threshold=auto_threshold 220 | self.threshold=threshold 221 | def C_quantize(self,tensor): 222 | return Wrapped.quant_ternary(tensor, self.marginal, self.auto_threshold, self.threshold) 223 | def quantize(self, tensor): 224 | @tf.custom_gradient 225 | def op(tensor): 226 | def grad(dy): 227 | return dy 228 | if self.auto_threshold: 229 | threshold = -0.7 * tf.math.reduce_sum(tf.math.abs(tensor))/tf.dtypes.cast(tf.size(tensor),tensor.dtype) 230 | else: 231 | threshold = self.threshold 232 | out = tf.ones_like(tensor)*-1 233 | out += tf.dtypes.cast(tf.greater(tensor, -threshold),tensor.dtype) 234 | out += tf.dtypes.cast(tf.greater(tensor, threshold),tensor.dtype) 235 | out *= self.marginal 236 | out = tf.identity(out, name=str(self)+"_output") 237 | return out, grad 238 | return op(tensor) 239 | 240 | 241 | 242 | ############################### 243 | ### Other 244 | ############################### 245 | class NoQuantizer(Quantizer_if): 246 | """Applies no quantization to the tensor""" 247 | def quantize(self,tensor): 248 | return tensor 249 | -------------------------------------------------------------------------------- /Quantize/README.md: -------------------------------------------------------------------------------- 1 | # Quantize 2 | This directory contains the TensorQuant core mechanics. 3 | 4 | ## Structure 5 | 6 | **Quantizers.py** - Definition of the Quantizer objects. Use these in your quantizer maps. 7 | 8 | **utils.py** - Utilities to generate quantization maps. 9 | 10 | **override.py** - defines the active overrides and the global quantization dictionaries. 11 | 12 | **FixedPoint.py** - Python wrappers for the fixed point quantization kernels. 13 | 14 | **QuantKernelWrapper.py** - Python wrappers for non-fixed point quantization kernels. 15 | 16 | **QLayer.py** - Defines a generic Keras Layer with enabled quantization. 17 | 18 | **override_functions.py** Contains the "generic_keras_override" function, which decides which layers are to be hijacked. 19 | 20 | ## Remarks 21 | 22 | ### Quantizers.py 23 | - Every Quantizer implements the quantizer interface defined by 'Quantizer_if'. It takes a tensor as an argument and returns a quantized tensor. Use this interface to add additional quantizers. 24 | - There is a 'NoQuantizer' quantizer, which simply returns the unquantized tensor. This quantizer is used for debugging. 25 | - Most of the quantizers are implemented with Tensorflow layers. The C-Code kernels can be called (but often not used) with the "C_quantize" method. In fact, if the compilation of the C-code kernels should not work, one can work around that step. 26 | - It is possible to define custom gradients for the Quantizers ("def grad(dy):"). The gradients in the TensorQuant quantizers are initially straight through, but can be modified. 27 | 28 | ### utils.py 29 | - the function "quantizer_map" can be used to generate quantizer maps. The function takes either a .json file, or a dictionary with following structure: 30 | ```json 31 | { 32 | "Layer_name" : "Quantizer_shortcut_string" 33 | } 34 | ``` 35 | The quantizer shortcut strings are defined in the same file in the "quantizer_selector" function (e.g. "nearest,32,16" would create a fixed point quantization with 32bits and 16bit fractional part). 36 | The layer names do not require to match the real names entirely, but every layer which contains a matching substring will be quantized with the given quantizer. This allows to quantize entire blocks of layers. As of writing this readme, there is an issue with the "tf.name_scope" feature together with Keras layers, so it is not a reliable way to structure your network. 37 | 38 | ### override.py 39 | - The available overrides do not cover all Keras layers. However, the overrides can be easily extended for other Keras layers with: 40 | ``` python 41 | keras_SomeLayer = tf.keras.layers.SomeLayer 42 | keras_SomeLayer_override = generic_keras_override(keras_SomeLayer) 43 | # override the Keras layer 44 | tf.keras.layers.SomeLayer = keras_SomeLayer_override 45 | # optional: override for aliases 46 | tf.keras.layers.SomeLayer_alias = keras_conv2d_override 47 | ``` 48 | - the "intr_q_map" has no effect in this version of TensorQuant 49 | - the "extr_q_map" (for activations) and "weight_q_map" (for weights and biases) dictionaries can be written with a dictionary defining the desired quantization setup. 50 | 51 | -------------------------------------------------------------------------------- /Quantize/__init__.py: -------------------------------------------------------------------------------- 1 | from TensorQuant.Quantize import utils 2 | #from TensorQuant.Quantize import override 3 | -------------------------------------------------------------------------------- /Quantize/override.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from TensorQuant.Quantize.override_functions import generic_keras_override 3 | 4 | # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 5 | # override these qmaps in your main application 6 | # Example: 7 | # intr_q_map = { "MyNetwork/Conv2D_1" : "nearest,32,16", 8 | # "MyNetwork/Conv2D_2" : "nearest,16,8"} 9 | 10 | intr_q_map=None 11 | extr_q_map=None 12 | weight_q_map=None 13 | # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 14 | 15 | # 'tensorflow.keras.layers.Convolution2D' override 16 | keras_conv2d = tf.keras.layers.Conv2D 17 | keras_conv2d_override = generic_keras_override(keras_conv2d) 18 | tf.keras.layers.Conv2D = keras_conv2d_override 19 | tf.keras.layers.Convolution2D = keras_conv2d_override 20 | 21 | # 'tf.keras.layers.Conv1D' override 22 | keras_conv1d = tf.keras.layers.Conv1D 23 | keras_conv1d_override = generic_keras_override(keras_conv1d) 24 | tf.keras.layers.Conv1D = keras_conv1d_override 25 | 26 | # 'tf.keras.layers.Dense' override 27 | keras_dense = tf.keras.layers.Dense 28 | keras_dense_override = generic_keras_override(keras_dense) 29 | tf.keras.layers.Dense = keras_dense_override 30 | 31 | # 'tf.keras.layers.MaxPooling2D' override 32 | keras_maxpool2d = tf.keras.layers.MaxPooling2D 33 | keras_maxpool2d_override = generic_keras_override(keras_maxpool2d) 34 | tf.keras.layers.MaxPooling2D = keras_maxpool2d_override 35 | 36 | # 'tf.keras.layers.MaxPool1D' override 37 | keras_maxpool1d = tf.keras.layers.MaxPool1D 38 | keras_maxpool1d_override = generic_keras_override(keras_maxpool1d) 39 | tf.keras.layers.MaxPool1D = keras_maxpool1d_override 40 | -------------------------------------------------------------------------------- /Quantize/override_functions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from TensorQuant.Quantize.QLayer import create_qLayer 3 | from TensorQuant.Quantize import utils as tq_utils 4 | from TensorQuant.Quantize import override 5 | 6 | # This function creates an override function for nearly all keras layers with weights. 7 | def generic_keras_override(Class): 8 | print("Override for Class %s is active."%(Class.__module__+ "." + Class.__name__)) 9 | def override_func(*args, **kwargs): 10 | def find_quantizer(name, _map): 11 | _map = tq_utils.quantizer_map(_map) 12 | if _map is None: 13 | return None 14 | quantizer = None 15 | for name in _map.keys(): 16 | if(name in layer_ID): 17 | quantizer = _map[name] 18 | break 19 | return quantizer 20 | 21 | # if "name" in kwargs.keys(): 22 | # name = "/"+kwargs["name"] 23 | # else: 24 | # name = '' 25 | 26 | layer = Class(*args, **kwargs) 27 | 28 | # Tensorflow 1.X 29 | try: 30 | layer_scope = tf.get_default_graph().get_name_scope() 31 | except: 32 | layer_scope = "" 33 | 34 | if layer_scope !="": 35 | layer_ID = tf.get_default_graph().get_name_scope() + "/" + layer.name 36 | else: 37 | layer_ID = layer.name 38 | 39 | intr_quantizer = find_quantizer(layer_ID, override.intr_q_map) 40 | extr_quantizer = find_quantizer(layer_ID, override.extr_q_map) 41 | weight_quantizer = find_quantizer(layer_ID, override.weight_q_map) 42 | 43 | layer=create_qLayer(layer, 44 | intr_quantizer=intr_quantizer, 45 | extr_quantizer=extr_quantizer, 46 | weight_quantizer=weight_quantizer) 47 | # Info about quantized Layers 48 | if intr_quantizer is not None: 49 | print("%s: internally quantized with %s"%(layer.name, intr_quantizer)) 50 | if extr_quantizer is not None: 51 | print("%s output quantized with %s"%(layer.name, extr_quantizer)) 52 | if weight_quantizer is not None: 53 | print("%s weights quantized with %s"%(layer.name, weight_quantizer)) 54 | return layer 55 | return override_func 56 | -------------------------------------------------------------------------------- /Quantize/utils.py: -------------------------------------------------------------------------------- 1 | # Utilities for Quantizers used in the evaluation and training python scripts. 2 | # 3 | # author: Dominik Loroch 4 | # date: August 2017 5 | 6 | import json 7 | import tensorflow as tf 8 | 9 | from TensorQuant.Quantize import Quantizers 10 | 11 | def quantizer_selector(selector_str, arg_list): 12 | """ Builds and returns the specified quantizer. 13 | Args: 14 | selector_str: The name of the quantizer. 15 | arg_list: Arguments which need to be passed to the constructor of the quantizer. 16 | Returns: 17 | Quantizer object. 18 | """ 19 | if selector_str=="none": 20 | quantizer = Quantizers.NoQuantizer() 21 | elif selector_str=="zero": 22 | quantizer = Quantizers.FixedPointQuantizer_zero( 23 | int(arg_list[0]), int(arg_list[1]) ) 24 | elif selector_str=="down": 25 | quantizer = Quantizers.FixedPointQuantizer_down( 26 | int(arg_list[0]), int(arg_list[1]) ) 27 | elif selector_str=="nearest": 28 | quantizer = Quantizers.FixedPointQuantizer_nearest( 29 | int(arg_list[0]), int(arg_list[1]) ) 30 | elif selector_str=="stochastic": 31 | quantizer = Quantizers.FixedPointQuantizer_stochastic( 32 | int(arg_list[0]), int(arg_list[1]) ) 33 | elif selector_str=="sparse": 34 | quantizer = Quantizers.SparseQuantizer( 35 | float(arg_list[0]) ) 36 | elif selector_str=="logarithmic": 37 | quantizer = Quantizers.LogarithmicQuantizer() 38 | elif selector_str=="fp16": 39 | quantizer = Quantizers.HalffpQuantizer() 40 | elif selector_str=="binary": 41 | if len(arg_list)==0: 42 | quantizer = Quantizers.BinaryQuantizer( 1 ) 43 | if len(arg_list)==1: 44 | quantizer = Quantizers.BinaryQuantizer( float(arg_list[0]) ) 45 | elif selector_str=="ternary": 46 | if len(arg_list)==0: 47 | quantizer = Quantizers.TernaryQuantizer( 1 ) 48 | if len(arg_list)==1: 49 | quantizer = Quantizers.TernaryQuantizer( float(arg_list[0]) ) 50 | elif len(arg_list)==2: 51 | quantizer = Quantizers.TernaryQuantizer( float(arg_list[0]), False, float(arg_list[1])) 52 | else: 53 | raise ValueError('Quantizer %s not recognized!'%(selector_str)) 54 | return quantizer 55 | 56 | 57 | def split_quantizer_str(quantizer_str): 58 | """ Splits a quantizer string into its components. 59 | Interprets the first entry as the quantizer name. 60 | Args: 61 | quantizer_str: String in the form: "quantizer_type,argument_1,argument_2,..." 62 | Returns: 63 | Tupel of strings in the form (quantizer_type, [argument_1, argument_2,...]) 64 | """ 65 | quantizer_type='' 66 | args=[] 67 | tokens = quantizer_str.split(',') 68 | if len(tokens) > 0: 69 | quantizer_type=tokens[0] 70 | if len(tokens) > 1: 71 | args=tokens[1:] 72 | return (quantizer_type, args) 73 | 74 | 75 | def get_quantizer(q_str): 76 | """ Get a quantizer instance based on string. 77 | If quantizer is empty string or None, None is returned. 78 | Args: 79 | q_str: quantizer string to be interpreted. 80 | Returns: 81 | Quantizer object or None. 82 | """ 83 | if q_str == "": 84 | q_str=None 85 | if q_str is None: 86 | return None 87 | qtype, qargs= split_quantizer_str(q_str) 88 | quantizer = quantizer_selector(qtype, qargs) 89 | return quantizer 90 | 91 | 92 | def quantizer_map(qmap): 93 | """ Creates a Quantizer map. All specified layers share the same quantizer type. 94 | Args: 95 | qmap: Location of the .json file, which specifies the mapping, or a dictionary with the same content. 96 | Returns: 97 | A dictionary containing the mapping from layers to quantizers. 98 | """ 99 | if qmap is None: 100 | return None 101 | elif type(qmap) == str: 102 | # load dictionary from json file. 103 | # open file and parse data 104 | if qmap is '': 105 | return None 106 | try: 107 | with open(qmap,'r') as hfile: 108 | qmap = json.load(hfile) 109 | except IOError: 110 | qmap={"":qmap} 111 | 112 | # change strings in qmap into quantizer objects 113 | for key in qmap: 114 | if type(qmap[key]) is str: 115 | # generate quantizer object 116 | quantizer=get_quantizer(qmap[key]) 117 | #if quantizer is None: 118 | # raise ValueError("Invalid quantizer \""+qmap[key]+"\" for layer \""+key+"\"") 119 | qmap[key]=quantizer 120 | 121 | return qmap 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorQuant 2 | 3 | A TensorFlow toolbox for Deep Neural Network Quantization 4 | 5 | Original paper: https://arxiv.org/abs/1710.05758 6 | 7 | ## Getting Started 8 | 9 | ### Structure 10 | 11 | **Examples/** - Contains examples on how to use TensorQuant. 12 | 13 | **Kernels/** - Contains C-files of the kernels used in the quantizers. 14 | 15 | **Quantize/** - Contains the quantizers and override mechanic of the TensorQuant toolbox. 16 | 17 | ### Prerequisites 18 | 19 | - [TensorFlow](https://www.tensorflow.org/) 2.0 (Keras) 20 | - [Python](https://www.python.org/) 3.6 21 | 22 | ### Installing 23 | 24 | Add the TensorQuant directory to your PYTHONPATH environment variable, so it can be found by your project. 25 | ``` shell 26 | export PYTHONPATH=${PYTHONPATH}: 27 | ``` 28 | 29 | Compile the Kernels in the "Kernels/" directory. A makefile is provided (run 'make all'). There might be issues with the -D_GLIBCXX_USE_CXX11_ABI=0 flag. See [this link](https://www.tensorflow.org/extend/adding_an_op) (under 'Build the op library') for more help on this topic. 30 | 31 | ## Quantizing a Neural Network 32 | 33 | TensorQuant temporarily hijacks the Keras layer identifiers in order to inline additional ops for the Quantization. 34 | The override needs to be applied before the model is build. Therefore, TensorQuant cannot be used if the model is loaded from a container file (i.e. no calls to the "tf.keras.layers" classes). 35 | In order to apply the overrides, you must import the "override" module from TensorQuant in your main file: 36 | 37 | ``` python 38 | from TensorQuant.Quantize import override 39 | ``` 40 | 41 | The layers for quantization are selected via a Dictionary, which maps layer names to quantizers. The available quantizers are in "TensorQuant.Quantize.Quantizers". 42 | 43 | For example, you can provide a dictionary in your python code like this: 44 | ``` python 45 | override.extr_q_map={"Conv1" : tensorQuant.Quantize.Quantizers.FixedPointQuantizer_nearest(16,8)} 46 | ``` 47 | 48 | Alternatively, you can provide a .json file 49 | ```json 50 | { 51 | "Layer_name" : "Quantizer_shortcut_string" 52 | } 53 | ``` 54 | The quantizer shortcut strings are defined in the same file in the "quantizer_selector" function (e.g. "nearest,16,8" would create a fixed point quantization with 32bits and 16bit fractional part). 55 | 56 | The layer names do not require to match the real names entirely, but every layer which contains a matching substring will be quantized with the given quantizer. This allows to quantize entire blocks of layers. As of writing this readme, there is an issue with the "tf.name_scope" feature together with Keras layers, so it is not a reliable way to structure your network. 57 | 58 | Load the json file with 59 | ```python 60 | from TensorQuant.Quantize import utils 61 | 62 | override.extr_q_map = utils.quantizer_map(json_filename) 63 | ``` 64 | The available Quantizer shortcut strings are in the file "TensorQuant.Quantize.utils.quantizer_selector". 65 | 66 | Currently, there is "extr_q_map" for layer activations and "weight_q_map" for the layer weights. "intr_q_map" for intrinsic quantization is not available in this version of TensorQuant. 67 | Set these dictionaries before the model is build (i.e. before calling tensorflow.keras.layers classes). If you do not want to use quantization, set the quantizer map to "None" (default). 68 | 69 | There are no changes required in your Keras model. However, the override mechanic is very sensitive to the exact identifiers of the classes, so it might be necessary to use the full identifiers (e.g. "tf.keras.layers.Conv2D"), or to create aliases (e.g. "Conv2D = tf.keras.layers.Conv2D"), as shown in the LeNet example. 70 | ``` python 71 | # Introducing an alias for a Keras layer 72 | Convolution2D = tf.keras.layers.Convolution2D 73 | 74 | model = tf.keras.models.Sequential() 75 | 76 | model.add(Convolution2D( 77 | filters = 20, 78 | kernel_size = (5, 5), 79 | padding = "same", 80 | input_shape = (28, 28, 1), 81 | activation="relu", 82 | name="Conv1")) 83 | ``` 84 | 85 | ## Overriding Layers 86 | The number of initially available overrides does not span the complete set of available Keras layers. Any Keras layer can be hijacked, like in this example: 87 | ``` python 88 | import tensorflow as tf 89 | from TensorQuant.Quantize.override_functions import generic_keras_override 90 | 91 | keras_conv2d = tf.keras.layers.Conv2D 92 | keras_conv2d_override = generic_keras_override(keras_conv2d) 93 | # the override happens in this line 94 | tf.keras.layers.Conv2D = keras_conv2d_override 95 | # optionally, override any aliases 96 | tf.keras.layers.Convolution2D = keras_conv2d_override 97 | ``` 98 | The overrides must be done before the model is build. The available overrides are in "TensorQuant.Quantize.override". Additional overrides can be placed in that file as well. 99 | 100 | ## Authors 101 | 102 | Dominik Loroch (Fraunhofer ITWM) 103 | 104 | Please reference to [this](https://arxiv.org/abs/1710.05758) paper. 105 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from TensorQuant import Quantize --------------------------------------------------------------------------------