├── imgs ├── 0.bmp ├── 1.bmp ├── 2.bmp ├── 3.bmp ├── 4.bmp ├── 5.bmp ├── 6.bmp ├── 7.bmp ├── 8.bmp └── 9.bmp ├── src ├── LeNet.h ├── Caffe2Net.h ├── LeNet.cpp ├── main.cpp └── Caffe2Net.cpp ├── deploy_models ├── mnist_init_net.pbtxt └── mnist_predict_net.pbtxt ├── README.md ├── Makefile ├── test_plan.pbtxt └── train_plan.pbtxt /imgs/0.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/0.bmp -------------------------------------------------------------------------------- /imgs/1.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/1.bmp -------------------------------------------------------------------------------- /imgs/2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/2.bmp -------------------------------------------------------------------------------- /imgs/3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/3.bmp -------------------------------------------------------------------------------- /imgs/4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/4.bmp -------------------------------------------------------------------------------- /imgs/5.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/5.bmp -------------------------------------------------------------------------------- /imgs/6.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/6.bmp -------------------------------------------------------------------------------- /imgs/7.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/7.bmp -------------------------------------------------------------------------------- /imgs/8.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/8.bmp -------------------------------------------------------------------------------- /imgs/9.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadbread1984/Caffe2-C-demo/HEAD/imgs/9.bmp -------------------------------------------------------------------------------- /src/LeNet.h: -------------------------------------------------------------------------------- 1 | #ifndef LENET_H 2 | #define LENET_H 3 | 4 | #include "Caffe2Net.h" 5 | 6 | using namespace std; 7 | using namespace cv; 8 | using namespace caffe2; 9 | 10 | class LeNet : public Caffe2Net { 11 | public: 12 | LeNet(string initNet,string predictNet); 13 | virtual ~LeNet(); 14 | protected: 15 | virtual TensorCPU preProcess(Mat img); 16 | virtual vector postProcess(TensorCPU output); 17 | }; 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /deploy_models/mnist_init_net.pbtxt: -------------------------------------------------------------------------------- 1 | name: "mnist_init_net" 2 | #注意所有的fill操作符都变成了一个Load操作符 3 | op { 4 | type: "ConstantFill" 5 | output: "data" 6 | arg { 7 | name: "shape" 8 | ints: 1 9 | } 10 | } 11 | op { 12 | type: "Load" 13 | output: "conv1_w" 14 | output: "conv1_b" 15 | output: "conv2_w" 16 | output: "conv2_b" 17 | output: "fc3_w" 18 | output: "fc3_b" 19 | output: "pred_w" 20 | output: "pred_b" 21 | arg { 22 | name: "db" 23 | s: "LeNet_params" 24 | } 25 | arg { 26 | name: "db_type" 27 | s: "lmdb" 28 | } 29 | } 30 | device_option { 31 | device_type: 1 32 | } 33 | -------------------------------------------------------------------------------- /src/Caffe2Net.h: -------------------------------------------------------------------------------- 1 | #ifndef CAFFE2NET_H 2 | #define CAFFE2NET_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | using namespace std; 14 | using namespace cv; 15 | using namespace caffe2; 16 | 17 | class Caffe2Net { 18 | public: 19 | Caffe2Net(string initNet,string predictNet); 20 | virtual ~Caffe2Net() = 0; 21 | vector predict(Mat img); 22 | protected: 23 | virtual TensorCPU preProcess(Mat img) = 0; 24 | virtual vector postProcess(TensorCPU output) = 0; 25 | 26 | Workspace workspace; 27 | unique_ptr predict_net; 28 | }; 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /src/LeNet.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "LeNet.h" 3 | 4 | using namespace std; 5 | 6 | LeNet::LeNet(string initNet,string predictNet) 7 | :Caffe2Net(initNet,predictNet) 8 | { 9 | } 10 | 11 | LeNet::~LeNet() 12 | { 13 | } 14 | 15 | TensorCPU LeNet::preProcess(Mat img) 16 | { 17 | assert(img.channels() == 1); 18 | assert(img.rows == 28); 19 | assert(img.cols == 28); 20 | vector dims({1, img.channels(), img.rows, img.cols}); 21 | vector data(1 * 1 * 28 * 28); 22 | 23 | img.convertTo(img, CV_32FC1, 1.0/256,0); 24 | copy((float *)img.datastart, (float *)img.dataend,data.begin()); 25 | 26 | return TensorCPU(dims, data, NULL); 27 | } 28 | 29 | vector LeNet::postProcess(TensorCPU output) 30 | { 31 | const float * probs = output.data(); 32 | vector dims = output.dims(); 33 | //检查输出的dims是否正确 34 | assert(2 == output.ndim()); 35 | assert(1 == dims[0]); 36 | assert(10 == dims[1]); 37 | vector retVal(dims[1]); 38 | copy(probs,probs+dims[1],retVal.begin()); 39 | return retVal; 40 | } 41 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "LeNet.h" 7 | 8 | using namespace std; 9 | using namespace boost::program_options; 10 | using namespace cv; 11 | 12 | int main(int argc,char ** argv) 13 | { 14 | string img_path; 15 | options_description desc; 16 | desc.add_options() 17 | ("help,h","打印当前使用方法") 18 | ("input,i",value(&img_path),"输入图片路径"); 19 | variables_map vm; 20 | store(parse_command_line(argc,argv,desc),vm); 21 | notify(vm); 22 | 23 | if(1 == argc || 1 != vm.count("input") || 1 == vm.count("help")) { 24 | cout< result = lenet.predict(img); 36 | vector::iterator max_iter = max_element(result.begin(),result.end()); 37 | cout< tmp/train-images-idx3-ubyte 15 | curl --progress-bar http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz | gunzip > tmp/train-labels-idx1-ubyte 16 | make_mnist_db --image_file=tmp/train-images-idx3-ubyte --label_file=tmp/train-labels-idx1-ubyte --output_file=mnist-train-nchw-leveldb --channel_first --db leveldb 17 | curl --progress-bar http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz | gunzip > tmp/t10k-images-idx3-ubyte 18 | curl --progress-bar http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz | gunzip > tmp/t10k-labels-idx1-ubyte 19 | make_mnist_db --image_file=tmp/t10k-images-idx3-ubyte --label_file=tmp/t10k-labels-idx1-ubyte --output_file=mnist-test-nchw-leveldb --channel_first --db leveldb 20 | $(RM) -r tmp 21 | 22 | train: train_plan.pbtxt 23 | run_plan --plan $^ 24 | 25 | test: test_plan.pbtxt 26 | run_plan --plan $^ 27 | 28 | clean: 29 | $(RM) *.log *.summary 30 | $(RM) -r LeNet_params 31 | $(RM) predictor $(OBJS) 32 | -------------------------------------------------------------------------------- /src/Caffe2Net.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "Caffe2Net.h" 3 | 4 | Caffe2Net::Caffe2Net(string initNet,string predictNet) 5 | :workspace(nullptr) 6 | { 7 | #ifdef WITH_CUDA 8 | DeviceOption option; 9 | option.set_device_type(CUDA); 10 | new CUDAContext(option); 11 | #endif 12 | //载入部署模型 13 | NetDef init_net_def, predict_net_def; 14 | CAFFE_ENFORCE(ReadProtoFromFile(initNet, &init_net_def)); 15 | CAFFE_ENFORCE(ReadProtoFromFile(predictNet, &predict_net_def)); 16 | #ifdef WITH_CUDA 17 | init_net_def.mutable_device_option()->set_device_type(CUDA); 18 | predict_net_def.mutable_device_option()->set_device_type(CUDA); 19 | #else 20 | init_net_def.mutable_device_option()->set_device_type(CPU); 21 | predict_net_def.mutable_device_option()->set_device_type(CPU); 22 | #endif 23 | //网络初始化 24 | workspace.RunNetOnce(init_net_def); 25 | //创建判别器 26 | predict_net = CreateNet(predict_net_def,&workspace); 27 | } 28 | 29 | Caffe2Net::~Caffe2Net() 30 | { 31 | } 32 | 33 | vector Caffe2Net::predict(Mat img) 34 | { 35 | //create input blob 36 | #ifdef WITH_CUDA 37 | TensorCUDA input = TensorCUDA(preProcess(img)); 38 | auto tensor = workspace.CreateBlob("data")->GetMutable(); 39 | #else 40 | TensorCPU input = preProcess(img); 41 | auto tensor = workspace.CreateBlob("data")->GetMutable(); 42 | #endif 43 | tensor->ResizeLike(input); 44 | tensor->ShareData(input); 45 | //predict 46 | predict_net->Run(); 47 | //get output blob 48 | #ifdef WITH_CUDA 49 | TensorCPU output = TensorCPU(workspace.GetBlob("softmax")->Get()); 50 | #else 51 | TensorCPU output = TensorCPU(workspace.GetBlob("softmax")->Get()); 52 | #endif 53 | return postProcess(output); 54 | } 55 | 56 | TensorCPU Caffe2Net::preProcess(Mat img) 57 | { 58 | } 59 | 60 | vector Caffe2Net::postProcess(TensorCPU output) 61 | { 62 | } 63 | -------------------------------------------------------------------------------- /deploy_models/mnist_predict_net.pbtxt: -------------------------------------------------------------------------------- 1 | name: "mnist_predict_net" 2 | op { 3 | input: "data" 4 | input: "conv1_w" 5 | input: "conv1_b" 6 | output: "conv1" 7 | type: "Conv" 8 | arg { 9 | name: "stride" 10 | i: 1 11 | } 12 | arg { 13 | name: "pad" 14 | i: 0 15 | } 16 | arg { 17 | name: "kernel" 18 | i: 5 19 | } 20 | } 21 | op { 22 | input: "conv1" 23 | output: "pool1" 24 | type: "MaxPool" 25 | arg { 26 | name: "stride" 27 | i: 2 28 | } 29 | arg { 30 | name: "pad" 31 | i: 0 32 | } 33 | arg { 34 | name: "kernel" 35 | i: 2 36 | } 37 | arg { 38 | name: "order" 39 | s: "NCHW" 40 | } 41 | arg { 42 | name: "legacy_pad" 43 | i: 3 44 | } 45 | } 46 | op { 47 | input: "pool1" 48 | input: "conv2_w" 49 | input: "conv2_b" 50 | output: "conv2" 51 | type: "Conv" 52 | arg { 53 | name: "stride" 54 | i: 1 55 | } 56 | arg { 57 | name: "pad" 58 | i: 0 59 | } 60 | arg { 61 | name: "kernel" 62 | i: 5 63 | } 64 | } 65 | op { 66 | input: "conv2" 67 | output: "pool2" 68 | type: "MaxPool" 69 | arg { 70 | name: "stride" 71 | i: 2 72 | } 73 | arg { 74 | name: "pad" 75 | i: 0 76 | } 77 | arg { 78 | name: "kernel" 79 | i: 2 80 | } 81 | arg { 82 | name: "order" 83 | s: "NCHW" 84 | } 85 | arg { 86 | name: "legacy_pad" 87 | i: 3 88 | } 89 | } 90 | op { 91 | input: "pool2" 92 | input: "fc3_w" 93 | input: "fc3_b" 94 | output: "fc3" 95 | type: "FC" 96 | } 97 | op { 98 | input: "fc3" 99 | output: "fc3" 100 | type: "Relu" 101 | } 102 | op { 103 | input: "fc3" 104 | input: "pred_w" 105 | input: "pred_b" 106 | output: "pred" 107 | type: "FC" 108 | } 109 | op { 110 | input: "pred" 111 | output: "softmax" 112 | type: "Softmax" 113 | } 114 | device_option { 115 | device_type: 1 116 | } 117 | external_input: "data" 118 | external_input: "conv1_w" 119 | external_input: "conv1_b" 120 | external_input: "conv2_w" 121 | external_input: "conv2_b" 122 | external_input: "fc3_w" 123 | external_input: "fc3_b" 124 | external_input: "pred_w" 125 | external_input: "pred_b" 126 | external_output: "softmax" 127 | -------------------------------------------------------------------------------- /test_plan.pbtxt: -------------------------------------------------------------------------------- 1 | name: "mnist_train_plan" 2 | network{ 3 | name: "mnist_init_net" 4 | op { 5 | output: "dbreader" 6 | type: "CreateDB" 7 | arg { 8 | name: "db_type" 9 | s: "leveldb" 10 | } 11 | arg { 12 | name: "db" 13 | s: "mnist-test-nchw-leveldb" 14 | } 15 | } 16 | op { 17 | type: "Load" 18 | output: "conv1_w" 19 | output: "conv1_b" 20 | output: "conv2_w" 21 | output: "conv2_b" 22 | output: "fc3_w" 23 | output: "fc3_b" 24 | output: "pred_w" 25 | output: "pred_b" 26 | arg { 27 | name: "db" 28 | s: "LeNet_params" 29 | } 30 | arg { 31 | name: "db_type" 32 | s: "lmdb" 33 | } 34 | } 35 | device_option { 36 | device_type: 1 37 | } 38 | } 39 | network{ 40 | name: "mnist_test_net" 41 | op { 42 | input: "dbreader" 43 | output: "data_uint8" 44 | output: "label" 45 | type: "TensorProtosDBInput" 46 | arg { 47 | name: "batch_size" 48 | i: 100 49 | } 50 | } 51 | op { 52 | input: "data_uint8" 53 | output: "data" 54 | type: "Cast" 55 | arg { 56 | name: "to" 57 | i: 1 58 | } 59 | } 60 | op { 61 | input: "data" 62 | output: "data" 63 | type: "Scale" 64 | arg { 65 | name: "scale" 66 | f: 0.00390625 67 | } 68 | } 69 | op { 70 | input: "data" 71 | output: "data" 72 | type: "StopGradient" 73 | } 74 | op { 75 | input: "data" 76 | input: "conv1_w" 77 | input: "conv1_b" 78 | output: "conv1" 79 | type: "Conv" 80 | arg { 81 | name: "stride" 82 | i: 1 83 | } 84 | arg { 85 | name: "pad" 86 | i: 0 87 | } 88 | arg { 89 | name: "kernel" 90 | i: 5 91 | } 92 | } 93 | op { 94 | input: "conv1" 95 | output: "pool1" 96 | type: "MaxPool" 97 | arg { 98 | name: "stride" 99 | i: 2 100 | } 101 | arg { 102 | name: "pad" 103 | i: 0 104 | } 105 | arg { 106 | name: "kernel" 107 | i: 2 108 | } 109 | arg { 110 | name: "order" 111 | s: "NCHW" 112 | } 113 | arg { 114 | name: "legacy_pad" 115 | i: 3 116 | } 117 | } 118 | op { 119 | input: "pool1" 120 | input: "conv2_w" 121 | input: "conv2_b" 122 | output: "conv2" 123 | type: "Conv" 124 | arg { 125 | name: "stride" 126 | i: 1 127 | } 128 | arg { 129 | name: "pad" 130 | i: 0 131 | } 132 | arg { 133 | name: "kernel" 134 | i: 5 135 | } 136 | } 137 | op { 138 | input: "conv2" 139 | output: "pool2" 140 | type: "MaxPool" 141 | arg { 142 | name: "stride" 143 | i: 2 144 | } 145 | arg { 146 | name: "pad" 147 | i: 0 148 | } 149 | arg { 150 | name: "kernel" 151 | i: 2 152 | } 153 | arg { 154 | name: "order" 155 | s: "NCHW" 156 | } 157 | arg { 158 | name: "legacy_pad" 159 | i: 3 160 | } 161 | } 162 | op { 163 | input: "pool2" 164 | input: "fc3_w" 165 | input: "fc3_b" 166 | output: "fc3" 167 | type: "FC" 168 | } 169 | op { 170 | input: "fc3" 171 | output: "fc3" 172 | type: "Relu" 173 | } 174 | op { 175 | input: "fc3" 176 | input: "pred_w" 177 | input: "pred_b" 178 | output: "pred" 179 | type: "FC" 180 | } 181 | op { 182 | input: "pred" 183 | output: "softmax" 184 | type: "Softmax" 185 | } 186 | op { 187 | input: "softmax" 188 | input: "label" 189 | output: "accuracy" 190 | type: "Accuracy" 191 | } 192 | op { 193 | input: "accuracy" 194 | type: "Print" 195 | arg { 196 | name: "to_file" 197 | i: 1 198 | } 199 | } 200 | device_option { 201 | device_type: 1 202 | } 203 | external_input: "dbreader" 204 | external_input: "conv1_w" 205 | external_input: "conv1_b" 206 | external_input: "conv2_w" 207 | external_input: "conv2_b" 208 | external_input: "fc3_w" 209 | external_input: "fc3_b" 210 | external_input: "pred_w" 211 | external_input: "pred_b" 212 | } 213 | execution_step { 214 | substep { 215 | network: "mnist_init_net" 216 | num_iter: 1 217 | } 218 | substep { 219 | network: "mnist_test_net" 220 | num_iter: 100 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /train_plan.pbtxt: -------------------------------------------------------------------------------- 1 | name: "mnist_train_plan" 2 | network{ 3 | name: "mnist_init_net" 4 | op { 5 | output: "dbreader" 6 | type: "CreateDB" 7 | arg { 8 | name: "db_type" 9 | s: "leveldb" 10 | } 11 | arg { 12 | name: "db" 13 | s: "mnist-train-nchw-leveldb" 14 | } 15 | } 16 | op { 17 | output: "conv1_w" 18 | type: "XavierFill" 19 | arg { 20 | name: "shape" 21 | ints: 20 22 | ints: 1 23 | ints: 5 24 | ints: 5 25 | } 26 | } 27 | op { 28 | output: "conv1_b" 29 | type: "ConstantFill" 30 | arg { 31 | name: "shape" 32 | ints: 20 33 | } 34 | } 35 | op { 36 | output: "conv2_w" 37 | type: "XavierFill" 38 | arg { 39 | name: "shape" 40 | ints: 50 41 | ints: 20 42 | ints: 5 43 | ints: 5 44 | } 45 | } 46 | op { 47 | output: "conv2_b" 48 | type: "ConstantFill" 49 | arg { 50 | name: "shape" 51 | ints: 50 52 | } 53 | } 54 | op { 55 | output: "fc3_w" 56 | type: "XavierFill" 57 | arg { 58 | name: "shape" 59 | ints: 500 60 | ints: 800 61 | } 62 | } 63 | op { 64 | output: "fc3_b" 65 | type: "ConstantFill" 66 | arg { 67 | name: "shape" 68 | ints: 500 69 | } 70 | } 71 | op { 72 | output: "pred_w" 73 | type: "XavierFill" 74 | arg { 75 | name: "shape" 76 | ints: 10 77 | ints: 500 78 | } 79 | } 80 | op { 81 | output: "pred_b" 82 | type: "ConstantFill" 83 | arg { 84 | name: "shape" 85 | ints: 10 86 | } 87 | } 88 | op { 89 | output: "iter" 90 | type: "ConstantFill" 91 | arg { 92 | name: "shape" 93 | ints: 1 94 | } 95 | arg { 96 | name: "value" 97 | i: 0 98 | } 99 | arg { 100 | name: "dtype" 101 | i: 10 102 | } 103 | device_option { 104 | device_type: 0 105 | } 106 | } 107 | op { 108 | output: "ONE" 109 | type: "ConstantFill" 110 | arg { 111 | name: "shape" 112 | ints: 1 113 | } 114 | arg { 115 | name: "value" 116 | f: 1 117 | } 118 | } 119 | device_option { 120 | device_type: 1 121 | } 122 | } 123 | network{ 124 | name: "mnist_train_net" 125 | 126 | # size channel dims 127 | #input 28 1 28x28x1 128 | #conv1 (28-5+0)/1+1=24 20 24x24x20 129 | #pool1 (24-2+0)/2+1=12 20 12x12x20 130 | #conv2 (12-5+0)/1+1=8 50 8x8x50 131 | #pool2 (8-2+0)/2+1=24 50 4x4x50 132 | #fc3(relu) 1 500 1x1x500 133 | #pred 1 10 1x1x10 134 | 135 | op { 136 | input: "dbreader" 137 | output: "data_uint8" 138 | output: "label" 139 | type: "TensorProtosDBInput" 140 | arg { 141 | name: "batch_size" 142 | i: 100 143 | } 144 | } 145 | op { 146 | input: "data_uint8" 147 | output: "data" 148 | type: "Cast" 149 | arg { 150 | name: "to" 151 | i: 1 152 | } 153 | } 154 | op { 155 | input: "data" 156 | output: "data" 157 | type: "Scale" 158 | arg { 159 | name: "scale" 160 | f: 0.00390625 161 | } 162 | } 163 | op { 164 | input: "data" 165 | output: "data" 166 | type: "StopGradient" 167 | } 168 | op { 169 | input: "data" 170 | input: "conv1_w" 171 | input: "conv1_b" 172 | output: "conv1" 173 | type: "Conv" 174 | arg { 175 | name: "stride" 176 | i: 1 177 | } 178 | arg { 179 | name: "pad" 180 | i: 0 181 | } 182 | arg { 183 | name: "kernel" 184 | i: 5 185 | } 186 | } 187 | op { 188 | input: "conv1" 189 | output: "pool1" 190 | type: "MaxPool" 191 | arg { 192 | name: "stride" 193 | i: 2 194 | } 195 | arg { 196 | name: "pad" 197 | i: 0 198 | } 199 | arg { 200 | name: "kernel" 201 | i: 2 202 | } 203 | arg { 204 | name: "order" 205 | s: "NCHW" 206 | } 207 | arg { 208 | name: "legacy_pad" 209 | i: 3 210 | } 211 | } 212 | op { 213 | input: "pool1" 214 | input: "conv2_w" 215 | input: "conv2_b" 216 | output: "conv2" 217 | type: "Conv" 218 | arg { 219 | name: "stride" 220 | i: 1 221 | } 222 | arg { 223 | name: "pad" 224 | i: 0 225 | } 226 | arg { 227 | name: "kernel" 228 | i: 5 229 | } 230 | } 231 | op { 232 | input: "conv2" 233 | output: "pool2" 234 | type: "MaxPool" 235 | arg { 236 | name: "stride" 237 | i: 2 238 | } 239 | arg { 240 | name: "pad" 241 | i: 0 242 | } 243 | arg { 244 | name: "kernel" 245 | i: 2 246 | } 247 | arg { 248 | name: "order" 249 | s: "NCHW" 250 | } 251 | arg { 252 | name: "legacy_pad" 253 | i: 3 254 | } 255 | } 256 | op { 257 | input: "pool2" 258 | input: "fc3_w" 259 | input: "fc3_b" 260 | output: "fc3" 261 | type: "FC" 262 | } 263 | op { 264 | input: "fc3" 265 | output: "fc3" 266 | type: "Relu" 267 | } 268 | op { 269 | input: "fc3" 270 | input: "pred_w" 271 | input: "pred_b" 272 | output: "pred" 273 | type: "FC" 274 | } 275 | op { 276 | input: "pred" 277 | output: "softmax" 278 | type: "Softmax" 279 | } 280 | op { 281 | input: "softmax" 282 | input: "label" 283 | output: "xent" 284 | type: "LabelCrossEntropy" 285 | } 286 | op { 287 | input: "xent" 288 | output: "loss" 289 | type: "AveragedLoss" 290 | } 291 | op { 292 | input: "softmax" 293 | input: "label" 294 | output: "accuracy" 295 | type: "Accuracy" 296 | } 297 | op { 298 | input: "iter" 299 | output: "iter" 300 | type: "Iter" 301 | } 302 | op { 303 | input: "loss" 304 | output: "loss_grad" 305 | type: "ConstantFill" 306 | arg { 307 | name: "value" 308 | f: 1 309 | } 310 | } 311 | op { 312 | input: "xent" 313 | input: "loss_grad" 314 | output: "xent_grad" 315 | name: "" 316 | type: "AveragedLossGradient" 317 | is_gradient_op: true 318 | } 319 | op { 320 | input: "softmax" 321 | input: "label" 322 | input: "xent_grad" 323 | output: "softmax_grad" 324 | name: "" 325 | type: "LabelCrossEntropyGradient" 326 | is_gradient_op: true 327 | } 328 | op { 329 | input: "softmax" 330 | input: "softmax_grad" 331 | output: "pred_grad" 332 | name: "" 333 | type: "SoftmaxGradient" 334 | is_gradient_op: true 335 | } 336 | op { 337 | input: "fc3" 338 | input: "pred_w" 339 | input: "pred_grad" 340 | output: "pred_w_grad" 341 | output: "pred_b_grad" 342 | output: "fc3_grad" 343 | name: "" 344 | type: "FCGradient" 345 | is_gradient_op: true 346 | } 347 | op { 348 | input: "fc3" 349 | input: "fc3_grad" 350 | output: "fc3_grad" 351 | name: "" 352 | type: "ReluGradient" 353 | is_gradient_op: true 354 | } 355 | op { 356 | input: "pool2" 357 | input: "fc3_w" 358 | input: "fc3_grad" 359 | output: "fc3_w_grad" 360 | output: "fc3_b_grad" 361 | output: "pool2_grad" 362 | name: "" 363 | type: "FCGradient" 364 | is_gradient_op: true 365 | } 366 | op { 367 | input: "conv2" 368 | input: "pool2" 369 | input: "pool2_grad" 370 | output: "conv2_grad" 371 | name: "" 372 | type: "MaxPoolGradient" 373 | arg { 374 | name: "stride" 375 | i: 2 376 | } 377 | arg { 378 | name: "pad" 379 | i: 0 380 | } 381 | arg { 382 | name: "kernel" 383 | i: 2 384 | } 385 | arg { 386 | name: "order" 387 | s: "NCHW" 388 | } 389 | arg { 390 | name: "legacy_pad" 391 | i: 3 392 | } 393 | is_gradient_op: true 394 | } 395 | op { 396 | input: "pool1" 397 | input: "conv2_w" 398 | input: "conv2_grad" 399 | output: "conv2_w_grad" 400 | output: "conv2_b_grad" 401 | output: "pool1_grad" 402 | name: "" 403 | type: "ConvGradient" 404 | arg { 405 | name: "stride" 406 | i: 1 407 | } 408 | arg { 409 | name: "pad" 410 | i: 0 411 | } 412 | arg { 413 | name: "kernel" 414 | i: 5 415 | } 416 | is_gradient_op: true 417 | } 418 | op { 419 | input: "conv1" 420 | input: "pool1" 421 | input: "pool1_grad" 422 | output: "conv1_grad" 423 | name: "" 424 | type: "MaxPoolGradient" 425 | arg { 426 | name: "stride" 427 | i: 2 428 | } 429 | arg { 430 | name: "pad" 431 | i: 0 432 | } 433 | arg { 434 | name: "kernel" 435 | i: 2 436 | } 437 | arg { 438 | name: "order" 439 | s: "NCHW" 440 | } 441 | arg { 442 | name: "legacy_pad" 443 | i: 3 444 | } 445 | is_gradient_op: true 446 | } 447 | op { 448 | input: "data" 449 | input: "conv1_w" 450 | input: "conv1_grad" 451 | output: "conv1_w_grad" 452 | output: "conv1_b_grad" 453 | output: "data_grad" 454 | name: "" 455 | type: "ConvGradient" 456 | arg { 457 | name: "stride" 458 | i: 1 459 | } 460 | arg { 461 | name: "pad" 462 | i: 0 463 | } 464 | arg { 465 | name: "kernel" 466 | i: 5 467 | } 468 | is_gradient_op: true 469 | } 470 | op { 471 | input: "iter" 472 | output: "LR" 473 | type: "LearningRate" 474 | arg { 475 | name: "policy" 476 | s: "step" 477 | } 478 | arg { 479 | name: "stepsize" 480 | i: 1 481 | } 482 | arg { 483 | name: "base_lr" 484 | f: -0.1 485 | } 486 | arg { 487 | name: "gamma" 488 | f: 0.999 489 | } 490 | } 491 | op { 492 | input: "conv1_w" 493 | input: "ONE" 494 | input: "conv1_w_grad" 495 | input: "LR" 496 | output: "conv1_w" 497 | type: "WeightedSum" 498 | } 499 | op { 500 | input: "conv1_b" 501 | input: "ONE" 502 | input: "conv1_b_grad" 503 | input: "LR" 504 | output: "conv1_b" 505 | type: "WeightedSum" 506 | } 507 | op { 508 | input: "conv2_w" 509 | input: "ONE" 510 | input: "conv2_w_grad" 511 | input: "LR" 512 | output: "conv2_w" 513 | type: "WeightedSum" 514 | } 515 | op { 516 | input: "conv2_b" 517 | input: "ONE" 518 | input: "conv2_b_grad" 519 | input: "LR" 520 | output: "conv2_b" 521 | type: "WeightedSum" 522 | } 523 | op { 524 | input: "fc3_w" 525 | input: "ONE" 526 | input: "fc3_w_grad" 527 | input: "LR" 528 | output: "fc3_w" 529 | type: "WeightedSum" 530 | } 531 | op { 532 | input: "fc3_b" 533 | input: "ONE" 534 | input: "fc3_b_grad" 535 | input: "LR" 536 | output: "fc3_b" 537 | type: "WeightedSum" 538 | } 539 | op { 540 | input: "pred_w" 541 | input: "ONE" 542 | input: "pred_w_grad" 543 | input: "LR" 544 | output: "pred_w" 545 | type: "WeightedSum" 546 | } 547 | op { 548 | input: "pred_b" 549 | input: "ONE" 550 | input: "pred_b_grad" 551 | input: "LR" 552 | output: "pred_b" 553 | type: "WeightedSum" 554 | } 555 | op { 556 | input: "accuracy" 557 | type: "Print" 558 | arg { 559 | name: "to_file" 560 | i: 1 561 | } 562 | } 563 | op { 564 | input: "loss" 565 | type: "Print" 566 | arg { 567 | name: "to_file" 568 | i: 1 569 | } 570 | } 571 | op { 572 | input: "conv1_w" 573 | type: "Summarize" 574 | arg { 575 | name: "to_file" 576 | i: 1 577 | } 578 | } 579 | op { 580 | input: "conv1_w_grad" 581 | type: "Summarize" 582 | arg { 583 | name: "to_file" 584 | i: 1 585 | } 586 | } 587 | op { 588 | input: "conv1_b" 589 | type: "Summarize" 590 | arg { 591 | name: "to_file" 592 | i: 1 593 | } 594 | } 595 | op { 596 | input: "conv1_b_grad" 597 | type: "Summarize" 598 | arg { 599 | name: "to_file" 600 | i: 1 601 | } 602 | } 603 | op { 604 | input: "conv2_w" 605 | type: "Summarize" 606 | arg { 607 | name: "to_file" 608 | i: 1 609 | } 610 | } 611 | op { 612 | input: "conv2_w_grad" 613 | type: "Summarize" 614 | arg { 615 | name: "to_file" 616 | i: 1 617 | } 618 | } 619 | op { 620 | input: "conv2_b" 621 | type: "Summarize" 622 | arg { 623 | name: "to_file" 624 | i: 1 625 | } 626 | } 627 | op { 628 | input: "conv2_b_grad" 629 | type: "Summarize" 630 | arg { 631 | name: "to_file" 632 | i: 1 633 | } 634 | } 635 | op { 636 | input: "fc3_w" 637 | type: "Summarize" 638 | arg { 639 | name: "to_file" 640 | i: 1 641 | } 642 | } 643 | op { 644 | input: "fc3_w_grad" 645 | type: "Summarize" 646 | arg { 647 | name: "to_file" 648 | i: 1 649 | } 650 | } 651 | op { 652 | input: "fc3_b" 653 | type: "Summarize" 654 | arg { 655 | name: "to_file" 656 | i: 1 657 | } 658 | } 659 | op { 660 | input: "fc3_b_grad" 661 | type: "Summarize" 662 | arg { 663 | name: "to_file" 664 | i: 1 665 | } 666 | } 667 | op { 668 | input: "pred_w" 669 | type: "Summarize" 670 | arg { 671 | name: "to_file" 672 | i: 1 673 | } 674 | } 675 | op { 676 | input: "pred_w_grad" 677 | type: "Summarize" 678 | arg { 679 | name: "to_file" 680 | i: 1 681 | } 682 | } 683 | op { 684 | input: "pred_b" 685 | type: "Summarize" 686 | arg { 687 | name: "to_file" 688 | i: 1 689 | } 690 | } 691 | op { 692 | input: "pred_b_grad" 693 | type: "Summarize" 694 | arg { 695 | name: "to_file" 696 | i: 1 697 | } 698 | } 699 | device_option { 700 | device_type: 1 701 | } 702 | external_input: "dbreader" 703 | external_input: "conv1_w" 704 | external_input: "conv1_b" 705 | external_input: "conv2_w" 706 | external_input: "conv2_b" 707 | external_input: "fc3_w" 708 | external_input: "fc3_b" 709 | external_input: "pred_w" 710 | external_input: "pred_b" 711 | external_input: "iter" 712 | external_input: "ONE" 713 | } 714 | network{ 715 | name: "mnist_save_net" 716 | op { 717 | type: "Save" 718 | input: "conv1_w" 719 | input: "conv1_b" 720 | input: "conv2_w" 721 | input: "conv2_b" 722 | input: "fc3_w" 723 | input: "fc3_b" 724 | input: "pred_w" 725 | input: "pred_b" 726 | arg { 727 | name: "db" 728 | s: "LeNet_params" 729 | } 730 | arg { 731 | name: "db_type" 732 | s: "lmdb" 733 | } 734 | } 735 | external_input: "conv1_w" 736 | external_input: "conv1_b" 737 | external_input: "conv2_w" 738 | external_input: "conv2_b" 739 | external_input: "fc3_w" 740 | external_input: "fc3_b" 741 | external_input: "pred_w" 742 | external_input: "pred_b" 743 | device_option { 744 | device_type: 1 745 | } 746 | } 747 | execution_step { 748 | substep { 749 | network: "mnist_init_net" 750 | num_iter: 1 751 | } 752 | substep { 753 | network: "mnist_train_net" 754 | num_iter: 600 755 | } 756 | substep { 757 | network: "mnist_save_net" 758 | num_iter: 1 759 | } 760 | } 761 | --------------------------------------------------------------------------------