├── main.cpp ├── README.md ├── sys_test_config.xml ├── Pruner.h └── Pruner.cpp /main.cpp: -------------------------------------------------------------------------------- 1 | #include "Pruner.h" 2 | #include 3 | void main(){ 4 | std::string xml_path = "D:\\MINE\\c\\compression\\compression\\compress\\sys_test_config.xml"; 5 | Pruner t = Pruner(xml_path); 6 | t.start(); 7 | /*int i = 7 / 2; 8 | std::cout << i;*/ 9 | system("pause"); 10 | 11 | } 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Caffemodel_Compress# <> 2 | 3 | Song Han, Huizi Mao, William J. Dally 4 | (Submitted on 1 Oct 2015 (v1), last revised 15 Feb 2016 (this version, v5)) 5 | 6 | This C++ project implements CNN channel pruning, an idea I've been considering for a while. You can prune channels in your trained Caffe model using this tool. First, ensure you have the XML file relocated, regardless of whether you’re working with depthwise or convolutional layers. When setting up the layers for pruning, the corresponding channel pruning arrangement is implicitly configured. You’ll also need to make a minor change in main.cpp and build the project with your local IDE. 7 | 8 | Updated 18/6/25: Eltwise pruning on Mac is now supported. 9 | -------------------------------------------------------------------------------- /sys_test_config.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 0 6 | 7 | 8 | 9 | 10 | 11 | 11 12 | 13 | 14 | 17 | 18 | 19 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | D:/MINE/c/compression/compression/car_type/_iter_10000.caffemodel 45 | D:/MINE/c/compression/compression/car_type/ResNet101_test.prototxt 46 | D:/MINE/c/compression/compression/car_type/rate10.caffemodel 47 | D:/MINE/c/compression/compression/car_type/rate10.prototxt 48 | D:/MINE/c/compression/compression/car_type/rate10.txt 49 | -------------------------------------------------------------------------------- /Pruner.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "caffe/proto/caffe.pb.h" 10 | #include "caffe/blob.hpp" 11 | 12 | typedef google::protobuf::int64 int_64; 13 | typedef std::pair atom; 14 | typedef atom* Patom; 15 | 16 | 17 | 18 | class Utility{ 19 | public: 20 | Utility() = default; 21 | ~Utility(){}; 22 | inline std::string doubleToString(double num) 23 | { 24 | char str[256]; 25 | sprintf(str, "%lf", num); 26 | std::string result = str; 27 | return result; 28 | }; 29 | inline std::string intToString(int_64 i){ 30 | std::stringstream stream; 31 | stream << i; 32 | return stream.str(); 33 | }; 34 | inline std::vector split(const std::string& str, const std::string& devide){ 35 | std::vector res; 36 | if ("" == str) return res; 37 | char * strs = new char[str.length() + 1]; 38 | strcpy(strs, str.c_str()); 39 | char * d = new char[devide.length() + 1]; 40 | strcpy(d, devide.c_str()); 41 | char *p = strtok(strs, d); 42 | while (p) { 43 | std::string s = p; 44 | res.push_back(s); 45 | p = strtok(NULL, d); 46 | } 47 | return res; 48 | }; 49 | void hS(std::vector* a, int l, int r); 50 | void fixUp(std::vector* a, int k); 51 | void fixDown(std::vector* a, int k, int N); 52 | }; 53 | 54 | class Pruner 55 | { 56 | public: 57 | #define _RELU_ "ReLU" 58 | #define _PRELU_ "PReLU" 59 | #define _SIGMOID_ "Sigmoid" 60 | #define _TANH_ "Tanh" 61 | #define _CONVOLUTION_ "Convolution" 62 | #define _POOLING_ "Pooling" 63 | 64 | typedef std::pair param; 65 | typedef std::pair convParam; 66 | typedef std::vector convParams; 67 | typedef std::pair record; 68 | typedef std::pair eltwiseRecord; 69 | typedef record* precord; 70 | typedef eltwiseRecord* peltwiserecord; 71 | typedef ::google::protobuf::RepeatedField< double > caffe_double_; 72 | typedef const ::google::protobuf::RepeatedField< double >& caffe_double_data_; 73 | 74 | Pruner() = default; 75 | Pruner(const Pruner&); 76 | Pruner(const std::string xml_path); 77 | Pruner& operator=(const Pruner&); 78 | void start(void); 79 | void read_XML(const std::string xml_path); 80 | void import(void); 81 | inline void pruning(void){ 82 | switch (pruningMode){ 83 | case ratio: 84 | pruningByratio(); 85 | break; 86 | case size: 87 | pruningBySize(); 88 | break; 89 | default: 90 | break; 91 | } 92 | }; 93 | inline bool isNonLinear(std::string layerType){ 94 | return layerType == _RELU_ || layerType == _PRELU_ || layerType == _SIGMOID_ || layerType == _TANH_ ? true : false; 95 | } 96 | std::pair, std::vector> eltwiseTravel(const std::string eltwiseName); 97 | std::vectorfindUpChannels(const std::vector* eltwiseLayers, const std::vector* splitLayers); 98 | std::vectorfindUpFilters(const std::vector* eltwiseLayers, const std::vector* splitLayers); 99 | std::string findDown(const std::string layerName, std::vector* eltwiseLayers, std::vector* splitLayers); 100 | std::string findUp(const std::string layerName, std::vector* eltwiseLayers, std::vector* splitLayers); 101 | bool CheckIsEltwiseFilter(const std::string layerName); 102 | bool CheckIsEltwiseChannel(const std::string layerName); 103 | void eltwiseCaculate(const peltwiserecord r, std::vector* channelNeedPrune); 104 | bool checkIsConv(const std::string layerName); 105 | std::string hasBottom(const std::string layerName); 106 | std::string hasTop(const std::string layerName); 107 | void pruningByratio(void); 108 | void pruningEltwiseByratio(\ 109 | const peltwiserecord r, \ 110 | std::vector* channelNeedPrune); \ 111 | void pruningConvByratio(\ 112 | const precord r, \ 113 | std::vector* channelNeedPrune); \ 114 | void pruningBottomByratio(\ 115 | const precord r, \ 116 | std::vector* channelNeedPrune); \ 117 | int writePrototxt(\ 118 | const std::string prototxt1, \ 119 | const std::string prototxt2); \ 120 | 121 | void filterPruning(\ 122 | ::google::protobuf::RepeatedPtrField< caffe::LayerParameter >::iterator iter_, \ 123 | std::vector* channelNeedPrune) const; \ 124 | void channelPruning(\ 125 | ::google::protobuf::RepeatedPtrField< caffe::LayerParameter >::iterator iter_, \ 126 | std::vector* channelNeedPrune) const; \ 127 | void pruningBySize(); 128 | void writeModel(); 129 | virtual ~Pruner(){}; 130 | 131 | 132 | 133 | private: 134 | std::string xml_Path; 135 | std::string pruning_caffemodel_path; 136 | std::string pruning_proto_path; 137 | std::string pruned_caffemodel_path; 138 | std::string pruned_proto_path; 139 | std::string txt_proto_path; 140 | 141 | 142 | 143 | const enum ConvCalculateMode 144 | { 145 | Norm = 8, L1 = 11, L2 = 12,Variance = 16 146 | }; 147 | const enum PruningMode 148 | { 149 | ratio = 0, size = 1 150 | }; 151 | int convCalculateMode; 152 | int pruningMode; 153 | boost::shared_ptr utility_; 154 | 155 | std::vector pruning_ratio; 156 | boost::property_tree::ptree configure; 157 | caffe::NetParameter proto; 158 | std::vector conv; 159 | std::vector eltwiseConv; 160 | convParams convNeedRewriteOnPrototxt; 161 | ::google::protobuf::RepeatedPtrField< caffe::LayerParameter >* layer; 162 | mutable ::google::protobuf::RepeatedPtrField< caffe::LayerParameter >::iterator it; 163 | }; 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /Pruner.cpp: -------------------------------------------------------------------------------- 1 | #include "Pruner.h" 2 | #include "caffe/util/io.hpp" 3 | #include 4 | #include 5 | #include 6 | #include "caffe/util/math_functions.hpp" 7 | 8 | using namespace caffe; 9 | using namespace std; 10 | 11 | 12 | Pruner::Pruner(const Pruner &p) : 13 | xml_Path(p.xml_Path) 14 | { } 15 | 16 | Pruner& Pruner::operator=(const Pruner &rhs){ 17 | xml_Path = rhs.xml_Path; 18 | return *this; 19 | } 20 | 21 | Pruner::Pruner(const string xml_path){ 22 | xml_Path = xml_path; 23 | utility_ = boost::shared_ptr (new Utility()); 24 | } 25 | 26 | 27 | void Pruner::start(){ 28 | read_XML(xml_Path); 29 | import(); 30 | pruning(); 31 | writePrototxt(pruning_proto_path, pruned_proto_path); 32 | writeModel(); 33 | 34 | } 35 | 36 | void Pruner::read_XML(const string xml_path){ 37 | read_xml(xml_path, configure); 38 | pruning_caffemodel_path = configure.get("caffemodelpath"); 39 | pruning_proto_path = configure.get("protopath"); 40 | pruned_caffemodel_path = configure.get("prunedcaffemodelpath"); 41 | pruned_proto_path = configure.get("prunedprotopath"); 42 | txt_proto_path = configure.get("txtprotopath"); 43 | pruningMode = atoi(configure.get("PruningMode.mode").c_str()); 44 | convCalculateMode = atoi(configure.get("ConvCalculateMode.mode").c_str()); 45 | 46 | ReadProtoFromBinaryFile(pruning_caffemodel_path, &proto); 47 | layer = proto.mutable_layer(); 48 | //importing vanilla convolution layers'parameters from xml 49 | boost::property_tree::ptree layers = configure.get_child("filterpruning"); 50 | for (auto it1 = layers.begin(); it1 != layers.end(); it1++){ 51 | auto clayers = it1->second; 52 | if (clayers.empty()){ 53 | continue; 54 | } 55 | string name = clayers.get(".name"); 56 | if (!checkIsConv(name)){ 57 | cout << "input incorrect: " + name + " is not a convolution layer , ignoring..." << endl; 58 | Sleep(1000); 59 | continue; 60 | } 61 | double ratio = atof(clayers.get(".cut").c_str()); 62 | for (it = layer->begin(); it != layer->end(); it++){ 63 | if (name == it->name()){ 64 | pruning_ratio.push_back(convParam(name, param(ratio, it->blobs(0).shape().dim(0)))); 65 | convNeedRewriteOnPrototxt.push_back(convParam(it->name(), param(ratio, it->blobs(0).shape().dim(0)))); 66 | break; 67 | } 68 | } 69 | } 70 | 71 | //Read Eltwise Prop from XML file 72 | layers = configure.get_child("eltwise"); 73 | for (auto it1 = layers.begin(); it1 != layers.end(); it1++){ 74 | convParams conv_channels_temps; 75 | convParams conv_filters_temps; 76 | auto clayers = it1->second; 77 | if (clayers.empty()){ 78 | continue; 79 | } 80 | string name = clayers.get(".name"); 81 | double ratio = atof(clayers.get(".cut").c_str()); 82 | 83 | //Getting eltwise layers 'parameters 84 | //Getting channels'paramters 85 | pair, vector> pair = eltwiseTravel(name); 86 | vector channels = pair.first; 87 | vector filters = pair.second; 88 | for (auto it2 = channels.begin(); it2 != channels.end(); it2++){ 89 | convParam conv_temp; 90 | for (it = layer->begin(); it != layer->end(); it++){ 91 | if (*it2 == it->name()){ 92 | conv_temp.first = *it2; 93 | conv_temp.second.first = ratio; 94 | conv_temp.second.second = it->blobs(0).shape().dim(0); 95 | conv_channels_temps.push_back(conv_temp); 96 | break; 97 | } 98 | } 99 | } 100 | //Getting filters' parameters 101 | for (auto it3 = filters.begin(); it3 != filters.end(); it3++){ 102 | convParam conv_temp; 103 | for (it = layer->begin(); it != layer->end(); it++){ 104 | if (*it3 == it->name()){ 105 | conv_temp.first = *it3; 106 | conv_temp.second.first = ratio; 107 | conv_temp.second.second = it->blobs(0).shape().dim(0); 108 | conv_filters_temps.push_back(conv_temp); 109 | convNeedRewriteOnPrototxt.push_back(convParam(it->name(), param(ratio, it->blobs(0).shape().dim(0)))); 110 | break; 111 | } 112 | } 113 | } 114 | eltwiseConv.push_back(make_pair(conv_channels_temps, conv_filters_temps)); 115 | } 116 | } 117 | 118 | void Pruner::import(){ 119 | /*for (it = layer->begin(); it->name() != "conv1_bn"; ++it); 120 | it->blobs(0).num(); 121 | string x1 = it->name(); 122 | int x = it->blobs(0).shape().dim(0);*/ 123 | 124 | auto iter1 = pruning_ratio.begin(); 125 | while (iter1 != pruning_ratio.end()){ 126 | it = layer->begin(); 127 | double ratio = iter1->second.first; 128 | convParams b1; 129 | string prunedConvName = iter1->first; 130 | string poolName = "konglusen"; 131 | for (; it != layer->end(); it++){ 132 | string n = it->name(); 133 | if (it->bottom_size() != 0){ 134 | for (int i = 0; i < it->bottom_size(); i++){ 135 | if (prunedConvName == it->bottom(i)){ 136 | 137 | if(it->type()=="ReLU"){ 138 | ++it; 139 | cout<name() <type() == "Convolution"){ 142 | if (prunedConvName == it->bottom(0)){ 143 | b1.push_back(convParam(it->name(), param(ratio, it->blobs(0).shape().dim(0)))); 144 | break; 145 | } 146 | } 147 | else if (it->type() == "ConvolutionDepthwise"){ 148 | 149 | if (prunedConvName == it->bottom(0)){ 150 | b1.push_back(convParam(it->name(), param(ratio, it->blobs(0).shape().dim(0)))); 151 | convNeedRewriteOnPrototxt.push_back(convParam(it->name(), param(ratio, it->blobs(0).shape().dim(0)))); 152 | break; 153 | } 154 | } 155 | 156 | else if (it->type() == "Pooling"){ 157 | it++; 158 | vector top_names_; 159 | if (it->type() == "Split"){ 160 | for (size_t i = 0; i < it->top_size(); i++){ 161 | top_names_.push_back(it->top(i)); 162 | } 163 | } 164 | else{ 165 | top_names_.push_back(it->top(i)); 166 | } 167 | 168 | for (auto it1 = it; it1 != layer->end(); it1++){ 169 | if (it1->type() == "Convolution" || it1->type() == "ConvolutionDepthwise"){ 170 | if (find(top_names_.begin(), top_names_.end(), it1->name()) != top_names_.end()){ 171 | b1.push_back(convParam(it1->name(), param(ratio, it1->blobs(0).shape().dim(0)))); 172 | } 173 | } 174 | } 175 | } 176 | } 177 | } 178 | } 179 | else{ 180 | continue; 181 | } 182 | 183 | } 184 | conv.push_back(record(*iter1, b1)); 185 | iter1++; 186 | } 187 | } 188 | 189 | void Pruner::pruningByratio(){ 190 | for (string::size_type i = 0; i < conv.size(); i++){ 191 | vector channelNeedPrune; 192 | pruningConvByratio(&conv.at(i), &channelNeedPrune); 193 | pruningBottomByratio(&conv.at(i), &channelNeedPrune); 194 | } 195 | for (string::size_type i = 0; i < eltwiseConv.size(); i++){ 196 | vector channelNeedPrune; 197 | eltwiseCaculate(&eltwiseConv.at(i), &channelNeedPrune); 198 | pruningEltwiseByratio(&eltwiseConv.at(i), &channelNeedPrune); 199 | } 200 | } 201 | 202 | void Pruner::pruningBySize(){ 203 | 204 | } 205 | 206 | void Utility::hS(vector* a, int l, int r){ 207 | int k; 208 | int N = r - l + 1; 209 | for (k = N / 2; k >= 1; k--){ 210 | fixDown(a, k, N); 211 | } 212 | while (N > 1){ 213 | swap(a->at(1), a->at(N)); 214 | fixDown(a, 1, --N); 215 | } 216 | } 217 | 218 | void Utility::fixUp(vector* a, int k){ 219 | while (k > 1 && ([a, k]() -> bool {return (a->at(k / 2).second) < (a->at(k).second); })()){ 220 | swap(a->at(k), a->at(k / 2)); 221 | k = k / 2; 222 | } 223 | } 224 | 225 | void Utility::fixDown(vector* a, int k, int N){ 226 | int j; 227 | while (2 * k <= N){ 228 | j = 2 * k; 229 | if (j < N && ([a, j]() -> bool {return (a->at(j).second) < (a->at(j + 1).second); })()){ 230 | j++; 231 | } 232 | if (([a, j, k]() -> bool {return (a->at(k).second) > (a->at(j).second); })()){ 233 | break; 234 | } 235 | swap(a->at(k), a->at(j)); 236 | k = j; 237 | } 238 | } 239 | 240 | void Pruner::writeModel(){ 241 | WriteProtoToTextFile(proto, txt_proto_path); 242 | WriteProtoToBinaryFile(proto, pruned_caffemodel_path); 243 | } 244 | 245 | bool Pruner::checkIsConv(const string layerName){ 246 | int count = 0; 247 | for (auto it1 = layer->begin(); it1 != layer->end(); it1++){ 248 | if (it1->name() == layerName) 249 | if (it1->type() == "Convolution"){ 250 | it1++; 251 | if (it1->type() == "BatchNorm"){ 252 | it1++; 253 | it1++; 254 | if (it1->type() == "Split" || it1->type() == "Eltwise") 255 | { 256 | break; 257 | } 258 | else if (isNonLinear(it1->type())){ 259 | it1++; 260 | if (it1->type() == "Split" || it1->type() == "Eltwise") 261 | { 262 | break; 263 | } 264 | else{ 265 | count++; 266 | break; 267 | } 268 | } 269 | else 270 | { 271 | count++; 272 | break; 273 | } 274 | } 275 | else if (it1->type() == "Split" || it1->type() == "Eltwise"){ 276 | 277 | break; 278 | } 279 | else{ 280 | count++; 281 | break; 282 | } 283 | 284 | 285 | } 286 | 287 | } 288 | return (count == 1) ? true : false; 289 | } 290 | 291 | void Pruner::pruningConvByratio(const precord r, vector* pchannelNeedPrune){ 292 | 293 | for (it = layer->begin(); it != layer->end(); it++){ 294 | if (r->first.first == it->name()){ 295 | vector convlayervalue; 296 | convlayervalue.push_back(make_pair(-1, 1)); 297 | int num = it->blobs(0).shape().dim(0); 298 | int channels = it->blobs(0).shape().dim(1); 299 | int height = it->blobs(0).shape().dim(2); 300 | int width = it->blobs(0).shape().dim(3); 301 | int spatial_dim = channels * width * height; 302 | int data_count = num * spatial_dim; 303 | int cutNum = (r->first.second.second)*(r->first.second.first); 304 | 305 | // oblas calculate 306 | Blob mean_, variance_, temp_; 307 | Blob spatial_sum_multiplier_, filter_data_; 308 | Blob num_temp_, num_temp_1, num_temp_2; 309 | //BlobProto 310 | 311 | vector sz; 312 | sz.push_back(spatial_dim); 313 | spatial_sum_multiplier_.Reshape(sz); 314 | 315 | double* multiplier_data = spatial_sum_multiplier_.mutable_cpu_data(); 316 | caffe_set(spatial_sum_multiplier_.count(), double(1), multiplier_data); 317 | //Blob filter_data_ = 318 | 319 | //Modifying the kernel heap by traveling through the computed-average-kernel's -size then sort 320 | BlobProto blobData = it->blobs(0); 321 | BlobProto blobData1; 322 | double maxData = 0.0; 323 | int k = blobData.data_size(); 324 | filter_data_.FromProto(blobData, true); 325 | sz[0] = num; 326 | mean_.Reshape(sz); 327 | temp_.Reshape(sz); 328 | variance_.Reshape(sz); 329 | sz[0] = num*spatial_dim; 330 | num_temp_.Reshape(sz); 331 | num_temp_1.FromProto(blobData, true); 332 | num_temp_2.Reshape(sz); 333 | switch (convCalculateMode) 334 | { 335 | case Pruner::Variance: 336 | caffe_cpu_gemv(CblasNoTrans, num, spatial_dim, 1. / spatial_dim, filter_data_.cpu_data(), 337 | spatial_sum_multiplier_.cpu_data(), 0., mean_.mutable_cpu_data()); 338 | caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, spatial_dim, num, 1, -1, 339 | spatial_sum_multiplier_.cpu_data(), mean_.cpu_data(), 1, num_temp_1.mutable_cpu_data()); 340 | caffe_powx(num_temp_1.count(), num_temp_1.cpu_data(), double(2), num_temp_2.mutable_cpu_data()); 341 | caffe_cpu_gemv(CblasNoTrans, num, spatial_dim, 1. / spatial_dim, num_temp_2.cpu_data(), 342 | spatial_sum_multiplier_.cpu_data(), 0., variance_.mutable_cpu_data()); 343 | variance_.ToProto(&blobData1, false); 344 | for (size_t i = 0; i < variance_.count(); i++){ 345 | atom a = make_pair(i, blobData1.double_data(i)); 346 | convlayervalue.push_back(a); 347 | } 348 | break; 349 | case Pruner::Norm: 350 | for (size_t i = 0; i < blobData.data_size(); i++){ 351 | if (maxData < abs(blobData.data(i))){ 352 | maxData = abs(blobData.data(i)); 353 | } 354 | } 355 | for (int i = 0; i < num; i++){ 356 | double value = 0.0; 357 | for (int j = 0; j < spatial_dim; j++){ 358 | value += abs(blobData.data(i*spatial_dim + j)) / maxData; 359 | } 360 | atom a = make_pair(i, value / spatial_dim); 361 | convlayervalue.push_back(a); 362 | } 363 | 364 | break; 365 | case Pruner::L1: 366 | caffe_abs(filter_data_.count(), filter_data_.cpu_data(), filter_data_.mutable_cpu_data()); 367 | caffe_cpu_gemv(CblasNoTrans, num, spatial_dim, 1. / spatial_dim, filter_data_.cpu_data(), 368 | spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data()); 369 | temp_.ToProto(&blobData1, false); 370 | for (size_t i = 0; i < temp_.count(); i++){ 371 | atom a = make_pair(i, blobData1.double_data(i)); 372 | convlayervalue.push_back(a); 373 | } 374 | break; 375 | case Pruner::L2: 376 | caffe_powx(filter_data_.count(), filter_data_.cpu_data(), double(2), filter_data_.mutable_cpu_data()); 377 | caffe_cpu_gemv(CblasNoTrans, num, spatial_dim, 1. / spatial_dim, filter_data_.cpu_data(), 378 | spatial_sum_multiplier_.cpu_data(), 0., temp_.mutable_cpu_data()); 379 | temp_.ToProto(&blobData1, false); 380 | for (size_t i = 0; i < temp_.count(); i++){ 381 | atom a = make_pair(i, blobData1.double_data(i)); 382 | convlayervalue.push_back(a); 383 | } 384 | break; 385 | default: 386 | break; 387 | } 388 | utility_->hS(&convlayervalue, 1, num); 389 | for (int i = 0; i < cutNum; i++){ 390 | pchannelNeedPrune->push_back(convlayervalue.at(i + 1).first); 391 | } 392 | //start prune 393 | this->filterPruning(it, pchannelNeedPrune); 394 | 395 | } 396 | } 397 | } 398 | 399 | void Pruner::pruningBottomByratio(const precord r, vector* pchannelNeedPrune){ 400 | //preform pruning on next layer 401 | int num = r->first.second.second; 402 | int cutNum = (r->first.second.second)*(r->first.second.first); 403 | string::size_type i1 = r->second.size(); 404 | for (string::size_type k = 0; k < r->second.size(); k++){ 405 | convParam conv1 = r->second[k]; 406 | string n = conv1.first; 407 | for (it = layer->begin(); it != layer->end(); it++){ 408 | if (it->name() == conv1.first){ 409 | if (it->type() == "Convolution"){ 410 | this->channelPruning(it, pchannelNeedPrune); 411 | break; 412 | } 413 | else if (it->type() == "ConvolutionDepthwise"){ 414 | this->filterPruning(it, pchannelNeedPrune); 415 | it++; 416 | 417 | while (it->type() != "Convolution"){ 418 | it++; 419 | } 420 | 421 | //start prune pointwise conv layer which subsequent to depthwiseConv 422 | string name1 = it->name(); 423 | if (it->type() == "Convolution"){ 424 | this->channelPruning(it, pchannelNeedPrune); 425 | } 426 | 427 | break; 428 | } 429 | } 430 | 431 | } 432 | 433 | } 434 | } 435 | 436 | void Pruner::filterPruning(::google::protobuf::RepeatedPtrField< caffe::LayerParameter >::iterator iter_, vector* pchannelNeedPrune) const{ 437 | int_64 filter_count = iter_->blobs(0).shape().dim(0); 438 | int_64 channels = iter_->blobs(0).shape().dim(1); 439 | int_64 height = iter_->blobs(0).shape().dim(2); 440 | int_64 width = iter_->blobs(0).shape().dim(3); 441 | int count = channels * width * height; 442 | int cutNum = pchannelNeedPrune->size(); 443 | BlobProto *blob_ = iter_->mutable_blobs(0); 444 | BlobProto blob = iter_->blobs(0); 445 | blob_->clear_data(); 446 | vector::const_iterator beg = pchannelNeedPrune->cbegin(); 447 | vector::const_iterator end = pchannelNeedPrune->cend(); 448 | for (int j = 0; j < filter_count; j++){ 449 | if (find(beg, end, j) == pchannelNeedPrune->cend()){ 450 | for (int g = 0; g < count; g++){ 451 | blob_->add_data(blob.data(j*count + g)); 452 | } 453 | } 454 | } 455 | BlobShape shape; 456 | shape.add_dim(filter_count - cutNum); 457 | shape.add_dim(channels); 458 | shape.add_dim(height); 459 | shape.add_dim(width); 460 | blob_->mutable_shape()->CopyFrom(shape); 461 | 462 | // We will perform bias update based if the bias of conv existed. 463 | if (iter_->blobs_size() > 1){ 464 | BlobProto *blob_ = iter_->mutable_blobs(1); 465 | BlobProto blob = iter_->blobs(1); 466 | blob_->clear_data(); 467 | for (int j = 0; j < filter_count; j++){ 468 | if (find(beg, end, j) == pchannelNeedPrune->cend()){ 469 | blob_->add_data(blob.data(j)); 470 | } 471 | } 472 | BlobShape shape; 473 | shape.add_dim(filter_count - cutNum); 474 | blob_->mutable_shape()->CopyFrom(shape); 475 | } 476 | 477 | iter_->mutable_convolution_param()->set_num_output(filter_count - cutNum); 478 | 479 | if ((++iter_)->type() == "BatchNorm"){ 480 | int_64 bn_count = iter_->blobs(0).shape().dim(0); 481 | BlobProto *bnBlob0_ = iter_->mutable_blobs(0); 482 | BlobProto bnBlob0 = iter_->blobs(0); 483 | bnBlob0_->clear_data(); 484 | for (int j = 0; j < bn_count; j++){ 485 | if (find(beg, end, j) == pchannelNeedPrune->cend()){ 486 | bnBlob0_->add_data(bnBlob0.data(j)); 487 | } 488 | } 489 | BlobShape shape0; 490 | shape0.add_dim(bn_count - cutNum); 491 | bnBlob0_->mutable_shape()->CopyFrom(shape0); 492 | 493 | BlobProto *bnBlob1_ = iter_->mutable_blobs(1); 494 | BlobProto bnBlob1 = iter_->blobs(1); 495 | bnBlob1_->clear_data(); 496 | for (int j = 0; j < bn_count; j++){ 497 | if (find(beg, end, j) == pchannelNeedPrune->cend()){ 498 | bnBlob1_->add_data(bnBlob1.data(j)); 499 | } 500 | } 501 | BlobShape shape1; 502 | shape1.add_dim(bn_count - cutNum); 503 | bnBlob1_->mutable_shape()->CopyFrom(shape1); 504 | iter_++; 505 | 506 | BlobProto *sBlob0_ = iter_->mutable_blobs(0); 507 | BlobProto sBlob0 = iter_->blobs(0); 508 | sBlob0_->clear_data(); 509 | for (int j = 0; j < bn_count; j++){ 510 | if (find(beg, end, j) == pchannelNeedPrune->cend()){ 511 | sBlob0_->add_data(sBlob0.data(j)); 512 | } 513 | } 514 | BlobShape shape2; 515 | shape2.add_dim(bn_count - cutNum); 516 | sBlob0_->mutable_shape()->CopyFrom(shape2); 517 | 518 | BlobProto *sBlob1_ = iter_->mutable_blobs(1); 519 | BlobProto sBlob1 = iter_->blobs(1); 520 | sBlob1_->clear_data(); 521 | for (int j = 0; j < bn_count; j++){ 522 | if (find(beg, end, j) == pchannelNeedPrune->cend()){ 523 | sBlob1_->add_data(sBlob1.data(j)); 524 | } 525 | } 526 | BlobShape shape3; 527 | shape3.add_dim(bn_count - cutNum); 528 | sBlob1_->mutable_shape()->CopyFrom(shape3); 529 | } 530 | } 531 | 532 | void Pruner::channelPruning(::google::protobuf::RepeatedPtrField< caffe::LayerParameter >::iterator iter_, vector* pchannelNeedPrune) const{ 533 | 534 | int_64 nextLayKerNum = iter_->blobs(0).shape().dim(0); 535 | int_64 nextLayChannel = iter_->blobs(0).shape().dim(1); 536 | int_64 nextLayKerH = iter_->blobs(0).shape().dim(2); 537 | int_64 nextLayKerW = iter_->blobs(0).shape().dim(3); 538 | int cutNum = pchannelNeedPrune->size(); 539 | int counts = nextLayChannel * nextLayKerH * nextLayKerW; 540 | int dimSize = nextLayKerH * nextLayKerW; 541 | BlobProto *blob1_ = iter_->mutable_blobs(0); 542 | BlobProto blob1 = iter_->blobs(0); 543 | blob1_->clear_data(); 544 | vector::const_iterator beg = pchannelNeedPrune->cbegin(); 545 | vector::const_iterator end = pchannelNeedPrune->cend(); 546 | for (int j = 0; j < nextLayKerNum; j++){ 547 | for (int g = 0; g < nextLayChannel; g++){ 548 | if (find(beg, end, g) == pchannelNeedPrune->cend()){ 549 | for (int m = 0; m < dimSize; m++){ 550 | blob1_->add_data(blob1.data(j * counts + g * dimSize + m)); 551 | } 552 | } 553 | } 554 | } 555 | BlobShape shape1; 556 | shape1.add_dim(nextLayKerNum); 557 | shape1.add_dim(nextLayChannel - cutNum); 558 | shape1.add_dim(nextLayKerH); 559 | shape1.add_dim(nextLayKerW); 560 | blob1_->mutable_shape()->CopyFrom(shape1); 561 | 562 | } 563 | 564 | int Pruner::writePrototxt(const string prototxt1, const string prototxt2){ 565 | fstream fin_in(pruning_proto_path, ios::in | ios::binary); 566 | fstream fin_out(pruned_proto_path, ios::out | ios::binary); 567 | if (!fin_in || !fin_out) 568 | return 0; 569 | string str1 = "name"; 570 | string str2 = "num_output"; 571 | string str3 = "type"; 572 | string str; 573 | string nametemp; 574 | bool final_flag = false; 575 | bool nor_flag = false; 576 | int prunedNum; 577 | while (getline(fin_in, str)){ 578 | if (str.find("prob") != -1){ 579 | final_flag = true; 580 | } 581 | if (final_flag == true){ 582 | fin_out << str << '\n'; 583 | continue; 584 | } 585 | int index = -1; 586 | if (str.find(str1) != -1){ 587 | for (auto& r : convNeedRewriteOnPrototxt){ 588 | string s = '"' + r.first + '"'; 589 | index = str.find(s); 590 | if (index != -1){ 591 | int num = r.second.second; 592 | int cut = r.second.first*r.second.second; 593 | prunedNum = num - cut; 594 | nor_flag = true; 595 | break; 596 | } 597 | } 598 | } 599 | if (str.find(str2) != -1){ 600 | if (!nor_flag){ 601 | fin_out << str << '\n'; 602 | } 603 | else{ 604 | fin_out << " num_output: " + to_string(prunedNum) << '\n'; 605 | nor_flag = false; 606 | } 607 | } 608 | else{ 609 | fin_out << str << '\n'; 610 | } 611 | } 612 | return 1; 613 | } 614 | 615 | void Pruner::eltwiseCaculate(const peltwiserecord r, vector* channelNeedPrune){ 616 | 617 | unsigned blob_size = r->second.at(0).second.second; 618 | unsigned cutNum = (r->second.at(0).second.second) * (r->second.at(0).second.first); 619 | vector convlayervalue; 620 | convlayervalue.push_back(make_pair(-1, 1)); 621 | double *p_arr = new double[blob_size]; 622 | for (int i = 0; i < blob_size; i++){ 623 | p_arr[i] = 0; 624 | } 625 | for (string::size_type i = 0; i < r->second.size(); i++){ 626 | for (it = layer->begin(); it != layer->end(); it++){ 627 | int_64 num, channels, height, width; 628 | int count, cutNum; 629 | if (it->name() == r->second.at(i).first){ 630 | num = it->blobs(0).shape().dim(0); 631 | channels = it->blobs(0).shape().dim(1); 632 | height = it->blobs(0).shape().dim(2); 633 | width = it->blobs(0).shape().dim(3); 634 | count = channels * width * height; 635 | cutNum = r->second.at(0).second.first * r->second.at(0).second.second; 636 | BlobProto blobData = it->blobs(0); 637 | for (int_64 j = 0; j < num; j++){ 638 | double value = 0.0; 639 | for (int k = 0; k < count; k++){ 640 | value += abs(blobData.data(j*count + k)); 641 | } 642 | p_arr[i] = p_arr[i] + value / count; 643 | } 644 | } 645 | } 646 | } 647 | for (int i = 0; i < blob_size; i++){ 648 | atom a = make_pair(i, p_arr[i]); 649 | convlayervalue.push_back(a); 650 | } 651 | utility_->hS(&convlayervalue, 1, blob_size); 652 | for (int i = 0; i < cutNum; i++){ 653 | channelNeedPrune->push_back(convlayervalue.at(i + 1).first); 654 | } 655 | } 656 | 657 | void Pruner::pruningEltwiseByratio(const peltwiserecord r, vector* channelNeedPrune){ 658 | for (auto iter = r->second.begin(); iter != r->second.end(); iter++){ 659 | for (it = layer->begin(); it != layer->end(); it++){ 660 | if (it->name() == iter->first){ 661 | this->filterPruning(it, channelNeedPrune); 662 | } 663 | } 664 | } 665 | for (auto iter = r->first.begin(); iter != r->first.end(); iter++){ 666 | for (it = layer->begin(); it != layer->end(); it++){ 667 | if (it->name() == iter->first){ 668 | this->channelPruning(it, channelNeedPrune); 669 | break; 670 | } 671 | } 672 | } 673 | } 674 | 675 | pair, vector> Pruner::eltwiseTravel(const string eltwiseName){ 676 | //Check 677 | auto it = layer->begin(); 678 | for (; it != layer->end(); it++){ 679 | if (it->name() == eltwiseName && it->type() == "Eltwise"){ 680 | break; 681 | } 682 | else if (it->name() == eltwiseName && it->type() != "Eltwise") 683 | { 684 | cout << eltwiseName << " is not an eltwise layer" << endl; 685 | system("pause"); 686 | } 687 | } 688 | pair, vector> p; 689 | vector eltwiseLayers = { eltwiseName }; 690 | vector splitLayers; 691 | vector filters; 692 | vector channels; 693 | this->findDown(eltwiseName, &eltwiseLayers, &splitLayers); 694 | string conv_temp = this->findUp(eltwiseName, &eltwiseLayers, &splitLayers); 695 | filters = this->findUpFilters(&eltwiseLayers, &splitLayers); 696 | channels = this->findUpChannels(&eltwiseLayers, &splitLayers); 697 | if ("stop" != conv_temp){ 698 | channels.push_back(conv_temp); 699 | } 700 | p.first = channels; 701 | p.second = filters; 702 | return p; 703 | } 704 | 705 | vector Pruner::findUpChannels(const vector* eltwiseLayers, const vector* splitLayers){ 706 | vector Channels; 707 | for (string::size_type k = 0; k < splitLayers->size(); k++){ 708 | auto it = layer->begin(); 709 | while (it->name() != splitLayers->at(k))it++; 710 | if (it->top_size() != 0){ 711 | for (int i = 0; i < it->top_size(); i++){ 712 | string temp = hasBottom(it->top(i)); 713 | if (temp != ""){ 714 | Channels.push_back(temp); 715 | } 716 | } 717 | } 718 | } 719 | return Channels; 720 | } 721 | 722 | vector Pruner::findUpFilters(const vector* eltwiseLayers, const vector* splitLayers){ 723 | vector Filters; 724 | for (string::size_type k = 0; k < eltwiseLayers->size(); k++){ 725 | auto it = layer->begin(); 726 | while (it->name() != eltwiseLayers->at(k))it++; 727 | if (it->bottom_size() != 0){ 728 | for (int i = 0; i < it->bottom_size(); i++){ 729 | if (it->bottom(i).find("split") != string::npos){ 730 | continue; 731 | } 732 | Filters.push_back(it->bottom(i)); 733 | } 734 | } 735 | } 736 | for (string::size_type k = 0; k < splitLayers->size(); k++){ 737 | auto it = layer->begin(); 738 | while (it->name() != splitLayers->at(k))it++; 739 | it--; 740 | if (it->type() == "Eltwise" || it->type() == "Pooling"){ 741 | continue; 742 | } 743 | else if (isNonLinear(it->type())){ 744 | it--; 745 | if (it->type() == "Eltwise"){ 746 | continue; 747 | } 748 | } 749 | while (it->type() != "Convolution") 750 | { 751 | it--; 752 | } 753 | Filters.push_back(it->name()); 754 | } 755 | return Filters; 756 | } 757 | 758 | string Pruner::findDown(const string layerName, vector* eltwiseLayers, vector* splitLayers){ 759 | auto it = layer->begin(); 760 | for (; it != layer->end(); it++){ 761 | if (it->name() == layerName){ 762 | string n = it->name(); 763 | break; 764 | } 765 | } 766 | if (it->bottom_size() != 0){ 767 | for (int i = 0; i < it->bottom_size(); i++){ 768 | if (it->bottom(i).find("split") != string::npos){ 769 | string splitLayerTopName = it->bottom(i); 770 | string splitLayerName = splitLayerTopName.substr(0, splitLayerTopName.find_last_of("_")); 771 | splitLayers->push_back(splitLayerName); 772 | it--; 773 | while (it->name() != splitLayerName){ 774 | it--; 775 | } 776 | it--; 777 | if (isNonLinear(it->type())){ 778 | it--; 779 | if (it->type() == "Eltwise"){ 780 | eltwiseLayers->push_back(it->name()); 781 | return this->findDown(it->name(), eltwiseLayers, splitLayers); 782 | } 783 | } 784 | else if (it->type() == "Eltwise"){ 785 | eltwiseLayers->push_back(it->name()); 786 | return this->findDown(it->name(), eltwiseLayers, splitLayers); 787 | } 788 | else 789 | { 790 | break; 791 | } 792 | } 793 | } 794 | } 795 | return "dytto"; 796 | } 797 | 798 | string Pruner::findUp(const string layerName, vector* eltwiseLayers, vector* splitLayers){ 799 | auto it = layer->begin(); 800 | for (; it != layer->end(); it++){ 801 | if (it->name() == layerName){ 802 | break; 803 | } 804 | } 805 | while (it->type() != "Split" && it->type() != "Convolution" && it->type() != "Pooling"){ 806 | it++; 807 | } 808 | if (it->type() == "Pooling"){ 809 | it++; 810 | if (it->type() == "Convolution"){ 811 | return it->name(); 812 | } 813 | else if (it->type() == "Split") 814 | { 815 | splitLayers->push_back(it->name()); 816 | return "stop"; 817 | } 818 | } 819 | else if (it->type() == "Split"){ 820 | splitLayers->push_back(it->name()); 821 | return "stop"; 822 | } 823 | else if (it->type() == "Convolution"){ 824 | return it->name(); 825 | } 826 | } 827 | 828 | string Pruner::hasBottom(const string layerName){ 829 | auto it = layer->begin(); 830 | for (; it != layer->end(); it++){ 831 | if (it->type() == "Convolution"){ 832 | if (it->bottom_size() > 0){ 833 | for (int k = 0; k < it->bottom_size(); k++){ 834 | if (it->bottom(k) == layerName){ 835 | return it->name(); 836 | } 837 | } 838 | } 839 | } 840 | } 841 | return ""; 842 | } 843 | 844 | string Pruner::hasTop(const string layerName){ 845 | auto it = layer->begin(); 846 | for (; it != layer->end(); it++){ 847 | if (it->type() == "Convolution"){ 848 | if (it->top_size() > 0){ 849 | for (int k = 0; k < it->top_size(); k++){ 850 | if (it->top(k) == layerName){ 851 | return it->name(); 852 | } 853 | } 854 | } 855 | } 856 | } 857 | return ""; 858 | } 859 | --------------------------------------------------------------------------------