├── LICENSE ├── README.md ├── ternary_conv_layer.cpp ├── ternary_conv_layer.cu ├── ternary_conv_layer.hpp ├── ternary_inner_product_layer.cpp ├── ternary_inner_product_layer.cu └── ternary_inner_product_layer.hpp /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2018, shuan 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimal Ternary Weights Approximation 2 | 3 | Caffe implementation of Optimal-Ternary-Weights-Approximation in "Two-Step Quantization for Low-bit Neural Networks" (CVPR2018). 4 | 5 | ### Objective Function 6 | ![equation](http://latex.codecogs.com/gif.latex?\min_{\alpha,\hat{w}}||w-\alpha\hat{w}||_2^2) 7 | 8 | where ![equation](http://latex.codecogs.com/gif.latex?\alpha>0) and ![equation](http://latex.codecogs.com/gif.latex?\hat{w}\in[-1,0,+1\]^m). 9 | 10 | ### Weight Blob 11 | We use a temporary memory block to store ![equation](http://latex.codecogs.com/gif.latex?\alpha\hat{w}) and keep ![equation](http://latex.codecogs.com/gif.latex?w) in the **this->blobs_[0]**. 12 | During the backwardpropagation, ![equation](http://latex.codecogs.com/gif.latex?w) was used in the gradient accumulation and ![equation](http://latex.codecogs.com/gif.latex?\alpha\hat{w}) was used in the calculation of bottom gradients. 13 | 14 | ### How to use ? 15 | change **type: "Convolution"** into **type: "TernaryConvolution"**, e.g. 16 | 17 | ```prototxt 18 | layer { 19 | bottom: "pool1" 20 | top: "res2a_branch1" 21 | name: "res2a_branch1" 22 | type: "TernaryConvolution" 23 | convolution_param { 24 | num_output: 64 25 | kernel_size: 1 26 | pad: 0 27 | stride: 1 28 | weight_filler { 29 | type: "msra" 30 | } 31 | bias_term: false 32 | } 33 | } 34 | ``` 35 | So far, **GPU only**. 36 | 37 | ### 2-bit Activation Quantization 38 | Please refer to [wps712](https://github.com/wps712/Two-Step-Quantization-AlexNet). 39 | -------------------------------------------------------------------------------- /ternary_conv_layer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "caffe/layers/ternary_conv_layer.hpp" 4 | 5 | namespace caffe { 6 | 7 | template 8 | void TernaryConvolutionLayer::compute_output_shape() { 9 | const int* kernel_shape_data = this->kernel_shape_.cpu_data(); 10 | const int* stride_data = this->stride_.cpu_data(); 11 | const int* pad_data = this->pad_.cpu_data(); 12 | const int* dilation_data = this->dilation_.cpu_data(); 13 | this->output_shape_.clear(); 14 | for (int i = 0; i < this->num_spatial_axes_; ++i) { 15 | // i + 1 to skip channel axis 16 | const int input_dim = this->input_shape(i + 1); 17 | const int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; 18 | const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent) 19 | / stride_data[i] + 1; 20 | this->output_shape_.push_back(output_dim); 21 | } 22 | 23 | ternary_weights_.ReshapeLike(*this->blobs_[0]); 24 | alphas_.Reshape(this->num_output_,1,1,1); 25 | weight_sum_multiplier_.Reshape(this->blobs_[0]->count(1),1,1,1); 26 | threshold_.Reshape(this->num_output_,1,1,1); 27 | 28 | skip_quantization_ = false; 29 | } 30 | 31 | template 32 | void TernaryConvolutionLayer::Forward_cpu(const vector*>& bottom, 33 | const vector*>& top) { 34 | } 35 | 36 | template 37 | void TernaryConvolutionLayer::Backward_cpu(const vector*>& top, 38 | const vector& propagate_down, const vector*>& bottom) { 39 | 40 | } 41 | 42 | #ifdef CPU_ONLY 43 | STUB_GPU(TernaryConvolutionLayer); 44 | #endif 45 | 46 | INSTANTIATE_CLASS(TernaryConvolutionLayer); 47 | REGISTER_LAYER_CLASS(TernaryConvolution); 48 | } // namespace caffe 49 | -------------------------------------------------------------------------------- /ternary_conv_layer.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "caffe/layers/ternary_conv_layer.hpp" 5 | 6 | namespace caffe { 7 | 8 | template 9 | __global__ void TernaryWeightQuant(const int n, const int weight_dim, const Dtype* weight, 10 | const Dtype* threshold, Dtype* ternary_weight) { 11 | CUDA_KERNEL_LOOP(index, n) { 12 | int i = index/weight_dim; 13 | Dtype ternary_code = weight[index] > Dtype(0) ? Dtype(1) : Dtype(-1); 14 | ternary_weight[index] = fabs(weight[index]) >= threshold[i] ? ternary_code : Dtype(0); 15 | } 16 | } 17 | 18 | template 19 | __global__ void TernaryWeightForward(const int n, const int weight_dim, const Dtype* weight, 20 | const Dtype* alpha, Dtype* ternary_weight) { 21 | CUDA_KERNEL_LOOP(index, n) { 22 | int i = index/weight_dim; 23 | ternary_weight[index] = weight[index] * alpha[i]; 24 | } 25 | } 26 | 27 | template 28 | void TernaryConvolutionLayer::Forward_gpu(const vector*>& bottom, 29 | const vector*>& top) { 30 | // initialization for ternary parameters 31 | const Dtype* weight = this->blobs_[0]->gpu_data(); 32 | const int weight_dim = this->blobs_[0]->count(1); 33 | 34 | if (skip_quantization_ == false) { 35 | caffe_gpu_abs(this->blobs_[0]->count(), weight, ternary_weights_.mutable_gpu_data()); 36 | caffe_gpu_set(weight_sum_multiplier_.count(),Dtype(1),weight_sum_multiplier_.mutable_gpu_data()); 37 | const int nthreads = this->blobs_[0]->count(); 38 | Dtype* threshold_ptr = threshold_.mutable_cpu_data(); 39 | 40 | for (int i = 0; i < this->blobs_[0]->num(); i++) { 41 | Dtype* kernel_mutable_cpu_data = ternary_weights_.mutable_cpu_data()+i*this->blobs_[0]->count(1); 42 | std::sort(kernel_mutable_cpu_data, kernel_mutable_cpu_data+this->blobs_[0]->count(1)); 43 | int r = 0; 44 | Dtype s = 0; 45 | Dtype loss_max = Dtype(1e-5); 46 | int idx = 1; 47 | for (int j = this->blobs_[0]->count(1)-1; j >=0; j--) { 48 | s += kernel_mutable_cpu_data[j]; r++; 49 | const Dtype loss = s*s/r; 50 | if (loss >= loss_max) { 51 | loss_max = loss; 52 | idx = j; 53 | } 54 | } 55 | threshold_ptr[i] = kernel_mutable_cpu_data[idx]; 56 | } 57 | 58 | TernaryWeightQuant<<>>( 59 | nthreads, weight_dim, weight, threshold_.gpu_data(), ternary_weights_.mutable_gpu_data()); 60 | 61 | const int output_channel_num = this->num_output_; 62 | const int kernel_dim = this->kernel_dim_; 63 | 64 | caffe_gpu_mul(output_channel_num*kernel_dim, weight, ternary_weights_.gpu_data(), 65 | ternary_weights_.mutable_gpu_diff()); 66 | caffe_gpu_gemv(CblasNoTrans, output_channel_num, kernel_dim, (Dtype)1., 67 | ternary_weights_.gpu_diff(), weight_sum_multiplier_.gpu_data(), 68 | (Dtype)0., alphas_.mutable_gpu_data()); 69 | caffe_gpu_mul(output_channel_num*kernel_dim, ternary_weights_.gpu_data(), 70 | ternary_weights_.gpu_data(), ternary_weights_.mutable_gpu_diff()); 71 | caffe_gpu_gemv(CblasNoTrans, output_channel_num, kernel_dim, 72 | (Dtype)1., ternary_weights_.gpu_diff(), weight_sum_multiplier_.gpu_data(), 73 | (Dtype)0., alphas_.mutable_gpu_diff()); 74 | caffe_gpu_div(output_channel_num, alphas_.gpu_data(), alphas_.gpu_diff(), alphas_.mutable_gpu_data()); 75 | 76 | TernaryWeightForward<<>>( 77 | nthreads, weight_dim, ternary_weights_.gpu_data(), alphas_.gpu_data(), ternary_weights_.mutable_gpu_data()); 78 | } 79 | skip_quantization_ = this->phase_ == TEST; 80 | 81 | const Dtype* ternary_weights = ternary_weights_.gpu_data(); 82 | for (int i = 0; i < bottom.size(); ++i) { 83 | const Dtype* bottom_data = bottom[i]->gpu_data(); 84 | Dtype* top_data = top[i]->mutable_gpu_data(); 85 | for (int n = 0; n < this->num_; ++n) { 86 | this->forward_gpu_gemm(bottom_data + n * this->bottom_dim_, ternary_weights, 87 | top_data + n * this->top_dim_); 88 | if (this->bias_term_) { 89 | const Dtype* bias = this->blobs_[1]->gpu_data(); 90 | this->forward_gpu_bias(top_data + n * this->top_dim_, bias); 91 | } 92 | } 93 | } 94 | } 95 | 96 | template 97 | void TernaryConvolutionLayer::Backward_gpu(const vector*>& top, 98 | const vector& propagate_down, const vector*>& bottom) { 99 | const Dtype* ternary_weights = ternary_weights_.gpu_data(); 100 | Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); 101 | for (int i = 0; i < top.size(); ++i) { 102 | const Dtype* top_diff = top[i]->gpu_diff(); 103 | // Bias gradient, if necessary. 104 | if (this->bias_term_ && this->param_propagate_down_[1]) { 105 | Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); 106 | for (int n = 0; n < this->num_; ++n) { 107 | this->backward_gpu_bias(bias_diff, top_diff + n * this->top_dim_); 108 | } 109 | } 110 | if (this->param_propagate_down_[0] || propagate_down[i]) { 111 | const Dtype* bottom_data = bottom[i]->gpu_data(); 112 | Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); 113 | for (int n = 0; n < this->num_; ++n) { 114 | // gradient w.r.t. weight. Note that we will accumulate diffs. 115 | if (this->param_propagate_down_[0]) { 116 | this->weight_gpu_gemm(bottom_data + n * this->bottom_dim_, 117 | top_diff + n * this->top_dim_, weight_diff); 118 | } 119 | // gradient w.r.t. bottom data, if necessary. 120 | if (propagate_down[i]) { 121 | this->backward_gpu_gemm(top_diff + n * this->top_dim_, ternary_weights, 122 | bottom_diff + n * this->bottom_dim_); 123 | } 124 | } 125 | } 126 | } 127 | } 128 | 129 | INSTANTIATE_LAYER_GPU_FUNCS(TernaryConvolutionLayer); 130 | 131 | } // namespace caffe 132 | -------------------------------------------------------------------------------- /ternary_conv_layer.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CAFFE_TERNARY_CONV_LAYER_HPP_ 2 | #define CAFFE_TERNARY_CONV_LAYER_HPP_ 3 | 4 | #include 5 | 6 | #include "caffe/blob.hpp" 7 | #include "caffe/layer.hpp" 8 | #include "caffe/proto/caffe.pb.h" 9 | 10 | #include "caffe/layers/base_conv_layer.hpp" 11 | 12 | namespace caffe { 13 | 14 | template 15 | class TernaryConvolutionLayer : public BaseConvolutionLayer { 16 | public: 17 | 18 | explicit TernaryConvolutionLayer(const LayerParameter& param) 19 | : BaseConvolutionLayer(param) {} 20 | 21 | virtual inline const char* type() const { return "TernaryConvolution"; } 22 | 23 | protected: 24 | virtual void Forward_cpu(const vector*>& bottom, 25 | const vector*>& top); 26 | virtual void Forward_gpu(const vector*>& bottom, 27 | const vector*>& top); 28 | virtual void Backward_cpu(const vector*>& top, 29 | const vector& propagate_down, const vector*>& bottom); 30 | virtual void Backward_gpu(const vector*>& top, 31 | const vector& propagate_down, const vector*>& bottom); 32 | virtual inline bool reverse_dimensions() { return false; } 33 | virtual void compute_output_shape(); 34 | 35 | Blob ternary_weights_; 36 | Blob alphas_; 37 | Blob weight_sum_multiplier_; 38 | Blob threshold_; 39 | 40 | bool skip_quantization_; 41 | // bool weights_overwrite_; 42 | }; 43 | 44 | } // namespace caffe 45 | 46 | #endif // CAFFE_TERNARY_CONV_LAYER_HPP_ 47 | -------------------------------------------------------------------------------- /ternary_inner_product_layer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "caffe/filler.hpp" 4 | #include "caffe/layers/ternary_inner_product_layer.hpp" 5 | #include "caffe/util/math_functions.hpp" 6 | 7 | namespace caffe { 8 | 9 | template 10 | void TernaryInnerProductLayer::LayerSetUp(const vector*>& bottom, 11 | const vector*>& top) { 12 | const int num_output = this->layer_param_.inner_product_param().num_output(); 13 | bias_term_ = this->layer_param_.inner_product_param().bias_term(); 14 | transpose_ = this->layer_param_.inner_product_param().transpose(); 15 | N_ = num_output; 16 | const int axis = bottom[0]->CanonicalAxisIndex( 17 | this->layer_param_.inner_product_param().axis()); 18 | // Dimensions starting from "axis" are "flattened" into a single 19 | // length K_ vector. For example, if bottom[0]'s shape is (N, C, H, W), 20 | // and axis == 1, N inner products with dimension CHW are performed. 21 | K_ = bottom[0]->count(axis); 22 | // Check if we need to set up the weights 23 | if (this->blobs_.size() > 0) { 24 | LOG(INFO) << "Skipping parameter initialization"; 25 | } else { 26 | if (bias_term_) { 27 | this->blobs_.resize(2); 28 | } else { 29 | this->blobs_.resize(1); 30 | } 31 | // Initialize the weights 32 | vector weight_shape(2); 33 | if (transpose_) { 34 | weight_shape[0] = K_; 35 | weight_shape[1] = N_; 36 | } else { 37 | weight_shape[0] = N_; 38 | weight_shape[1] = K_; 39 | } 40 | this->blobs_[0].reset(new Blob(weight_shape)); 41 | // fill the weights 42 | shared_ptr > weight_filler(GetFiller( 43 | this->layer_param_.inner_product_param().weight_filler())); 44 | weight_filler->Fill(this->blobs_[0].get()); 45 | // If necessary, intiialize and fill the bias term 46 | if (bias_term_) { 47 | vector bias_shape(1, N_); 48 | this->blobs_[1].reset(new Blob(bias_shape)); 49 | shared_ptr > bias_filler(GetFiller( 50 | this->layer_param_.inner_product_param().bias_filler())); 51 | bias_filler->Fill(this->blobs_[1].get()); 52 | } 53 | } // parameter initialization 54 | this->param_propagate_down_.resize(this->blobs_.size(), true); 55 | } 56 | 57 | template 58 | void TernaryInnerProductLayer::Reshape(const vector*>& bottom, 59 | const vector*>& top) { 60 | // Figure out the dimensions 61 | const int axis = bottom[0]->CanonicalAxisIndex( 62 | this->layer_param_.inner_product_param().axis()); 63 | const int new_K = bottom[0]->count(axis); 64 | CHECK_EQ(K_, new_K) 65 | << "Input size incompatible with inner product parameters."; 66 | // The first "axis" dimensions are independent inner products; the total 67 | // number of these is M_, the product over these dimensions. 68 | M_ = bottom[0]->count(0, axis); 69 | // The top shape will be the bottom shape with the flattened axes dropped, 70 | // and replaced by a single axis with dimension num_output (N_). 71 | vector top_shape = bottom[0]->shape(); 72 | top_shape.resize(axis + 1); 73 | top_shape[axis] = N_; 74 | top[0]->Reshape(top_shape); 75 | // Set up the bias multiplier 76 | if (bias_term_) { 77 | vector bias_shape(1, M_); 78 | bias_multiplier_.Reshape(bias_shape); 79 | caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data()); 80 | } 81 | 82 | ternary_weights_.ReshapeLike(*this->blobs_[0]); 83 | alphas_.Reshape(this->blobs_[0]->num(),1,1,1); 84 | weight_sum_multiplier_.Reshape(this->blobs_[0]->count(1),1,1,1); 85 | threshold_.Reshape(this->blobs_[0]->num(),1,1,1); 86 | 87 | skip_quantization_ = false; 88 | } 89 | 90 | template 91 | void TernaryInnerProductLayer::Forward_cpu(const vector*>& bottom, 92 | const vector*>& top) { 93 | const Dtype* bottom_data = bottom[0]->cpu_data(); 94 | Dtype* top_data = top[0]->mutable_cpu_data(); 95 | const Dtype* weight = this->blobs_[0]->cpu_data(); 96 | caffe_cpu_gemm(CblasNoTrans, transpose_ ? CblasNoTrans : CblasTrans, 97 | M_, N_, K_, (Dtype)1., 98 | bottom_data, weight, (Dtype)0., top_data); 99 | if (bias_term_) { 100 | caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1., 101 | bias_multiplier_.cpu_data(), 102 | this->blobs_[1]->cpu_data(), (Dtype)1., top_data); 103 | } 104 | } 105 | 106 | template 107 | void TernaryInnerProductLayer::Backward_cpu(const vector*>& top, 108 | const vector& propagate_down, 109 | const vector*>& bottom) { 110 | if (this->param_propagate_down_[0]) { 111 | const Dtype* top_diff = top[0]->cpu_diff(); 112 | const Dtype* bottom_data = bottom[0]->cpu_data(); 113 | // Gradient with respect to weight 114 | if (transpose_) { 115 | caffe_cpu_gemm(CblasTrans, CblasNoTrans, 116 | K_, N_, M_, 117 | (Dtype)1., bottom_data, top_diff, 118 | (Dtype)1., this->blobs_[0]->mutable_cpu_diff()); 119 | } else { 120 | caffe_cpu_gemm(CblasTrans, CblasNoTrans, 121 | N_, K_, M_, 122 | (Dtype)1., top_diff, bottom_data, 123 | (Dtype)1., this->blobs_[0]->mutable_cpu_diff()); 124 | } 125 | } 126 | if (bias_term_ && this->param_propagate_down_[1]) { 127 | const Dtype* top_diff = top[0]->cpu_diff(); 128 | // Gradient with respect to bias 129 | caffe_cpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, 130 | bias_multiplier_.cpu_data(), (Dtype)1., 131 | this->blobs_[1]->mutable_cpu_diff()); 132 | } 133 | if (propagate_down[0]) { 134 | const Dtype* top_diff = top[0]->cpu_diff(); 135 | // Gradient with respect to bottom data 136 | if (transpose_) { 137 | caffe_cpu_gemm(CblasNoTrans, CblasTrans, 138 | M_, K_, N_, 139 | (Dtype)1., top_diff, this->blobs_[0]->cpu_data(), 140 | (Dtype)0., bottom[0]->mutable_cpu_diff()); 141 | } else { 142 | caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, 143 | M_, K_, N_, 144 | (Dtype)1., top_diff, this->blobs_[0]->cpu_data(), 145 | (Dtype)0., bottom[0]->mutable_cpu_diff()); 146 | } 147 | } 148 | } 149 | 150 | #ifdef CPU_ONLY 151 | STUB_GPU(TernaryInnerProductLayer); 152 | #endif 153 | 154 | INSTANTIATE_CLASS(TernaryInnerProductLayer); 155 | REGISTER_LAYER_CLASS(TernaryInnerProduct); 156 | 157 | } // namespace caffe 158 | -------------------------------------------------------------------------------- /ternary_inner_product_layer.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "caffe/filler.hpp" 4 | #include "caffe/layers/ternary_inner_product_layer.hpp" 5 | #include "caffe/util/math_functions.hpp" 6 | 7 | namespace caffe { 8 | 9 | template 10 | __global__ void TernaryWeightQuant(const int n, const int weight_dim, const Dtype* weight, 11 | const Dtype* threshold, Dtype* ternary_weight) { 12 | CUDA_KERNEL_LOOP(index, n) { 13 | int i = index/weight_dim; 14 | Dtype ternary_code = weight[index] > Dtype(0) ? Dtype(1) : Dtype(-1); 15 | ternary_weight[index] = fabs(weight[index]) >= threshold[i] ? ternary_code : Dtype(0); 16 | } 17 | } 18 | 19 | template 20 | __global__ void TernaryWeightForward(const int n, const int weight_dim, const Dtype* weight, 21 | const Dtype* alpha, Dtype* ternary_weight) { 22 | CUDA_KERNEL_LOOP(index, n) { 23 | int i = index/weight_dim; 24 | ternary_weight[index] = weight[index] * alpha[i]; 25 | } 26 | } 27 | 28 | template 29 | void TernaryInnerProductLayer::Forward_gpu(const vector*>& bottom, 30 | const vector*>& top) { 31 | const Dtype* bottom_data = bottom[0]->gpu_data(); 32 | Dtype* top_data = top[0]->mutable_gpu_data(); 33 | const Dtype* weight = this->blobs_[0]->gpu_data(); 34 | const int weight_dim = this->blobs_[0]->count(1); 35 | 36 | if (skip_quantization_ == false) { 37 | caffe_gpu_abs(this->blobs_[0]->count(), weight, ternary_weights_.mutable_gpu_data()); 38 | caffe_gpu_set(weight_sum_multiplier_.count(),Dtype(1),weight_sum_multiplier_.mutable_gpu_data()); 39 | const int nthreads = this->blobs_[0]->count(); 40 | Dtype* threshold_ptr = threshold_.mutable_cpu_data(); 41 | 42 | for (int i = 0; i < this->blobs_[0]->num(); i++) { 43 | Dtype* kernel_mutable_cpu_data = ternary_weights_.mutable_cpu_data()+i*this->blobs_[0]->count(1); 44 | std::sort(kernel_mutable_cpu_data, kernel_mutable_cpu_data+this->blobs_[0]->count(1)); 45 | int r = 0; 46 | Dtype s = 0; 47 | // Dtype* J = contribution_.mutable_cpu_data(); 48 | Dtype loss_max = Dtype(1e-5); 49 | int idx = 1; 50 | for (int j = this->blobs_[0]->count(1)-1; j >=0; j--) { 51 | s += kernel_mutable_cpu_data[j]; r++; 52 | const Dtype loss = s*s/r; 53 | if (loss >= loss_max) { 54 | loss_max = loss; 55 | idx = j; 56 | } 57 | } 58 | threshold_ptr[i] = kernel_mutable_cpu_data[idx]; 59 | } 60 | 61 | TernaryWeightQuant<<>>( 62 | nthreads, weight_dim, weight, threshold_.gpu_data(), ternary_weights_.mutable_gpu_data()); 63 | 64 | const int output_channel_num = this->blobs_[0]->num(); 65 | const int kernel_dim = this->blobs_[0]->count(1); 66 | 67 | caffe_gpu_mul(output_channel_num*kernel_dim, weight, ternary_weights_.gpu_data(), 68 | ternary_weights_.mutable_gpu_diff()); 69 | caffe_gpu_gemv(CblasNoTrans, output_channel_num, kernel_dim, (Dtype)1., 70 | ternary_weights_.gpu_diff(), weight_sum_multiplier_.gpu_data(), 71 | (Dtype)0., alphas_.mutable_gpu_data()); 72 | caffe_gpu_mul(output_channel_num*kernel_dim, ternary_weights_.gpu_data(), 73 | ternary_weights_.gpu_data(), ternary_weights_.mutable_gpu_diff()); 74 | caffe_gpu_gemv(CblasNoTrans, output_channel_num, kernel_dim, 75 | (Dtype)1., ternary_weights_.gpu_diff(), weight_sum_multiplier_.gpu_data(), 76 | (Dtype)0., alphas_.mutable_gpu_diff()); 77 | caffe_gpu_div(output_channel_num, alphas_.gpu_data(), alphas_.gpu_diff(), alphas_.mutable_gpu_data()); 78 | 79 | TernaryWeightForward<<>>( 80 | nthreads, weight_dim, ternary_weights_.gpu_data(), alphas_.gpu_data(), ternary_weights_.mutable_gpu_data()); 81 | } 82 | skip_quantization_ = this->phase_ == TEST; 83 | 84 | const Dtype* ternary_weights = ternary_weights_.gpu_data(); 85 | 86 | if (M_ == 1) { 87 | caffe_gpu_gemv(CblasNoTrans, N_, K_, (Dtype)1., 88 | ternary_weights, bottom_data, (Dtype)0., top_data); 89 | if (bias_term_) 90 | caffe_gpu_axpy(N_, bias_multiplier_.cpu_data()[0], 91 | this->blobs_[1]->gpu_data(), top_data); 92 | } else { 93 | caffe_gpu_gemm(CblasNoTrans, 94 | transpose_ ? CblasNoTrans : CblasTrans, 95 | M_, N_, K_, (Dtype)1., 96 | bottom_data, ternary_weights, (Dtype)0., top_data); 97 | if (bias_term_) 98 | caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1., 99 | bias_multiplier_.gpu_data(), 100 | this->blobs_[1]->gpu_data(), (Dtype)1., top_data); 101 | } 102 | } 103 | 104 | template 105 | void TernaryInnerProductLayer::Backward_gpu(const vector*>& top, 106 | const vector& propagate_down, 107 | const vector*>& bottom) { 108 | if (this->param_propagate_down_[0]) { 109 | const Dtype* top_diff = top[0]->gpu_diff(); 110 | const Dtype* bottom_data = bottom[0]->gpu_data(); 111 | // Gradient with respect to weight 112 | if (transpose_) { 113 | caffe_gpu_gemm(CblasTrans, CblasNoTrans, 114 | K_, N_, M_, 115 | (Dtype)1., bottom_data, top_diff, 116 | (Dtype)1., this->blobs_[0]->mutable_gpu_diff()); 117 | } else { 118 | caffe_gpu_gemm(CblasTrans, CblasNoTrans, 119 | N_, K_, M_, 120 | (Dtype)1., top_diff, bottom_data, 121 | (Dtype)1., this->blobs_[0]->mutable_gpu_diff()); 122 | } 123 | } 124 | if (bias_term_ && this->param_propagate_down_[1]) { 125 | const Dtype* top_diff = top[0]->gpu_diff(); 126 | // Gradient with respect to bias 127 | caffe_gpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff, 128 | bias_multiplier_.gpu_data(), (Dtype)1., 129 | this->blobs_[1]->mutable_gpu_diff()); 130 | } 131 | if (propagate_down[0]) { 132 | const Dtype* top_diff = top[0]->gpu_diff(); 133 | // Gradient with respect to bottom data 134 | if (transpose_) { 135 | caffe_gpu_gemm(CblasNoTrans, CblasTrans, 136 | M_, K_, N_, 137 | (Dtype)1., top_diff, ternary_weights_.gpu_data(), 138 | (Dtype)0., bottom[0]->mutable_gpu_diff()); 139 | } else { 140 | caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, 141 | M_, K_, N_, 142 | (Dtype)1., top_diff, ternary_weights_.gpu_data(), 143 | (Dtype)0., bottom[0]->mutable_gpu_diff()); 144 | } 145 | } 146 | } 147 | 148 | INSTANTIATE_LAYER_GPU_FUNCS(TernaryInnerProductLayer); 149 | 150 | } // namespace caffe 151 | -------------------------------------------------------------------------------- /ternary_inner_product_layer.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CAFFE_TERNARY_INNER_PRODUCT_LAYER_HPP_ 2 | #define CAFFE_TERNARY_INNER_PRODUCT_LAYER_HPP_ 3 | 4 | #include 5 | 6 | #include "caffe/blob.hpp" 7 | #include "caffe/layer.hpp" 8 | #include "caffe/proto/caffe.pb.h" 9 | 10 | namespace caffe { 11 | 12 | template 13 | class TernaryInnerProductLayer : public Layer { 14 | public: 15 | explicit TernaryInnerProductLayer(const LayerParameter& param) 16 | : Layer(param) {} 17 | virtual void LayerSetUp(const vector*>& bottom, 18 | const vector*>& top); 19 | virtual void Reshape(const vector*>& bottom, 20 | const vector*>& top); 21 | 22 | virtual inline const char* type() const { return "TernaryInnerProduct"; } 23 | virtual inline int ExactNumBottomBlobs() const { return 1; } 24 | virtual inline int ExactNumTopBlobs() const { return 1; } 25 | 26 | protected: 27 | virtual void Forward_cpu(const vector*>& bottom, 28 | const vector*>& top); 29 | virtual void Forward_gpu(const vector*>& bottom, 30 | const vector*>& top); 31 | virtual void Backward_cpu(const vector*>& top, 32 | const vector& propagate_down, const vector*>& bottom); 33 | virtual void Backward_gpu(const vector*>& top, 34 | const vector& propagate_down, const vector*>& bottom); 35 | 36 | int M_; 37 | int K_; 38 | int N_; 39 | bool bias_term_; 40 | Blob bias_multiplier_; 41 | bool transpose_; ///< if true, assume transposed weights 42 | 43 | //parameters for binarization 44 | Blob ternary_weights_; 45 | Blob alphas_; 46 | Blob threshold_; 47 | Blob weight_sum_multiplier_; 48 | 49 | bool skip_quantization_; 50 | }; 51 | 52 | } // namespace caffe 53 | 54 | #endif // CAFFE_BINARY_INNER_PRODUCT_LAYER_HPP_ 55 | --------------------------------------------------------------------------------