├── .appveyor.yml ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── .travis.yml ├── README.md ├── data ├── README.md └── dog.jpg ├── examples ├── README.md ├── cifar_parser.h ├── makefile ├── mnist_parser.h ├── mojo_vc2010.sln ├── test.cpp ├── test.vcxproj ├── train_cifar.cpp ├── train_cifar.vcxproj ├── train_mnist.cpp ├── train_mnist.vcxproj ├── vgg.cpp └── vgg.vcxproj ├── license.txt ├── models ├── README.md ├── cifar_deepcnet.mojo ├── mnist_deepcnet.mojo ├── mnist_quickstart.txt └── snapshots │ └── README.md └── mojo ├── activation.h ├── core_math.h ├── cost.h ├── layer.h ├── mojo.h ├── network.h ├── solver.h └── util.h /.appveyor.yml: -------------------------------------------------------------------------------- 1 | version: 1.0.{build} 2 | image: Visual Studio 2015 3 | configuration: Release 2015 4 | platform: x64 5 | build: 6 | project: examples\mojo_vc2010.sln 7 | verbosity: detailed 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | 5 | --- 6 | 7 | **Describe the bug** 8 | A clear and concise description of what the bug is. 9 | 10 | **To Reproduce** 11 | Steps to reproduce the behavior: 12 | 13 | **Expected behavior** 14 | A clear and concise description of what you expected to happen. 15 | 16 | **Additional context** 17 | Add any other context about the problem here. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | 5 | --- 6 | 7 | **What is the use case?** 8 | 9 | **Is your feature request related to a problem? Please describe.** 10 | 11 | **Is there a publication reference?** 12 | 13 | **Additional context** 14 | Add any other context or screenshots about the feature request here. 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/snapshots/ 2 | data/ 3 | 4 | ################# 5 | ## Eclipse 6 | ################# 7 | 8 | *.pydevproject 9 | .project 10 | .metadata 11 | bin/ 12 | tmp/ 13 | *.tmp 14 | *.bak 15 | *.swp 16 | *~.nib 17 | local.properties 18 | .classpath 19 | .settings/ 20 | .loadpath 21 | 22 | # External tool builders 23 | .externalToolBuilders/ 24 | 25 | # Locally stored "Eclipse launch configurations" 26 | *.launch 27 | 28 | # CDT-specific 29 | .cproject 30 | 31 | # PDT-specific 32 | .buildpath 33 | 34 | 35 | ################# 36 | ## Visual Studio 37 | ################# 38 | 39 | ## Ignore Visual Studio temporary files, build results, and 40 | ## files generated by popular Visual Studio add-ons. 41 | 42 | # User-specific files 43 | *.suo 44 | *.user 45 | *.sln.docstates 46 | 47 | # Build results 48 | 49 | [Dd]ebug/ 50 | [Rr]elease/ 51 | x64/ 52 | build/ 53 | [Bb]in/ 54 | [Oo]bj/ 55 | 56 | # MSTest test Results 57 | [Tt]est[Rr]esult*/ 58 | [Bb]uild[Ll]og.* 59 | 60 | *_i.c 61 | *_p.c 62 | *.ilk 63 | *.meta 64 | *.obj 65 | *.pch 66 | *.pdb 67 | *.pgc 68 | *.pgd 69 | *.rsp 70 | *.sbr 71 | *.tlb 72 | *.tli 73 | *.tlh 74 | *.tmp 75 | *.tmp_proj 76 | *.log 77 | *.vspscc 78 | *.vssscc 79 | .builds 80 | *.pidb 81 | *.log 82 | *.scc 83 | 84 | # Visual C++ cache files 85 | ipch/ 86 | *.aps 87 | *.ncb 88 | *.opensdf 89 | *.sdf 90 | *.cachefile 91 | 92 | # Visual Studio profiler 93 | *.psess 94 | *.vsp 95 | *.vspx 96 | 97 | # Guidance Automation Toolkit 98 | *.gpState 99 | 100 | # ReSharper is a .NET coding add-in 101 | _ReSharper*/ 102 | *.[Rr]e[Ss]harper 103 | 104 | # TeamCity is a build add-in 105 | _TeamCity* 106 | 107 | # DotCover is a Code Coverage Tool 108 | *.dotCover 109 | 110 | # NCrunch 111 | *.ncrunch* 112 | .*crunch*.local.xml 113 | 114 | # Installshield output folder 115 | [Ee]xpress/ 116 | 117 | # DocProject is a documentation generator add-in 118 | DocProject/buildhelp/ 119 | DocProject/Help/*.HxT 120 | DocProject/Help/*.HxC 121 | DocProject/Help/*.hhc 122 | DocProject/Help/*.hhk 123 | DocProject/Help/*.hhp 124 | DocProject/Help/Html2 125 | DocProject/Help/html 126 | 127 | # Click-Once directory 128 | publish/ 129 | 130 | # Publish Web Output 131 | *.Publish.xml 132 | *.pubxml 133 | *.publishproj 134 | 135 | # NuGet Packages Directory 136 | ## TODO: If you have NuGet Package Restore enabled, uncomment the next line 137 | #packages/ 138 | 139 | # Windows Azure Build Output 140 | csx 141 | *.build.csdef 142 | 143 | # Windows Store app package directory 144 | AppPackages/ 145 | 146 | # Others 147 | sql/ 148 | *.Cache 149 | ClientBin/ 150 | [Ss]tyle[Cc]op.* 151 | ~$* 152 | *~ 153 | *.dbmdl 154 | *.[Pp]ublish.xml 155 | *.pfx 156 | *.publishsettings 157 | 158 | # RIA/Silverlight projects 159 | Generated_Code/ 160 | 161 | # Backup & report files from converting an old project file to a newer 162 | # Visual Studio version. Backup files are not needed, because we have git ;-) 163 | _UpgradeReport_Files/ 164 | Backup*/ 165 | UpgradeLog*.XML 166 | UpgradeLog*.htm 167 | 168 | # SQL Server files 169 | App_Data/*.mdf 170 | App_Data/*.ldf 171 | 172 | ############# 173 | ## Windows detritus 174 | ############# 175 | 176 | # Windows image file caches 177 | Thumbs.db 178 | ehthumbs.db 179 | 180 | # Folder config file 181 | Desktop.ini 182 | 183 | # Recycle Bin used on file shares 184 | $RECYCLE.BIN/ 185 | 186 | # Mac crap 187 | .DS_Store 188 | 189 | 190 | ############# 191 | ## Python 192 | ############# 193 | 194 | *.py[cod] 195 | 196 | # Packages 197 | *.egg 198 | *.egg-info 199 | dist/ 200 | build/ 201 | eggs/ 202 | parts/ 203 | var/ 204 | sdist/ 205 | develop-eggs/ 206 | .installed.cfg 207 | 208 | # Installer logs 209 | pip-log.txt 210 | 211 | # Unit test / coverage reports 212 | .coverage 213 | .tox 214 | 215 | #Translations 216 | *.mo 217 | 218 | #Mr Developer 219 | .mr.developer.cfg 220 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | language: cpp 3 | addons: 4 | apt: 5 | sources: 6 | - ubuntu-toolchain-r-test 7 | packages: 8 | - gcc-7 9 | - g++-7 10 | before_install: cd examples 11 | script: 12 | - export CC=gcc-7 13 | - export CXX=g++-7 14 | - make -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | |MSVC Win64|GCC Linux64| 2 | |:---|:---| 3 | |[![Build status](https://ci.appveyor.com/api/projects/status/github/gnawice/mojo-cnn?svg=true)](https://ci.appveyor.com/api/projects/status/github/gnawice/mojo-cnn?svg=true)|[![Build Status](https://travis-ci.org/gnawice/mojo-cnn.svg?branch=master)](https://travis-ci.org/gnawice/mojo-cnn)| 4 | 5 | 6 | # mojo cnn 7 | #### the *fast* and *easy* header only c++ convolutional neural network package 8 | 9 | mojo is an efficient C++ CNN / DNN implementation that was built with the goal to balance usability, functionality, and speed. It is ideal for use in real-world applications. 10 | 11 | See the [mojo cnn wiki](https://github.com/gnawice/mojo-cnn/wiki) for updates on what's new. 12 | 13 | Consisting of only a handful of header files, mojo is in portable C++ with old fashioned C tricks for optimization. If built with OpenMP and SSE3, it's speed is competitive with other CPU based CNN frameworks. Being a minimal CPU solution, it is not designed to scale over a cluster to train very deep models (for that, go with GPUs and Caffe, TensorFlow, CNTK, Torch, etc…) 14 | 15 | The mojo cnn API provides a 'smart training' option which abstracts the management of the training process but still provides the flexibility to handle the threading and input data as you'd like (enabling real-time data augmentation). Just make a loop and pass in training samples until mojo cnn says stop. You are therefore not limited by the RAM required to hold your data. On the standard MNIST handwritten digit database, mojo's 'smart training' gives 99% accuracy in less than a minute and using only random shifts data augmentation, an accuracy of 99.71% (0.29% error) in about an hour. After a couple of hours 99.75% (0.25%) accuracy is achieved with DeepCNet type of network. 16 | 17 | 18 | | | mojo specs | 19 | | ---------------------------- |--- | 20 | | Layers | Input, Fully Connected, Convolution, Grouped Convolution, Depth Wise Convolution, Max Pool, Semi-Stochastic Pool, Dropout, Max Feature Map, Resize, Shuffle, DeepCNet, Concatenation. [Read more on the wiki](https://github.com/gnawice/mojo-cnn/wiki/Layers) | 21 | | Activations | Identity, Hyperbolic Tangent (tanh), Exponential Linear Unit (ELU), Rectified Linear Unit (ReLU), Leaky Rectified Linear Unit (LReLU), Very Leaky Rectified Linear Unitv (VLReLU), Sigmoid, Softmax | 22 | | Solvers | Stochastic Gradient Descent, RMSProp, AdaGrad, Adam | 23 | | Loss Functions | Mean Squared Error, Cross Entropy | 24 | | Padding | Zero, Edge, Median Border Value | 25 | | Training speed (1st epoch time, MINST 2 layer) | about 10 sec with smart training on CPU | 26 | | Required external dependencies | none | 27 | | Native Windows Support | yes | 28 | | Mutli-core support | yes (OpenMP) | 29 | | g++ 5.3.0/MSVC 2010/2013/2015 | yes/yes/yes/yes | 30 | | Branching | yes | 31 | | Multiple Inputes | yes | 32 | | Real-time Data Augmentation | yes, random shift, rotate/scale available if linking OpenCV | 33 | | Automatic training | yes | 34 | | HTML Training Log and Graphing | yes | 35 | | GPU Support | no | 36 | | Model Zoo | DeepCNet MNIST, DeepCNet CIFAR-10, VGG | 37 | 38 | API Example: 39 | Load model and perform prediction: 40 | ``` 41 | #include 42 | 43 | mojo::network cnn; 44 | cnn.read("../models/mojo_mnist.model"); 45 | const int predicted_class=cnn.predict_class(float_image.data()); 46 | 47 | ``` 48 | 49 | API Example: Construction of a new CNN for MNIST, and train records with OpenMP threading: 50 | ``` 51 | #define MOJO_OMP 52 | #include 53 | 54 | ucnn::network cnn("adam"); 55 | cnn.set_smart_train(true); 56 | cnn.enable_omp(); 57 | cnn.set_mini_batch_size(24); 58 | 59 | // add layer definitions. format : "layer_name", "layer_type params" 60 | cnn.push_back("I1","input 28 28 1"); // MNIST is 28x28x1 61 | cnn.push_back("C1","convolution 5 20 1 elu"); // 5x5 kernel, 20 maps, stride 1. out size is 28-5+1=24 62 | cnn.push_back("P1","semi_stochastic_pool 4 4"); // pool 4x4 blocks, stride 4. out size is 6 63 | cnn.push_back("C2","convolution 5 200 1 elu"); // 5x5 kernel, 200 maps. out size is 6-5+1=2 64 | cnn.push_back("P2","semi_stochastic_pool 2 2"); // pool 2x2 blocks. out size is 2/2=1 65 | cnn.push_back("FC1","fully_connected 100 identity");// fully connected 100 nodes 66 | cnn.push_back("FC2","softmax 10"); 67 | 68 | cnn.connect_all(); // connect layers automatically (no branches) 69 | 70 | while(1) 71 | { 72 | // train with OpenMP threading 73 | cnn.start_epoch("cross_entropy"); 74 | 75 | #pragma omp parallel 76 | #pragma omp for schedule(dynamic) 77 | for(int k=0; k // cout 34 | #include 35 | #include 36 | #include //setw 37 | #include 38 | 39 | #include 40 | //#include 41 | 42 | namespace cifar 43 | { 44 | 45 | std::string data_name() {return std::string("CIFAR-10");} 46 | 47 | bool parse_cifar_data(const std::string& cifar_file, 48 | std::vector> *images, 49 | std::vector *labels, 50 | float scale_min = -1.0, float scale_max = 1.0, 51 | int x_padding = 0, int y_padding = 0) 52 | { 53 | std::ifstream ifs(cifar_file.c_str(), std::ios::in | std::ios::binary); 54 | 55 | if (ifs.bad() || ifs.fail()) 56 | { 57 | //std::cout << "failed to open file:" + cifar_file; 58 | return false; 59 | } 60 | 61 | // format is 1byte class, 1024b (32x32) R, 1024b (32x32) G, 1024b (32x32) B 62 | // 10,000 items in each file 63 | 64 | for (size_t i = 0; i < 10000; i++) 65 | { 66 | // read label 67 | unsigned char label; 68 | ifs.read((char*) &label, 1); 69 | labels->push_back((int) label); 70 | 71 | // read image 72 | std::vector image_c(32*32*3); 73 | ifs.read((char*) &image_c[0], 32*32*3); 74 | int width = 32+2*x_padding; 75 | int height = 32+2*y_padding; 76 | std::vector image(height*width*3); 77 | 78 | // convert from RGB to BGR 79 | for (size_t c = 0; c < 3; c++) 80 | for (size_t y = 0; y < 32; y++) 81 | for (size_t x = 0; x < 32; x++) 82 | image[width * (y + y_padding) + x + x_padding + (3-c-1)*width*height] = 83 | (image_c[y * 32 + x+c*32*32] / 255.0f) * (scale_max - scale_min) + scale_min; 84 | 85 | images->push_back(image); 86 | 87 | } 88 | return true; 89 | } 90 | 91 | bool parse_test_data(std::string &data_path, std::vector> &test_images, std::vector &test_labels, 92 | float min_val=-1.f, float max_val=1.f, int padx=0, int pady=0) 93 | { 94 | return parse_cifar_data(data_path+"test_batch.bin", &test_images, &test_labels, min_val, max_val, padx, pady); 95 | } 96 | 97 | bool parse_train_data(std::string &data_path, std::vector> &train_images, std::vector &train_labels, 98 | float min_val=-1.f, float max_val=1.f, int padx=0, int pady=0) 99 | { 100 | if(!parse_cifar_data(data_path+"data_batch_1.bin", &train_images, &train_labels, min_val, max_val, padx, pady)) return false; 101 | if(!parse_cifar_data(data_path+"data_batch_2.bin", &train_images, &train_labels, min_val, max_val, padx, pady)) return false; 102 | if(!parse_cifar_data(data_path+"data_batch_3.bin", &train_images, &train_labels, min_val, max_val, padx, pady)) return false; 103 | if(!parse_cifar_data(data_path+"data_batch_4.bin", &train_images, &train_labels, min_val, max_val, padx, pady)) return false; 104 | if(!parse_cifar_data(data_path+"data_batch_5.bin", &train_images, &train_labels, min_val, max_val, padx, pady)) return false; 105 | return true; 106 | } 107 | 108 | } // namespace 109 | 110 | -------------------------------------------------------------------------------- /examples/makefile: -------------------------------------------------------------------------------- 1 | CC=g++ 2 | CFLAGS_OMP= -I../mojo/ -std=c++11 -fopenmp -O3 -DMOJO_OPM -DMOJO_AVX -msse4 -mavx 3 | 4 | all: test train_mnist train_cifar 5 | 6 | test: test.cpp 7 | $(CC) $(CFLAGS_OMP) test.cpp -o test 8 | 9 | train_mnist: train_mnist.cpp 10 | $(CC) $(CFLAGS_OMP) train_mnist.cpp -o train_mnist 11 | 12 | train_cifar: train_cifar.cpp 13 | $(CC) $(CFLAGS_OMP) train_cifar.cpp -o train_cifar 14 | 15 | clean: 16 | -rm -f test 17 | -rm -f train_mnist 18 | -rm -f train_cifar 19 | -------------------------------------------------------------------------------- /examples/mnist_parser.h: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // mnist_parser.h: prepares MNIST data for testing/training 4 | // 5 | // This code was modified from tiny_cnn https://github.com/nyanp/tiny-cnn 6 | // It can parse MNIST data which you need to download and unzip locally on 7 | // your machine. 8 | // You can get it from: http://yann.lecun.com/exdb/mnist/index.html 9 | // 10 | // ==================================================================== mojo == 11 | 12 | /* 13 | Copyright (c) 2013, Taiga Nomi 14 | All rights reserved. 15 | 16 | Redistribution and use in source and binary forms, with or without 17 | modification, are permitted provided that the following conditions are met: 18 | * Redistributions of source code must retain the above copyright 19 | notice, this list of conditions and the following disclaimer. 20 | * Redistributions in binary form must reproduce the above copyright 21 | notice, this list of conditions and the following disclaimer in the 22 | documentation and/or other materials provided with the distribution. 23 | * Neither the name of the nor the 24 | names of its contributors may be used to endorse or promote products 25 | derived from this software without specific prior written permission. 26 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY 27 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 28 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 29 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY 30 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 31 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 32 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 33 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 34 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 35 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 36 | */ 37 | 38 | 39 | #pragma once 40 | 41 | 42 | #include // cout 43 | #include 44 | #include 45 | #include //setw 46 | #include 47 | #include 48 | 49 | 50 | namespace mnist 51 | { 52 | std::string data_name() {return std::string("MNIST");} 53 | 54 | // from tiny_cnn 55 | template 56 | T* reverse_endian(T* p) { 57 | std::reverse(reinterpret_cast(p), reinterpret_cast(p) + sizeof(T)); 58 | return p; 59 | } 60 | 61 | // from tiny_cnn (kinda) 62 | bool parse_mnist_labels(const std::string& label_file, std::vector *labels) { 63 | std::ifstream ifs(label_file.c_str(), std::ios::in | std::ios::binary); 64 | 65 | if (ifs.bad() || ifs.fail()) 66 | { 67 | return false; 68 | } 69 | int magic_number, num_items; 70 | 71 | ifs.read((char*) &magic_number, 4); 72 | ifs.read((char*) &num_items, 4); 73 | 74 | reverse_endian(&magic_number); 75 | reverse_endian(&num_items); 76 | 77 | for (size_t i = 0; i < num_items; i++) { 78 | unsigned char label; 79 | ifs.read((char*) &label, 1); 80 | labels->push_back((int) label); 81 | } 82 | return true; 83 | } 84 | 85 | // from tiny_cnn 86 | struct mnist_header { 87 | int magic_number; 88 | int num_items; 89 | int num_rows; 90 | int num_cols; 91 | }; 92 | 93 | // from tiny_cnn (kinda) 94 | bool parse_mnist_images(const std::string& image_file, 95 | std::vector> *images, 96 | float scale_min = -1.0, float scale_max = 1.0, 97 | int x_padding = 0, int y_padding = 0) 98 | { 99 | std::ifstream ifs(image_file.c_str(), std::ios::in | std::ios::binary); 100 | 101 | if (ifs.bad() || ifs.fail()) 102 | { 103 | return false; 104 | } 105 | mnist_header header; 106 | 107 | // read header 108 | ifs.read((char*) &header.magic_number, 4); 109 | ifs.read((char*) &header.num_items, 4); 110 | ifs.read((char*) &header.num_rows, 4); 111 | ifs.read((char*) &header.num_cols, 4); 112 | 113 | reverse_endian(&header.magic_number); 114 | reverse_endian(&header.num_items); 115 | reverse_endian(&header.num_rows); 116 | reverse_endian(&header.num_cols); 117 | 118 | 119 | const int width = header.num_cols + 2 * x_padding; 120 | const int height = header.num_rows + 2 * y_padding; 121 | 122 | // read each image 123 | for (size_t i = 0; i < header.num_items; i++) 124 | { 125 | std::vector image; 126 | std::vector image_vec(header.num_rows * header.num_cols); 127 | 128 | ifs.read((char*) &image_vec[0], header.num_rows * header.num_cols); 129 | image.resize(width * height, scale_min); 130 | 131 | for (size_t y = 0; y < header.num_rows; y++) 132 | { 133 | for (size_t x = 0; x < header.num_cols; x++) 134 | image[width * (y + y_padding) + x + x_padding] = 135 | (image_vec[y * header.num_cols + x] / 255.0f) * (scale_max - scale_min) + scale_min; 136 | } 137 | 138 | images->push_back(image); 139 | } 140 | return true; 141 | } 142 | 143 | // == load data (MNIST-28x28x1 size, no padding, pixel range -1 to 1) 144 | bool parse_test_data(std::string &data_path, std::vector> &test_images, std::vector &test_labels, 145 | float min_val=-1.f, float max_val=1.f, int padx=0, int pady=0) 146 | { 147 | if(!parse_mnist_images(data_path+"/t10k-images.idx3-ubyte", &test_images, min_val, max_val, padx, pady)) 148 | if (!parse_mnist_images(data_path + "/t10k-images-idx3-ubyte", &test_images, min_val, max_val, padx, pady)) 149 | return false; 150 | if(!parse_mnist_labels(data_path+"/t10k-labels.idx1-ubyte", &test_labels)) 151 | if (!parse_mnist_labels(data_path + "/t10k-labels-idx1-ubyte", &test_labels)) return false; 152 | return true; 153 | } 154 | bool parse_train_data(std::string &data_path, std::vector> &train_images, std::vector &train_labels, 155 | float min_val=-1.f, float max_val=1.f, int padx=0, int pady=0) 156 | { 157 | if(!parse_mnist_images(data_path+"/train-images.idx3-ubyte", &train_images, min_val, max_val, padx, pady)) 158 | if (!parse_mnist_images(data_path + "/train-images-idx3-ubyte", &train_images, min_val, max_val, padx, pady)) 159 | return false; 160 | if(!parse_mnist_labels(data_path+"/train-labels.idx1-ubyte", &train_labels)) 161 | if (!parse_mnist_labels(data_path + "/train-labels-idx1-ubyte", &train_labels)) return false; 162 | return true; 163 | } 164 | } 165 | 166 | 167 | -------------------------------------------------------------------------------- /examples/mojo_vc2010.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 11.00 3 | # Visual Studio 2010 4 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "test", "test.vcxproj", "{855AB941-7417-414E-8FBA-8581CAFB03EA}" 5 | EndProject 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "vgg", "vgg.vcxproj", "{855AB941-7417-414E-8FBA-8581CAFB03EB}" 7 | EndProject 8 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "train_cifar", "train_cifar.vcxproj", "{DA6A4565-A674-4210-B93B-7870E038FE69}" 9 | EndProject 10 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "train_mnist", "train_mnist.vcxproj", "{DA6A4565-A674-4210-B93B-7870E038FE65}" 11 | EndProject 12 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "mojo", "mojo", "{2ADCCC4F-A972-4662-A5EE-2FFF8360AEA0}" 13 | ProjectSection(SolutionItems) = preProject 14 | ..\mojo\activation.h = ..\mojo\activation.h 15 | ..\mojo\core_math.h = ..\mojo\core_math.h 16 | ..\mojo\cost.h = ..\mojo\cost.h 17 | ..\mojo\layer.h = ..\mojo\layer.h 18 | ..\mojo\mojo.h = ..\mojo\mojo.h 19 | ..\mojo\network.h = ..\mojo\network.h 20 | ..\mojo\solver.h = ..\mojo\solver.h 21 | ..\mojo\util.h = ..\mojo\util.h 22 | EndProjectSection 23 | EndProject 24 | Global 25 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 26 | Release 2010|x64 = Release 2010|x64 27 | Release 2015|x64 = Release 2015|x64 28 | EndGlobalSection 29 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 30 | {855AB941-7417-414E-8FBA-8581CAFB03EA}.Release 2010|x64.ActiveCfg = Release 2010|x64 31 | {855AB941-7417-414E-8FBA-8581CAFB03EA}.Release 2010|x64.Build.0 = Release 2010|x64 32 | {855AB941-7417-414E-8FBA-8581CAFB03EA}.Release 2015|x64.ActiveCfg = Release 2015|x64 33 | {855AB941-7417-414E-8FBA-8581CAFB03EA}.Release 2015|x64.Build.0 = Release 2015|x64 34 | {855AB941-7417-414E-8FBA-8581CAFB03EB}.Release 2010|x64.ActiveCfg = Release 2010|x64 35 | {855AB941-7417-414E-8FBA-8581CAFB03EB}.Release 2015|x64.ActiveCfg = Release 2015|x64 36 | {DA6A4565-A674-4210-B93B-7870E038FE69}.Release 2010|x64.ActiveCfg = Release 2010|x64 37 | {DA6A4565-A674-4210-B93B-7870E038FE69}.Release 2010|x64.Build.0 = Release 2010|x64 38 | {DA6A4565-A674-4210-B93B-7870E038FE69}.Release 2015|x64.ActiveCfg = Release 2015|x64 39 | {DA6A4565-A674-4210-B93B-7870E038FE69}.Release 2015|x64.Build.0 = Release 2015|x64 40 | {DA6A4565-A674-4210-B93B-7870E038FE65}.Release 2010|x64.ActiveCfg = Release 2010|x64 41 | {DA6A4565-A674-4210-B93B-7870E038FE65}.Release 2010|x64.Build.0 = Release 2010|x64 42 | {DA6A4565-A674-4210-B93B-7870E038FE65}.Release 2015|x64.ActiveCfg = Release 2015|x64 43 | {DA6A4565-A674-4210-B93B-7870E038FE65}.Release 2015|x64.Build.0 = Release 2015|x64 44 | EndGlobalSection 45 | GlobalSection(SolutionProperties) = preSolution 46 | HideSolutionNode = FALSE 47 | EndGlobalSection 48 | EndGlobal 49 | -------------------------------------------------------------------------------- /examples/test.cpp: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // test.cpp: Simple example using pre-trained model to test mojo cnn 27 | // 28 | // Instructions: 29 | // Add the "mojo" folder in your include path. 30 | // Download MNIST data and unzip locally on your machine: 31 | // (http://yann.lecun.com/exdb/mnist/index.html) 32 | // Download CIFAR-10 data and unzip locally on your machine: 33 | // (http://www.cs.toronto.edu/~kriz/cifar.html) 34 | // Set the data_path variable in the code to point to your data location. 35 | // ==================================================================== mojo == 36 | 37 | #include // cout 38 | #include 39 | #include 40 | #include 41 | #include 42 | //#include 43 | 44 | #include 45 | 46 | //* 47 | #include "mnist_parser.h" 48 | using namespace mnist; 49 | std::string data_path="../data/mnist/"; 50 | std::string model_file="../models/mnist_deepcnet.mojo"; 51 | 52 | /*/ 53 | #include "cifar_parser.h" 54 | using namespace cifar; 55 | std::string data_path="../data/cifar-10-batches-bin/"; 56 | std::string model_file="../models/cifar_deepcnet.mojo"; 57 | //*/ 58 | 59 | void test(mojo::network &cnn, const std::vector> &test_images, const std::vector &test_labels) 60 | { 61 | int out_size=cnn.out_size(); // we know this to be 10 for MNIST and CIFAR 62 | int correct_predictions=0; 63 | 64 | // use progress object for simple timing and status updating 65 | mojo::progress progress((int)test_images.size(), " testing : "); 66 | 67 | const int record_cnt= (int)test_images.size(); 68 | 69 | // when MOJO_OMP is defined, we use standard "omp parallel for" loop, 70 | // the number of threads determined by network.enable_external_threads() call 71 | #pragma omp parallel for reduction(+:correct_predictions) schedule(dynamic) // dynamic schedule just helps the progress class to work correcly 72 | for(int k=0; k> test_images; 92 | // array to hold image labels 93 | std::vector test_labels; 94 | // calls MNIST::parse_test_data or CIFAR10::parse_test_data depending on 'using' 95 | if(!parse_test_data(data_path, test_images, test_labels)) {std::cerr << "error: could not parse data.\n"; return 1;} 96 | 97 | // == setup the network 98 | mojo::network cnn; 99 | 100 | // here we need to prepare mojo cnn to store data from multiple threads 101 | // !! enable_external_threads must be set prior to loading or creating a model !! 102 | cnn.enable_external_threads(); 103 | 104 | // load model 105 | if(!cnn.read(model_file)) {std::cerr << "error: could not read model.\n"; return 1;} 106 | std::cout << "Mojo CNN Configuration:" << std::endl; 107 | std::cout << cnn.get_configuration() << std::endl << std::endl; 108 | 109 | // == run the test 110 | std::cout << "Testing " << data_name() << ":" << std::endl; 111 | // this function will loop through all images, call predict, and print out stats 112 | test(cnn, test_images, test_labels); 113 | 114 | std::cout << std::endl; 115 | return 0; 116 | } 117 | -------------------------------------------------------------------------------- /examples/test.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Release 2010 6 | x64 7 | 8 | 9 | Release 2013 10 | x64 11 | 12 | 13 | Release 2015 14 | x64 15 | 16 | 17 | 18 | {855AB941-7417-414E-8FBA-8581CAFB03EA} 19 | Win32Proj 20 | test 21 | 8.1 22 | 23 | 24 | 25 | Application 26 | false 27 | true 28 | MultiByte 29 | v100 30 | 31 | 32 | Application 33 | false 34 | true 35 | MultiByte 36 | v120 37 | 38 | 39 | Application 40 | false 41 | true 42 | MultiByte 43 | v140 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | false 60 | ..\bin\ 61 | $(Platform)\$(Configuration)\$(ProjectName)\ 62 | 63 | 64 | false 65 | ..\bin\ 66 | $(Platform)\$(Configuration)\$(ProjectName)\ 67 | 68 | 69 | false 70 | ..\bin\ 71 | $(Platform)\$(Configuration)\$(ProjectName)\ 72 | 73 | 74 | 75 | Level3 76 | NotUsing 77 | MaxSpeed 78 | true 79 | true 80 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 81 | ../mojo 82 | false 83 | 84 | 85 | Console 86 | true 87 | true 88 | true 89 | $(OutDir)$(TargetName)$(TargetExt) 90 | 91 | 92 | 93 | 94 | Level3 95 | NotUsing 96 | MaxSpeed 97 | true 98 | true 99 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 100 | ../mojo 101 | false 102 | 103 | 104 | Console 105 | true 106 | true 107 | true 108 | $(OutDir)$(TargetName)$(TargetExt) 109 | 110 | 111 | 112 | 113 | Level3 114 | NotUsing 115 | MaxSpeed 116 | true 117 | true 118 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 119 | ../mojo 120 | false 121 | 122 | 123 | Console 124 | true 125 | true 126 | true 127 | $(OutDir)$(TargetName)$(TargetExt) 128 | 129 | 130 | 131 | 132 | Level3 133 | NotUsing 134 | MaxSpeed 135 | true 136 | true 137 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 138 | ../mojo 139 | false 140 | 141 | 142 | Console 143 | true 144 | true 145 | true 146 | $(OutDir)$(TargetName)$(TargetExt) 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /examples/train_cifar.cpp: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // train_cifar.cpp: train cifar-10 classifier 27 | // 28 | // Instructions: 29 | // Add the "mojo" folder in your include path. 30 | // Download MNIST data and unzip locally on your machine: 31 | // (http://yann.lecun.com/exdb/mnist/index.html) 32 | // Set the data_path variable in the code to point to your data location. 33 | // ==================================================================== mojo == 34 | 35 | #include // cout 36 | #include 37 | #include 38 | #include 39 | #include 40 | //#include 41 | 42 | //#define MOJO_CV3 43 | #include 44 | #include 45 | #include "cifar_parser.h" 46 | 47 | const int mini_batch_size = 16; 48 | const float initial_learning_rate = 0.05f; 49 | std::string solver = "adam"; 50 | std::string data_path = "../data/cifar-10-batches-bin/"; 51 | using namespace cifar; 52 | 53 | 54 | float test(mojo::network &cnn, const std::vector> &test_images, const std::vector &test_labels) 55 | { 56 | // use progress object for simple timing and status updating 57 | mojo::progress progress((int)test_images.size(), " testing:\t\t"); 58 | 59 | int out_size = cnn.out_size(); // we know this to be 10 for MNIST 60 | int correct_predictions = 0; 61 | const int record_cnt = (int)test_images.size(); 62 | 63 | #pragma omp parallel for reduction(+:correct_predictions) schedule(dynamic) 64 | for (int k = 0; k> &train_images, std::vector> &test_images) 76 | { 77 | // calculate the mean for every pixel position 78 | mojo::matrix mean(32, 32, 3); 79 | mean.fill(0); 80 | for (int i = 0; i < train_images.size(); i++) mean += mojo::matrix(32, 32, 3, train_images[i].data()); 81 | mean *= (float)(1.f / train_images.size()); 82 | 83 | // remove mean from data 84 | for (int i = 0; i < train_images.size(); i++) 85 | { 86 | mojo::matrix img(32, 32, 3, train_images[i].data()); 87 | img -= mean; 88 | memcpy(train_images[i].data(), img.x, sizeof(float)*img.size()); 89 | } 90 | for (int i = 0; i < test_images.size(); i++) 91 | { 92 | mojo::matrix img(32, 32, 3, test_images[i].data()); 93 | img -= mean; 94 | memcpy(test_images[i].data(), img.x, sizeof(float)*img.size()); 95 | } 96 | } 97 | 98 | int main() 99 | { 100 | // == parse data 101 | // array to hold image data (note that mojo does not require use of std::vector) 102 | std::vector> test_images; 103 | std::vector test_labels; 104 | std::vector> train_images; 105 | std::vector train_labels; 106 | 107 | // calls MNIST::parse_test_data or CIFAR10::parse_test_data depending on 'using' 108 | if (!parse_test_data(data_path, test_images, test_labels)) { std::cerr << "error: could not parse data.\n"; return 1; } 109 | if (!parse_train_data(data_path, train_images, train_labels)) { std::cerr << "error: could not parse data.\n"; return 1; } 110 | 111 | //remove_cifar_mean(train_images, test_images); 112 | 113 | // == setup the network - when you train you must specify an optimizer ("sgd", "rmsprop", "adagrad", "adam") 114 | mojo::network cnn(solver.c_str()); 115 | // !! the threading must be enabled with thread count prior to loading or creating a model !! 116 | cnn.enable_external_threads(); 117 | cnn.set_mini_batch_size(mini_batch_size); 118 | cnn.set_smart_training(true); // automate training 119 | cnn.set_learning_rate(initial_learning_rate); 120 | // augment data random shifts only +/-2 pix 121 | cnn.set_random_augmentation(2,2,0,0,mojo::edge); 122 | 123 | // configure network 124 | cnn.push_back("I1", "input 32 32 3"); // CIFAR is 32x32x3 125 | cnn.push_back("C1", "convolution 3 16 1 elu"); // 32-3+1=30 126 | cnn.push_back("P1", "semi_stochastic_pool 3 3"); // 10x10 out 127 | cnn.push_back("C2", "convolution 3 64 1 elu"); // 8x8 out 128 | cnn.push_back("P2", "semi_stochastic_pool 4 4"); // 2x2 out 129 | cnn.push_back("FC2", "softmax 10"); 130 | 131 | // connect all the layers. Call connect() manually for all layer connections if you need more exotic networks. 132 | cnn.connect_all(); 133 | std::cout << "== Network Configuration ====================================================" << std::endl; 134 | std::cout << cnn.get_configuration() << std::endl; 135 | 136 | 137 | // add headers for table of values we want to log out 138 | mojo::html_log log; 139 | log.set_table_header("epoch\ttest accuracy(%)\testimated accuracy(%)\tepoch time(s)\ttotal time(s)\tlearn rate\tmodel"); 140 | log.set_note(cnn.get_configuration()); 141 | 142 | // setup timer/progress for overall training 143 | mojo::progress overall_progress(-1, " overall:\t\t"); 144 | const int train_samples = (int)train_images.size(); 145 | while (1) 146 | { 147 | overall_progress.draw_header(data_name() + " Epoch " + std::to_string((long long)cnn.get_epoch() + 1), true); 148 | // setup timer / progress for this one epoch 149 | mojo::progress progress(train_samples, " training:\t\t"); 150 | 151 | cnn.start_epoch("cross_entropy"); 152 | 153 | #pragma omp parallel for schedule(dynamic) 154 | for (int k = 0; k 2 | 3 | 4 | 5 | Release 2010 6 | x64 7 | 8 | 9 | Release 2013 10 | x64 11 | 12 | 13 | Release 2015 14 | x64 15 | 16 | 17 | 18 | {DA6A4565-A674-4210-B93B-7870E038FE69} 19 | Win32Proj 20 | train_cv 21 | train_cifar 22 | 8.1 23 | 24 | 25 | 26 | Application 27 | false 28 | true 29 | MultiByte 30 | v140 31 | 32 | 33 | Application 34 | false 35 | true 36 | MultiByte 37 | v100 38 | 39 | 40 | Application 41 | false 42 | true 43 | MultiByte 44 | v120 45 | 46 | 47 | Application 48 | false 49 | true 50 | MultiByte 51 | v140 52 | 53 | 54 | Application 55 | false 56 | true 57 | MultiByte 58 | v140 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | false 78 | ..\bin\ 79 | $(Platform)\$(Configuration)\$(ProjectName)\ 80 | 81 | 82 | false 83 | ..\bin\ 84 | $(Platform)\$(Configuration)\$(ProjectName)\ 85 | 86 | 87 | false 88 | ..\bin\ 89 | $(Platform)\$(Configuration)\$(ProjectName)\ 90 | 91 | 92 | false 93 | ..\bin\ 94 | $(Platform)\$(Configuration)\$(ProjectName)\ 95 | 96 | 97 | 98 | Level3 99 | NotUsing 100 | true 101 | true 102 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 103 | ../mojo;C:/opencv/build/include 104 | true 105 | MaxSpeed 106 | 16Bytes 107 | AdvancedVectorExtensions 108 | true 109 | 110 | 111 | Console 112 | true 113 | true 114 | true 115 | $(OutDir)$(TargetName)$(TargetExt) 116 | UseLinkTimeCodeGeneration 117 | c:/opencv/build/lib 118 | 119 | 120 | 121 | 122 | Level3 123 | NotUsing 124 | true 125 | true 126 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 127 | ../mojo;C:/opencv/build/include 128 | true 129 | MaxSpeed 130 | 16Bytes 131 | AdvancedVectorExtensions 132 | true 133 | 134 | 135 | Console 136 | true 137 | true 138 | true 139 | $(OutDir)$(TargetName)$(TargetExt) 140 | UseLinkTimeCodeGeneration 141 | c:/opencv/build/lib 142 | 143 | 144 | 145 | 146 | Level3 147 | NotUsing 148 | true 149 | true 150 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 151 | ../mojo;C:/opencv/build/include 152 | true 153 | MaxSpeed 154 | 16Bytes 155 | AdvancedVectorExtensions 156 | true 157 | 158 | 159 | Console 160 | true 161 | true 162 | true 163 | $(OutDir)$(TargetName)$(TargetExt) 164 | UseLinkTimeCodeGeneration 165 | c:/opencv/build/lib 166 | 167 | 168 | 169 | 170 | Level3 171 | NotUsing 172 | true 173 | true 174 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 175 | ../mojo;C:\opencv\build\include 176 | true 177 | Disabled 178 | 179 | 180 | Console 181 | true 182 | true 183 | true 184 | $(OutDir)$(TargetName)$(TargetExt) 185 | c:/opencv/build/lib 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /examples/train_mnist.cpp: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // train_mnist.cpp: train MNIST classifier 27 | // 28 | // Instructions: 29 | // Add the "mojo" folder in your include path. 30 | // Download MNIST data and unzip locally on your machine: 31 | // (http://yann.lecun.com/exdb/mnist/index.html) 32 | // Set the data_path variable in the code to point to your data location. 33 | // ==================================================================== mojo == 34 | 35 | #include // cout 36 | #include 37 | #include 38 | #include 39 | #include 40 | //#include 41 | 42 | //#define MOJO_CV3 43 | 44 | #include 45 | #include 46 | #include "mnist_parser.h" 47 | 48 | const int mini_batch_size = 24; 49 | const float initial_learning_rate = 0.04f; 50 | std::string solver = "adam"; 51 | std::string data_path="../data/mnist/"; 52 | using namespace mnist; 53 | 54 | // performs validation testing 55 | float test(mojo::network &cnn, const std::vector> &test_images, const std::vector &test_labels) 56 | { 57 | // use progress object for simple timing and status updating 58 | mojo::progress progress((int)test_images.size(), " testing:\t\t"); 59 | 60 | int out_size = cnn.out_size(); // we know this to be 10 for MNIST 61 | int correct_predictions = 0; 62 | const int record_cnt = (int)test_images.size(); 63 | 64 | #pragma omp parallel for reduction(+:correct_predictions) schedule(dynamic) 65 | for (int k = 0; k> test_images; 82 | std::vector test_labels; 83 | std::vector> train_images; 84 | std::vector train_labels; 85 | 86 | // calls MNIST::parse_test_data or CIFAR10::parse_test_data depending on 'using' 87 | if (!parse_test_data(data_path, test_images, test_labels)) { std::cerr << "error: could not parse data.\n"; return 1; } 88 | if (!parse_train_data(data_path, train_images, train_labels)) { std::cerr << "error: could not parse data.\n"; return 1; } 89 | 90 | // ==== setup the network - when you train you must specify an optimizer ("sgd", "rmsprop", "adagrad", "adam") 91 | mojo::network cnn(solver.c_str()); 92 | // !! the threading must be enabled with thread count prior to loading or creating a model !! 93 | cnn.enable_external_threads(); 94 | cnn.set_mini_batch_size(mini_batch_size); 95 | cnn.set_smart_training(true); // automate training 96 | cnn.set_learning_rate(initial_learning_rate); 97 | 98 | // Note, network descriptions can be read from a text file with similar format to the API 99 | cnn.read("../models/mnist_quickstart.txt"); 100 | 101 | /* 102 | // to construct the model through API calls... 103 | cnn.push_back("I1", "input 28 28 1"); // MNIST is 28x28x1 104 | cnn.push_back("C1", "convolution 5 8 1 elu"); // 5x5 kernel, 20 maps. stride 1. out size is 28-5+1=24 105 | cnn.push_back("P1", "semi_stochastic_pool 3 3"); // pool 3x3 blocks. stride 3. outsize is 8 106 | cnn.push_back("C2i", "convolution 1 16 1 elu"); // 1x1 'inceptoin' layer 107 | cnn.push_back("C2", "convolution 5 48 1 elu"); // 5x5 kernel, 200 maps. out size is 8-5+1=4 108 | cnn.push_back("P2", "semi_stochastic_pool 2 2"); // pool 2x2 blocks. stride 2. outsize is 2x2 109 | cnn.push_back("FC2", "softmax 10"); // 'flatten' of 2x2 input is inferred 110 | // connect all the layers. Call connect() manually for all layer connections if you need more exotic networks. 111 | cnn.connect_all(); 112 | // */ 113 | 114 | std::cout << "== Network Configuration ====================================================" << std::endl; 115 | std::cout << cnn.get_configuration() << std::endl; 116 | 117 | // add headers for table of values we want to log out 118 | mojo::html_log log; 119 | log.set_table_header("epoch\ttest accuracy(%)\testimated accuracy(%)\tepoch time(s)\ttotal time(s)\tlearn rate\tmodel"); 120 | log.set_note(cnn.get_configuration()); 121 | 122 | // augment data random shifts only 123 | cnn.set_random_augmentation(1,1,0,0,mojo::edge); 124 | 125 | // setup timer/progress for overall training 126 | mojo::progress overall_progress(-1, " overall:\t\t"); 127 | const int train_samples = (int)train_images.size(); 128 | float old_accuracy = 0; 129 | while (1) 130 | { 131 | overall_progress.draw_header(data_name() + " Epoch " + std::to_string((long long)cnn.get_epoch() + 1), true); 132 | // setup timer / progress for this one epoch 133 | mojo::progress progress(train_samples, " training:\t\t"); 134 | // set loss function 135 | cnn.start_epoch("cross_entropy"); 136 | 137 | // manually loop through data. batches are handled internally. if data is to be shuffled, the must be performed externally 138 | #pragma omp parallel for schedule(dynamic) // schedule dynamic to help make progress bar work correctly 139 | for (int k = 0; k old_accuracy) 173 | { 174 | cnn.reset_smart_training(); 175 | old_accuracy = accuracy; 176 | } 177 | 178 | // save model 179 | std::string model_file = "../models/snapshots/tmp_" + std::to_string((long long)cnn.get_epoch()) + ".txt"; 180 | cnn.write(model_file,true); 181 | std::cout << " saved model:\t\t" << model_file << std::endl << std::endl; 182 | 183 | // write log file 184 | std::string log_out; 185 | log_out += float2str(dt) + "\t"; 186 | log_out += float2str(overall_progress.elapsed_seconds()) + "\t"; 187 | log_out += float2str(cnn.get_learning_rate()) + "\t"; 188 | log_out += model_file; 189 | log.add_table_row(cnn.estimated_accuracy, accuracy, log_out); 190 | // will write this every epoch 191 | log.write("../models/snapshots/mojo_mnist_log.htm"); 192 | 193 | // can't seem to improve 194 | if (cnn.elvis_left_the_building()) 195 | { 196 | std::cout << "Elvis just left the building. No further improvement in training found.\nStopping.." << std::endl; 197 | break; 198 | } 199 | 200 | }; 201 | std::cout << std::endl; 202 | return 0; 203 | } -------------------------------------------------------------------------------- /examples/train_mnist.vcxproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | No Opt 6 | x64 7 | 8 | 9 | Release 2010 10 | x64 11 | 12 | 13 | Release 2013 14 | x64 15 | 16 | 17 | Release 2015 18 | x64 19 | 20 | 21 | 22 | {DA6A4565-A674-4210-B93B-7870E038FE65} 23 | Win32Proj 24 | train 25 | train_mnist 26 | 8.1 27 | 28 | 29 | 30 | Application 31 | false 32 | true 33 | MultiByte 34 | v100 35 | 36 | 37 | Application 38 | false 39 | true 40 | MultiByte 41 | v120 42 | 43 | 44 | Application 45 | false 46 | true 47 | MultiByte 48 | v140 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | false 65 | ..\bin\ 66 | $(Platform)\$(Configuration)\$(ProjectName)\ 67 | 68 | 69 | false 70 | ..\bin\ 71 | $(Platform)\$(Configuration)\$(ProjectName)\ 72 | 73 | 74 | false 75 | ..\bin\ 76 | $(Platform)\$(Configuration)\$(ProjectName)\ 77 | 78 | 79 | false 80 | 81 | 82 | 83 | Level3 84 | NotUsing 85 | true 86 | true 87 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 88 | ../mojo;C:/opencv/build/include 89 | true 90 | MaxSpeed 91 | AdvancedVectorExtensions 92 | 93 | 94 | Console 95 | true 96 | true 97 | true 98 | $(OutDir)$(TargetName)$(TargetExt) 99 | c:/opencv/build/lib 100 | 101 | 102 | 103 | 104 | Level3 105 | NotUsing 106 | true 107 | true 108 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 109 | ../mojo;C:/opencv/build/include 110 | true 111 | MaxSpeed 112 | AdvancedVectorExtensions 113 | 114 | 115 | Console 116 | true 117 | true 118 | true 119 | $(OutDir)$(TargetName)$(TargetExt) 120 | c:/opencv/build/lib 121 | 122 | 123 | 124 | 125 | Level3 126 | NotUsing 127 | true 128 | true 129 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 130 | ../mojo;C:/opencv/build/include 131 | true 132 | MaxSpeed 133 | AdvancedVectorExtensions 134 | 135 | 136 | Console 137 | true 138 | true 139 | true 140 | $(OutDir)$(TargetName)$(TargetExt) 141 | c:/opencv/build/lib 142 | 143 | 144 | 145 | 146 | ../mojo;C:/opencv/build/include 147 | 148 | 149 | c:/opencv/build/lib 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /examples/vgg.cpp: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // vgg.cpp: Simple example using pre-trained VGG16 model. See: 27 | // 28 | // Very Deep Convolutional Networks for Large-Scale Image Recognition 29 | // K. Simonyan, A. Zisserman 30 | // arXiv:1409.1556 31 | // 32 | // Instructions: 33 | // Download the mojo VGG16 model from https://drive.google.com/file/d/0B5Dx9ePCIXQAZU51T0MyQXpvOXc/view?usp=sharing 34 | // ==================================================================== mojo == 35 | 36 | #include // cout 37 | #include 38 | #include 39 | #include 40 | #include 41 | //#include 42 | 43 | char *labels[]={"tench, Tinca tinca", "goldfish, Carassius auratus", "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", "tiger shark, Galeocerdo cuvieri", "hammerhead, hammerhead shark", "electric ray, crampfish, numbfish, torpedo", "stingray", "cock", "hen", "ostrich, Struthio camelus", "brambling, Fringilla montifringilla", "goldfinch, Carduelis carduelis", "house finch, linnet, Carpodacus mexicanus", "junco, snowbird", "indigo bunting, indigo finch, indigo bird, Passerina cyanea", "robin, American robin, Turdus migratorius", "bulbul", "jay", "magpie", "chickadee", "water ouzel, dipper", "kite", "bald eagle, American eagle, Haliaeetus leucocephalus", "vulture", "great grey owl, great gray owl, Strix nebulosa", "European fire salamander, Salamandra salamandra", "common newt, Triturus vulgaris", "eft", "spotted salamander, Ambystoma maculatum", "axolotl, mud puppy, Ambystoma mexicanum", "bullfrog, Rana catesbeiana", "tree frog, tree-frog", "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", "loggerhead, loggerhead turtle, Caretta caretta", "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", "mud turtle", "terrapin", "box turtle, box tortoise", "banded gecko", "common iguana, iguana, Iguana iguana", "American chameleon, anole, Anolis carolinensis", "whiptail, whiptail lizard", "agama", "frilled lizard, Chlamydosaurus kingi", "alligator lizard", "Gila monster, Heloderma suspectum", "green lizard, Lacerta viridis", "African chameleon, Chamaeleo chamaeleon", "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", "African crocodile, Nile crocodile, Crocodylus niloticus", "American alligator, Alligator mississipiensis", "triceratops", "thunder snake, worm snake, Carphophis amoenus", "ringneck snake, ring-necked snake, ring snake", "hognose snake, puff adder, sand viper", "green snake, grass snake", "king snake, kingsnake", "garter snake, grass snake", "water snake", "vine snake", "night snake, Hypsiglena torquata", "boa constrictor, Constrictor constrictor", "rock python, rock snake, Python sebae", "Indian cobra, Naja naja", "green mamba", "sea snake", "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", "diamondback, diamondback rattlesnake, Crotalus adamanteus", "sidewinder, horned rattlesnake, Crotalus cerastes", "trilobite", "harvestman, daddy longlegs, Phalangium opilio", "scorpion", "black and gold garden spider, Argiope aurantia", "barn spider, Araneus cavaticus", "garden spider, Aranea diademata", "black widow, Latrodectus mactans", "tarantula", "wolf spider, hunting spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse, partridge, Bonasa umbellus", "prairie chicken, prairie grouse, prairie fowl", "peacock", "quail", "partridge", "African grey, African gray, Psittacus erithacus", "macaw", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "drake", "red-breasted merganser, Mergus serrator", "goose", "black swan, Cygnus atratus", "tusker", "echidna, spiny anteater, anteater", "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", "wallaby, brush kangaroo", "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", "wombat", "jellyfish", "sea anemone, anemone", "brain coral", "flatworm, platyhelminth", "nematode, nematode worm, roundworm", "conch", "snail", "slug", "sea slug, nudibranch", "chiton, coat-of-mail shell, sea cradle, polyplacophore", "chambered nautilus, pearly nautilus, nautilus", "Dungeness crab, Cancer magister", "rock crab, Cancer irroratus", "fiddler crab", "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", "American lobster, Northern lobster, Maine lobster, Homarus americanus", "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", "crayfish, crawfish, crawdad, crawdaddy", "hermit crab", "isopod", "white stork, Ciconia ciconia", "black stork, Ciconia nigra", "spoonbill", "flamingo", "little blue heron, Egretta caerulea", "American egret, great white heron, Egretta albus", "bittern", "crane", "limpkin, Aramus pictus", "European gallinule, Porphyrio porphyrio", "American coot, marsh hen, mud hen, water hen, Fulica americana", "bustard", "ruddy turnstone, Arenaria interpres", "red-backed sandpiper, dunlin, Erolia alpina", "redshank, Tringa totanus", "dowitcher", "oystercatcher, oyster catcher", "pelican", "king penguin, Aptenodytes patagonica", "albatross, mollymawk", "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", "dugong, Dugong dugon", "sea lion", "Chihuahua", "Japanese spaniel", "Maltese dog, Maltese terrier, Maltese", "Pekinese, Pekingese, Peke", "Shih-Tzu", "Blenheim spaniel", "papillon", "toy terrier", "Rhodesian ridgeback", "Afghan hound, Afghan", "basset, basset hound", "beagle", "bloodhound, sleuthhound", "bluetick", "black-and-tan coonhound", "Walker hound, Walker foxhound", "English foxhound", "redbone", "borzoi, Russian wolfhound", "Irish wolfhound", "Italian greyhound", "whippet", "Ibizan hound, Ibizan Podenco", "Norwegian elkhound, elkhound", "otterhound, otter hound", "Saluki, gazelle hound", "Scottish deerhound, deerhound", "Weimaraner", "Staffordshire bullterrier, Staffordshire bull terrier", "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "Bedlington terrier", "Border terrier", "Kerry blue terrier", "Irish terrier", "Norfolk terrier", "Norwich terrier", "Yorkshire terrier", "wire-haired fox terrier", "Lakeland terrier", "Sealyham terrier, Sealyham", "Airedale, Airedale terrier", "cairn, cairn terrier", "Australian terrier", "Dandie Dinmont, Dandie Dinmont terrier", "Boston bull, Boston terrier", "miniature schnauzer", "giant schnauzer", "standard schnauzer", "Scotch terrier, Scottish terrier, Scottie", "Tibetan terrier, chrysanthemum dog", "silky terrier, Sydney silky", "soft-coated wheaten terrier", "West Highland white terrier", "Lhasa, Lhasa apso", "flat-coated retriever", "curly-coated retriever", "golden retriever", "Labrador retriever", "Chesapeake Bay retriever", "German short-haired pointer", "vizsla, Hungarian pointer", "English setter", "Irish setter, red setter", "Gordon setter", "Brittany spaniel", "clumber, clumber spaniel", "English springer, English springer spaniel", "Welsh springer spaniel", "cocker spaniel, English cocker spaniel, cocker", "Sussex spaniel", "Irish water spaniel", "kuvasz", "schipperke", "groenendael", "malinois", "briard", "kelpie", "komondor", "Old English sheepdog, bobtail", "Shetland sheepdog, Shetland sheep dog, Shetland", "collie", "Border collie", "Bouvier des Flandres, Bouviers des Flandres", "Rottweiler", "German shepherd, German shepherd dog, German police dog, alsatian", "Doberman, Doberman pinscher", "miniature pinscher", "Greater Swiss Mountain dog", "Bernese mountain dog", "Appenzeller", "EntleBucher", "boxer", "bull mastiff", "Tibetan mastiff", "French bulldog", "Great Dane", "Saint Bernard, St Bernard", "Eskimo dog, husky", "malamute, malemute, Alaskan malamute", "Siberian husky", "dalmatian, coach dog, carriage dog", "affenpinscher, monkey pinscher, monkey dog", "basenji", "pug, pug-dog", "Leonberg", "Newfoundland, Newfoundland dog", "Great Pyrenees", "Samoyed, Samoyede", "Pomeranian", "chow, chow chow", "keeshond", "Brabancon griffon", "Pembroke, Pembroke Welsh corgi", "Cardigan, Cardigan Welsh corgi", "toy poodle", "miniature poodle", "standard poodle", "Mexican hairless", "timber wolf, grey wolf, gray wolf, Canis lupus", "white wolf, Arctic wolf, Canis lupus tundrarum", "red wolf, maned wolf, Canis rufus, Canis niger", "coyote, prairie wolf, brush wolf, Canis latrans", "dingo, warrigal, warragal, Canis dingo", "dhole, Cuon alpinus", "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "hyena, hyaena", "red fox, Vulpes vulpes", "kit fox, Vulpes macrotis", "Arctic fox, white fox, Alopex lagopus", "grey fox, gray fox, Urocyon cinereoargenteus", "tabby, tabby cat", "tiger cat", "Persian cat", "Siamese cat, Siamese", "Egyptian cat", "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", "lynx, catamount", "leopard, Panthera pardus", "snow leopard, ounce, Panthera uncia", "jaguar, panther, Panthera onca, Felis onca", "lion, king of beasts, Panthera leo", "tiger, Panthera tigris", "cheetah, chetah, Acinonyx jubatus", "brown bear, bruin, Ursus arctos", "American black bear, black bear, Ursus americanus, Euarctos americanus", "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", "sloth bear, Melursus ursinus, Ursus ursinus", "mongoose", "meerkat, mierkat", "tiger beetle", "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", "ground beetle, carabid beetle", "long-horned beetle, longicorn, longicorn beetle", "leaf beetle, chrysomelid", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant, emmet, pismire", "grasshopper, hopper", "cricket", "walking stick, walkingstick, stick insect", "cockroach, roach", "mantis, mantid", "cicada, cicala", "leafhopper", "lacewing, lacewing fly", "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", "damselfly", "admiral", "ringlet, ringlet butterfly", "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", "cabbage butterfly", "sulphur butterfly, sulfur butterfly", "lycaenid, lycaenid butterfly", "starfish, sea star", "sea urchin", "sea cucumber, holothurian", "wood rabbit, cottontail, cottontail rabbit", "hare", "Angora, Angora rabbit", "hamster", "porcupine, hedgehog", "fox squirrel, eastern fox squirrel, Sciurus niger", "marmot", "beaver", "guinea pig, Cavia cobaya", "sorrel", "zebra", "hog, pig, grunter, squealer, Sus scrofa", "wild boar, boar, Sus scrofa", "warthog", "hippopotamus, hippo, river horse, Hippopotamus amphibius", "ox", "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", "bison", "ram, tup", "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", "ibex, Capra ibex", "hartebeest", "impala, Aepyceros melampus", "gazelle", "Arabian camel, dromedary, Camelus dromedarius", "llama", "weasel", "mink", "polecat, fitch, foulmart, foumart, Mustela putorius", "black-footed ferret, ferret, Mustela nigripes", "otter", "skunk, polecat, wood pussy", "badger", "armadillo", "three-toed sloth, ai, Bradypus tridactylus", "orangutan, orang, orangutang, Pongo pygmaeus", "gorilla, Gorilla gorilla", "chimpanzee, chimp, Pan troglodytes", "gibbon, Hylobates lar", "siamang, Hylobates syndactylus, Symphalangus syndactylus", "guenon, guenon monkey", "patas, hussar monkey, Erythrocebus patas", "baboon", "macaque", "langur", "colobus, colobus monkey", "proboscis monkey, Nasalis larvatus", "marmoset", "capuchin, ringtail, Cebus capucinus", "howler monkey, howler", "titi, titi monkey", "spider monkey, Ateles geoffroyi", "squirrel monkey, Saimiri sciureus", "Madagascar cat, ring-tailed lemur, Lemur catta", "indri, indris, Indri indri, Indri brevicaudatus", "Indian elephant, Elephas maximus", "African elephant, Loxodonta africana", "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", "barracouta, snoek", "eel", "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "rock beauty, Holocanthus tricolor", "anemone fish", "sturgeon", "gar, garfish, garpike, billfish, Lepisosteus osseus", "lionfish", "puffer, pufferfish, blowfish, globefish", "abacus", "abaya", "academic gown, academic robe, judge's robe", "accordion, piano accordion, squeeze box", "acoustic guitar", "aircraft carrier, carrier, flattop, attack aircraft carrier", "airliner", "airship, dirigible", "altar", "ambulance", "amphibian, amphibious vehicle", "analog clock", "apiary, bee house", "apron", "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "assault rifle, assault gun", "backpack, back pack, knapsack, packsack, rucksack, haversack", "bakery, bakeshop, bakehouse", "balance beam, beam", "balloon", "ballpoint, ballpoint pen, ballpen, Biro", "Band Aid", "banjo", "bannister, banister, balustrade, balusters, handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel, cask", "barrow, garden cart, lawn cart, wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "bathing cap, swimming cap", "bath towel", "bathtub, bathing tub, bath, tub", "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", "beacon, lighthouse, beacon light, pharos", "beaker", "bearskin, busby, shako", "beer bottle", "beer glass", "bell cote, bell cot", "bib", "bicycle-built-for-two, tandem bicycle, tandem", "bikini, two-piece", "binder, ring-binder", "binoculars, field glasses, opera glasses", "birdhouse", "boathouse", "bobsled, bobsleigh, bob", "bolo tie, bolo, bola tie, bola", "bonnet, poke bonnet", "bookcase", "bookshop, bookstore, bookstall", "bottlecap", "bow", "bow tie, bow-tie, bowtie", "brass, memorial tablet, plaque", "brassiere, bra, bandeau", "breakwater, groin, groyne, mole, bulwark, seawall, jetty", "breastplate, aegis, egis", "broom", "bucket, pail", "buckle", "bulletproof vest", "bullet train, bullet", "butcher shop, meat market", "cab, hack, taxi, taxicab", "caldron, cauldron", "candle, taper, wax light", "cannon", "canoe", "can opener, tin opener", "cardigan", "car mirror", "carousel, carrousel, merry-go-round, roundabout, whirligig", "carpenter's kit, tool kit", "carton", "car wheel", "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello, violoncello", "cellular telephone, cellular phone, cellphone, cell, mobile phone", "chain", "chainlink fence", "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", "chain saw, chainsaw", "chest", "chiffonier, commode", "chime, bell, gong", "china cabinet, china closet", "Christmas stocking", "church, church building", "cinema, movie theater, movie theatre, movie house, picture palace", "cleaver, meat cleaver, chopper", "cliff dwelling", "cloak", "clog, geta, patten, sabot", "cocktail shaker", "coffee mug", "coffeepot", "coil, spiral, volute, whorl, helix", "combination lock", "computer keyboard, keypad", "confectionery, confectionary, candy store", "container ship, containership, container vessel", "convertible", "corkscrew, bottle screw", "cornet, horn, trumpet, trump", "cowboy boot", "cowboy hat, ten-gallon hat", "cradle", "crane", "crash helmet", "crate", "crib, cot", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam, dike, dyke", "desk", "desktop computer", "dial telephone, dial phone", "diaper, nappy, napkin", "digital clock", "digital watch", "dining table, board", "dishrag, dishcloth", "dishwasher, dish washer, dishwashing machine", "disk brake, disc brake", "dock, dockage, docking facility", "dogsled, dog sled, dog sleigh", "dome", "doormat, welcome mat", "drilling platform, offshore rig", "drum, membranophone, tympan", "drumstick", "dumbbell", "Dutch oven", "electric fan, blower", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso maker", "face powder", "feather boa, boa", "file, file cabinet, filing cabinet", "fireboat", "fire engine, fire truck", "fire screen, fireguard", "flagpole, flagstaff", "flute, transverse flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster", "freight car", "French horn, horn", "frying pan, frypan, skillet", "fur coat", "garbage truck, dustcart", "gasmask, respirator, gas helmet", "gas pump, gasoline pump, petrol pump, island dispenser", "goblet", "go-kart", "golf ball", "golfcart, golf cart", "gondola", "gong, tam-tam", "gown", "grand piano, grand", "greenhouse, nursery, glasshouse", "grille, radiator grille", "grocery store, grocery, food market, market", "guillotine", "hair slide", "hair spray", "half track", "hammer", "hamper", "hand blower, blow dryer, blow drier, hair dryer, hair drier", "hand-held computer, hand-held microcomputer", "handkerchief, hankie, hanky, hankey", "hard disc, hard disk, fixed disk", "harmonica, mouth organ, harp, mouth harp", "harp", "harvester, reaper", "hatchet", "holster", "home theater, home theatre", "honeycomb", "hook, claw", "hoopskirt, crinoline", "horizontal bar, high bar", "horse cart, horse-cart", "hourglass", "iPod", "iron, smoothing iron", "jack-o'-lantern", "jean, blue jean, denim", "jeep, landrover", "jersey, T-shirt, tee shirt", "jigsaw puzzle", "jinrikisha, ricksha, rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat, laboratory coat", "ladle", "lampshade, lamp shade", "laptop, laptop computer", "lawn mower, mower", "lens cap, lens cover", "letter opener, paper knife, paperknife", "library", "lifeboat", "lighter, light, igniter, ignitor", "limousine, limo", "liner, ocean liner", "lipstick, lip rouge", "Loafer", "lotion", "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "loupe, jeweler's loupe", "lumbermill, sawmill", "magnetic compass", "mailbag, postbag", "mailbox, letter box", "maillot", "maillot, tank suit", "manhole cover", "maraca", "marimba, xylophone", "mask", "matchstick", "maypole", "maze, labyrinth", "measuring cup", "medicine chest, medicine cabinet", "megalith, megalithic structure", "microphone, mike", "microwave, microwave oven", "military uniform", "milk can", "minibus", "miniskirt, mini", "minivan", "missile", "mitten", "mixing bowl", "mobile home, manufactured home", "Model T", "modem", "monastery", "monitor", "moped", "mortar", "mortarboard", "mosque", "mosquito net", "motor scooter, scooter", "mountain bike, all-terrain bike, off-roader", "mountain tent", "mouse, computer mouse", "mousetrap", "moving van", "muzzle", "nail", "neck brace", "necklace", "nipple", "notebook, notebook computer", "obelisk", "oboe, hautboy, hautbois", "ocarina, sweet potato", "odometer, hodometer, mileometer, milometer", "oil filter", "organ, pipe organ", "oscilloscope, scope, cathode-ray oscilloscope, CRO", "overskirt", "oxcart", "oxygen mask", "packet", "paddle, boat paddle", "paddlewheel, paddle wheel", "padlock", "paintbrush", "pajama, pyjama, pj's, jammies", "palace", "panpipe, pandean pipe, syrinx", "paper towel", "parachute, chute", "parallel bars, bars", "park bench", "parking meter", "passenger car, coach, carriage", "patio, terrace", "pay-phone, pay-station", "pedestal, plinth, footstall", "pencil box, pencil case", "pencil sharpener", "perfume, essence", "Petri dish", "photocopier", "pick, plectrum, plectron", "pickelhaube", "picket fence, paling", "pickup, pickup truck", "pier", "piggy bank, penny bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate, pirate ship", "pitcher, ewer", "plane, carpenter's plane, woodworking plane", "planetarium", "plastic bag", "plate rack", "plow, plough", "plunger, plumber's helper", "Polaroid camera, Polaroid Land camera", "pole", "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", "poncho", "pool table, billiard table, snooker table", "pop bottle, soda bottle", "pot, flowerpot", "potter's wheel", "power drill", "prayer rug, prayer mat", "printer", "prison, prison house", "projectile, missile", "projector", "puck, hockey puck", "punching bag, punch bag, punching ball, punchball", "purse", "quill, quill pen", "quilt, comforter, comfort, puff", "racer, race car, racing car", "racket, racquet", "radiator", "radio, wireless", "radio telescope, radio reflector", "rain barrel", "recreational vehicle, RV, R.V.", "reel", "reflex camera", "refrigerator, icebox", "remote control, remote", "restaurant, eating house, eating place, eatery", "revolver, six-gun, six-shooter", "rifle", "rocking chair, rocker", "rotisserie", "rubber eraser, rubber, pencil eraser", "rugby ball", "rule, ruler", "running shoe", "safe", "safety pin", "saltshaker, salt shaker", "sandal", "sarong", "sax, saxophone", "scabbard", "scale, weighing machine", "school bus", "schooner", "scoreboard", "screen, CRT screen", "screw", "screwdriver", "seat belt, seatbelt", "sewing machine", "shield, buckler", "shoe shop, shoe-shop, shoe store", "shoji", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "ski mask", "sleeping bag", "slide rule, slipstick", "sliding door", "slot, one-armed bandit", "snorkel", "snowmobile", "snowplow, snowplough", "soap dispenser", "soccer ball", "sock", "solar dish, solar collector, solar furnace", "sombrero", "soup bowl", "space bar", "space heater", "space shuttle", "spatula", "speedboat", "spider web, spider's web", "spindle", "sports car, sport car", "spotlight, spot", "stage", "steam locomotive", "steel arch bridge", "steel drum", "stethoscope", "stole", "stone wall", "stopwatch, stop watch", "stove", "strainer", "streetcar, tram, tramcar, trolley, trolley car", "stretcher", "studio couch, day bed", "stupa, tope", "submarine, pigboat, sub, U-boat", "suit, suit of clothes", "sundial", "sunglass", "sunglasses, dark glasses, shades", "sunscreen, sunblock, sun blocker", "suspension bridge", "swab, swob, mop", "sweatshirt", "swimming trunks, bathing trunks", "swing", "switch, electric switch, electrical switch", "syringe", "table lamp", "tank, army tank, armored combat vehicle, armoured combat vehicle", "tape player", "teapot", "teddy, teddy bear", "television, television system", "tennis ball", "thatch, thatched roof", "theater curtain, theatre curtain", "thimble", "thresher, thrasher, threshing machine", "throne", "tile roof", "toaster", "tobacco shop, tobacconist shop, tobacconist", "toilet seat", "torch", "totem pole", "tow truck, tow car, wrecker", "toyshop", "tractor", "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", "tray", "trench coat", "tricycle, trike, velocipede", "trimaran", "tripod", "triumphal arch", "trolleybus, trolley coach, trackless trolley", "trombone", "tub, vat", "turnstile", "typewriter keyboard", "umbrella", "unicycle, monocycle", "upright, upright piano", "vacuum, vacuum cleaner", "vase", "vault", "velvet", "vending machine", "vestment", "viaduct", "violin, fiddle", "volleyball", "waffle iron", "wall clock", "wallet, billfold, notecase, pocketbook", "wardrobe, closet, press", "warplane, military plane", "washbasin, handbasin, washbowl, lavabo, wash-hand basin", "washer, automatic washer, washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "wig", "window screen", "window shade", "Windsor tie", "wine bottle", "wing", "wok", "wooden spoon", "wool, woolen, woollen", "worm fence, snake fence, snake-rail fence, Virginia fence", "wreck", "yawl", "yurt", "web site, website, internet site, site", "comic book", "crossword puzzle, crossword", "street sign", "traffic light, traffic signal, stoplight", "book jacket, dust cover, dust jacket, dust wrapper", "menu", "plate", "guacamole", "consomme", "hot pot, hotpot", "trifle", "ice cream, icecream", "ice lolly, lolly, lollipop, popsicle", "French loaf", "bagel, beigel", "pretzel", "cheeseburger", "hotdog, hot dog, red hot", "mashed potato", "head cabbage", "broccoli", "cauliflower", "zucchini, courgette", "spaghetti squash", "acorn squash", "butternut squash", "cucumber, cuke", "artichoke, globe artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith", "strawberry", "orange", "lemon", "fig", "pineapple, ananas", "banana", "jackfruit, jak, jack", "custard apple", "pomegranate", "hay", "carbonara", "chocolate sauce, chocolate syrup", "dough", "meat loaf, meatloaf", "pizza, pizza pie", "potpie", "burrito", "red wine", "espresso", "cup", "eggnog", "alp", "bubble", "cliff, drop, drop-off", "coral reef", "geyser", "lakeside, lakeshore", "promontory, headland, head, foreland", "sandbar, sand bar", "seashore, coast, seacoast, sea-coast", "valley, vale", "volcano", "ballplayer, baseball player", "groom, bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", "corn", "acorn", "hip, rose hip, rosehip", "buckeye, horse chestnut, conker", "coral fungus", "agaric", "gyromitra", "stinkhorn, carrion fungus", "earthstar", "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", "bolete", "ear, spike, capitulum", "toilet tissue, toilet paper, bathroom tissue"}; 44 | 45 | #define MOJO_CV3 // this example requires opencv to read images 46 | #define MOJO_PROFILE_LAYERS // print out layer timing 47 | #include 48 | 49 | std::string data_path="../data/dog.jpg"; 50 | std::string model_file="../models/vgg16.mojo"; 51 | 52 | int main(int argc, char **argv) 53 | { 54 | if(argc>1) data_path = argv[1]; 55 | 56 | // read image 57 | cv::Mat im = cv::imread(data_path); 58 | if(im.empty() || im.cols<1) { std::cout << "Failed to read a valid image. (" << data_path <<")"< 2 | 3 | 4 | 5 | Release 2010 6 | x64 7 | 8 | 9 | Release 2013 10 | x64 11 | 12 | 13 | Release 2015 14 | x64 15 | 16 | 17 | 18 | {855AB941-7417-414E-8FBA-8581CAFB03EB} 19 | Win32Proj 20 | vgg 21 | 8.1 22 | 23 | 24 | 25 | Application 26 | false 27 | true 28 | MultiByte 29 | v100 30 | 31 | 32 | Application 33 | false 34 | true 35 | MultiByte 36 | v120 37 | 38 | 39 | Application 40 | false 41 | true 42 | MultiByte 43 | v140 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | false 60 | ..\bin\ 61 | $(Platform)\$(Configuration)\$(ProjectName)\ 62 | 63 | 64 | false 65 | ..\bin\ 66 | $(Platform)\$(Configuration)\$(ProjectName)\ 67 | 68 | 69 | false 70 | ..\bin\ 71 | $(Platform)\$(Configuration)\$(ProjectName)\ 72 | 73 | 74 | 75 | Level3 76 | NotUsing 77 | MaxSpeed 78 | true 79 | true 80 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 81 | ../mojo 82 | false 83 | 84 | 85 | Console 86 | true 87 | true 88 | true 89 | $(OutDir)$(TargetName)$(TargetExt) 90 | 91 | 92 | 93 | 94 | Level3 95 | NotUsing 96 | MaxSpeed 97 | true 98 | true 99 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 100 | ../mojo;C:/opencv/build/include 101 | false 102 | 103 | 104 | Console 105 | true 106 | true 107 | true 108 | $(OutDir)$(TargetName)$(TargetExt) 109 | c:/opencv/build/lib 110 | 111 | 112 | 113 | 114 | Level3 115 | NotUsing 116 | MaxSpeed 117 | true 118 | true 119 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 120 | ../mojo;C:/opencv/build/include 121 | false 122 | 123 | 124 | Console 125 | true 126 | true 127 | true 128 | $(OutDir)$(TargetName)$(TargetExt) 129 | c:/opencv/build/lib 130 | 131 | 132 | 133 | 134 | Level3 135 | NotUsing 136 | MaxSpeed 137 | true 138 | true 139 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 140 | ../mojo;C:/opencv/build/include 141 | false 142 | 143 | 144 | Console 145 | true 146 | true 147 | true 148 | $(OutDir)$(TargetName)$(TargetExt) 149 | c:/opencv/build/lib 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 gnawice 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | Pre-trained models for mojo cnn: 2 | 3 | + **vgg16.mojo:** VGG 16 layer model converted from mxnet model. This model is a pretrained model on ILSVRC2012 dataset. It is able to achieve 71.0% Top-1 Accuracy and 89.8% Top-5 accuracy on ILSVRC2012-Validation Set. Details about the network architecture can be found in the following arXiv paper: 4 | ``` 5 | Very Deep Convolutional Networks for Large-Scale Image Recognition 6 | K. Simonyan, A. Zisserman 7 | arXiv:1409.1556 8 | ``` 9 | Please cite the paper if you use the model. 10 | [**Download vgg16.mojo**](https://drive.google.com/file/d/0B5Dx9ePCIXQAZU51T0MyQXpvOXc/view?usp=sharing) 11 | 12 | + **mnist_deepcnet.mojo:** MNIST model 99.75% accuracy (0.25% error). Random +/-2 pixel translations on training data. No elastic distortions. Four convolution layers. Each deepcnet layer is a 2x2 convolution followed by 2x2 max pool. It took a little more than 2 hours to get to this accuracy in original mojo release. (softmax output had a bug during training- to capture this bug a backwards compatible layer 'brokemax' was added) 13 | ``` 14 | input 28x28x1 identity 15 | convolution 3x3 40 elu 16 | max_pool 2x2 17 | deepcnet 80 elu 18 | deepcnet 160 elu 19 | deepcnet 320 elu 20 | fully_connected 10 brokemax 21 | ``` 22 | 23 | + **cifar_deepcnet.mojo:** CIFAR-10 model 87.55% accuracy (12.45% error) No mean subtraction. Random mirror and +/-2 pixel translations on training data. No rotation, scale, or elastic augmentation. Five main convolution layers. Each deepcnet layer is a 2x2 convolution followed by 2x2 max pool. It took a little more than 8.5 hours to get to this accuracy in original mojo release. 24 | ``` 25 | input 32x32x3 identity 26 | convolution 3x3 50 elu 27 | max_pool 2x2 28 | deepcnet 100 elu 29 | deepcnet 150 elu 30 | resize 7 7 31 | deepcnet 200 elu 32 | deepcnet 250 elu 33 | fully_connected 10 tanh 34 | ``` 35 | -------------------------------------------------------------------------------- /models/cifar_deepcnet.mojo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnawice/mojo-cnn/5bbe8c5c012a2dbd9811a355cb1c1b2bc89a782f/models/cifar_deepcnet.mojo -------------------------------------------------------------------------------- /models/mnist_deepcnet.mojo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gnawice/mojo-cnn/5bbe8c5c012a2dbd9811a355cb1c1b2bc89a782f/models/mnist_deepcnet.mojo -------------------------------------------------------------------------------- /models/mnist_quickstart.txt: -------------------------------------------------------------------------------- 1 | mojo: 2 | # first line and filetype identifier must be 'mojo:' for quick start files 3 | # comments in quick start file must start at the beginning of the line (not after the layer description) 4 | # the input is 28x28 and 1 channel 5 | input 28 28 1 6 | 7 | # 5x5 convolution layer with 8 outputs and stride 1. elu (exponential linear unit) activation. 8 | convolution 5 8 1 elu 9 | 10 | # pooling size 3x3 stride 3 11 | semi_stochastic_pool 3 3 12 | 13 | # inception layer (1x1 convolution) 14 | convolution 1 16 1 elu 15 | 16 | # 5x5 convolution layer with 8 outputs and stride 1. elu (exponential linear unit) activation. 17 | convolution 5 48 1 elu 18 | 19 | # pooling size 2x2 stride 2 20 | semi_stochastic_pool 2 2 21 | 22 | # output softmax 10 channels 23 | softmax 10 24 | -------------------------------------------------------------------------------- /models/snapshots/README.md: -------------------------------------------------------------------------------- 1 | Placeholder folder for training snapshot models and log files. 2 | -------------------------------------------------------------------------------- /mojo/activation.h: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // activation.h: neuron activation functions 27 | // ==================================================================== mojo == 28 | 29 | #pragma once 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | namespace mojo { 36 | 37 | #ifdef MOJO_LUTS 38 | const float_t tanh_lut[1024] = { -1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f, 39 | -1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-1.f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f, 40 | -0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f, 41 | -0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999999f,-0.999998f,-0.999998f, 42 | -0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999998f,-0.999997f,-0.999997f,-0.999997f, 43 | -0.999997f,-0.999997f,-0.999997f,-0.999997f,-0.999997f,-0.999997f,-0.999997f,-0.999996f,-0.999996f,-0.999996f,-0.999996f,-0.999996f,-0.999996f,-0.999996f,-0.999996f,-0.999995f, 44 | -0.999995f,-0.999995f,-0.999995f,-0.999995f,-0.999995f,-0.999995f,-0.999994f,-0.999994f,-0.999994f,-0.999994f,-0.999994f,-0.999993f,-0.999993f,-0.999993f,-0.999993f,-0.999992f, 45 | -0.999992f,-0.999992f,-0.999992f,-0.999992f,-0.999991f,-0.999991f,-0.999991f,-0.99999f,-0.99999f,-0.99999f,-0.99999f,-0.999989f,-0.999989f,-0.999988f,-0.999988f,-0.999988f, 46 | -0.999987f,-0.999987f,-0.999987f,-0.999986f,-0.999986f,-0.999985f,-0.999985f,-0.999984f,-0.999984f,-0.999983f,-0.999983f,-0.999982f,-0.999981f,-0.999981f,-0.99998f,-0.99998f, 47 | -0.999979f,-0.999978f,-0.999978f,-0.999977f,-0.999976f,-0.999975f,-0.999975f,-0.999974f,-0.999973f,-0.999972f,-0.999971f,-0.99997f,-0.99997f,-0.999969f,-0.999968f,-0.999967f, 48 | -0.999966f,-0.999964f,-0.999963f,-0.999962f,-0.999961f,-0.99996f,-0.999958f,-0.999957f,-0.999956f,-0.999954f,-0.999953f,-0.999951f,-0.99995f,-0.999948f,-0.999947f,-0.999945f, 49 | -0.999943f,-0.999941f,-0.99994f,-0.999938f,-0.999936f,-0.999934f,-0.999931f,-0.999929f,-0.999927f,-0.999925f,-0.999922f,-0.99992f,-0.999917f,-0.999915f,-0.999912f,-0.999909f, 50 | -0.999906f,-0.999903f,-0.9999f,-0.999897f,-0.999894f,-0.99989f,-0.999887f,-0.999884f,-0.99988f,-0.999876f,-0.999872f,-0.999868f,-0.999864f,-0.999859f,-0.999855f,-0.99985f, 51 | -0.999846f,-0.999841f,-0.999836f,-0.99983f,-0.999825f,-0.99982f,-0.999814f,-0.999808f,-0.999802f,-0.999795f,-0.999789f,-0.999782f,-0.999775f,-0.999768f,-0.999761f,-0.999753f, 52 | -0.999745f,-0.999737f,-0.999729f,-0.99972f,-0.999712f,-0.999702f,-0.999693f,-0.999683f,-0.999673f,-0.999663f,-0.999652f,-0.999641f,-0.99963f,-0.999618f,-0.999606f,-0.999593f, 53 | -0.99958f,-0.999567f,-0.999553f,-0.999539f,-0.999524f,-0.999509f,-0.999494f,-0.999478f,-0.999461f,-0.999444f,-0.999426f,-0.999408f,-0.999389f,-0.99937f,-0.99935f,-0.999329f, 54 | -0.999308f,-0.999286f,-0.999263f,-0.99924f,-0.999216f,-0.999191f,-0.999165f,-0.999139f,-0.999112f,-0.999083f,-0.999054f,-0.999024f,-0.998993f,-0.998961f,-0.998929f,-0.998894f, 55 | -0.998859f,-0.998823f,-0.998786f,-0.998747f,-0.998708f,-0.998667f,-0.998624f,-0.998581f,-0.998536f,-0.998489f,-0.998441f,-0.998392f,-0.998341f,-0.998288f,-0.998234f,-0.998178f, 56 | -0.99812f,-0.998061f,-0.997999f,-0.997936f,-0.99787f,-0.997803f,-0.997733f,-0.997661f,-0.997587f,-0.99751f,-0.997431f,-0.99735f,-0.997266f,-0.997179f,-0.99709f,-0.996998f, 57 | -0.996902f,-0.996804f,-0.996703f,-0.996599f,-0.996491f,-0.99638f,-0.996265f,-0.996146f,-0.996024f,-0.995898f,-0.995769f,-0.995635f,-0.995496f,-0.995354f,-0.995207f,-0.995055f, 58 | -0.994898f,-0.994737f,-0.99457f,-0.994398f,-0.994221f,-0.994038f,-0.993849f,-0.993655f,-0.993454f,-0.993247f,-0.993033f,-0.992813f,-0.992585f,-0.992351f,-0.992109f,-0.99186f, 59 | -0.991602f,-0.991337f,-0.991063f,-0.990781f,-0.99049f,-0.990189f,-0.989879f,-0.98956f,-0.98923f,-0.98889f,-0.98854f,-0.988178f,-0.987805f,-0.98742f,-0.987023f,-0.986614f, 60 | -0.986192f,-0.985757f,-0.985308f,-0.984845f,-0.984368f,-0.983876f,-0.983368f,-0.982845f,-0.982305f,-0.981749f,-0.981175f,-0.980583f,-0.979973f,-0.979344f,-0.978695f,-0.978026f, 61 | -0.977336f,-0.976626f,-0.975892f,-0.975137f,-0.974357f,-0.973554f,-0.972726f,-0.971873f,-0.970993f,-0.970086f,-0.969151f,-0.968187f,-0.967194f,-0.96617f,-0.965115f,-0.964028f, 62 | -0.962907f,-0.961752f,-0.960562f,-0.959335f,-0.958072f,-0.956769f,-0.955428f,-0.954045f,-0.952621f,-0.951154f,-0.949642f,-0.948085f,-0.946481f,-0.944829f,-0.943128f,-0.941376f, 63 | -0.939571f,-0.937712f,-0.935799f,-0.933828f,-0.931799f,-0.92971f,-0.92756f,-0.925346f,-0.923068f,-0.920722f,-0.918309f,-0.915825f,-0.913269f,-0.910638f,-0.907932f,-0.905148f, 64 | -0.902284f,-0.899339f,-0.896309f,-0.893193f,-0.889989f,-0.886695f,-0.883308f,-0.879827f,-0.876248f,-0.87257f,-0.86879f,-0.864907f,-0.860916f,-0.856818f,-0.852607f,-0.848284f, 65 | -0.843844f,-0.839285f,-0.834605f,-0.829802f,-0.824872f,-0.819814f,-0.814624f,-0.809301f,-0.803841f,-0.798243f,-0.792503f,-0.786619f,-0.780588f,-0.774409f,-0.768079f,-0.761594f, 66 | -0.754954f,-0.748155f,-0.741195f,-0.734071f,-0.726783f,-0.719328f,-0.711702f,-0.703906f,-0.695935f,-0.68779f,-0.679468f,-0.670967f,-0.662286f,-0.653424f,-0.644378f,-0.635149f, 67 | -0.625735f,-0.616134f,-0.606348f,-0.596374f,-0.586212f,-0.575862f,-0.565325f,-0.5546f,-0.543687f,-0.532587f,-0.521301f,-0.50983f,-0.498174f,-0.486336f,-0.474316f,-0.462117f, 68 | -0.449741f,-0.437189f,-0.424464f,-0.41157f,-0.398509f,-0.385284f,-0.371899f,-0.358357f,-0.344663f,-0.330821f,-0.316835f,-0.30271f,-0.28845f,-0.274062f,-0.259549f,-0.244919f, 69 | -0.230176f,-0.215326f,-0.200377f,-0.185333f,-0.170202f,-0.154991f,-0.139705f,-0.124353f,-0.108941f,-0.0934763f,-0.0779665f,-0.0624188f,-0.0468407f,-0.0312398f,-0.0156237f,0.f, 70 | 0.0156237f,0.0312398f,0.0468407f,0.0624188f,0.0779665f,0.0934763f,0.108941f,0.124353f,0.139705f,0.154991f,0.170202f,0.185333f,0.200377f,0.215326f,0.230176f,0.244919f, 71 | 0.259549f,0.274062f,0.28845f,0.30271f,0.316835f,0.330821f,0.344663f,0.358357f,0.371899f,0.385284f,0.398509f,0.41157f,0.424464f,0.437189f,0.449741f,0.462117f, 72 | 0.474316f,0.486336f,0.498174f,0.50983f,0.521301f,0.532587f,0.543687f,0.5546f,0.565325f,0.575862f,0.586212f,0.596374f,0.606348f,0.616134f,0.625735f,0.635149f, 73 | 0.644378f,0.653424f,0.662286f,0.670967f,0.679468f,0.68779f,0.695935f,0.703906f,0.711702f,0.719328f,0.726783f,0.734071f,0.741195f,0.748155f,0.754954f,0.761594f, 74 | 0.768079f,0.774409f,0.780588f,0.786619f,0.792503f,0.798243f,0.803841f,0.809301f,0.814624f,0.819814f,0.824872f,0.829802f,0.834605f,0.839285f,0.843844f,0.848284f, 75 | 0.852607f,0.856818f,0.860916f,0.864907f,0.86879f,0.87257f,0.876248f,0.879827f,0.883308f,0.886695f,0.889989f,0.893193f,0.896309f,0.899339f,0.902284f,0.905148f, 76 | 0.907932f,0.910638f,0.913269f,0.915825f,0.918309f,0.920722f,0.923068f,0.925346f,0.92756f,0.92971f,0.931799f,0.933828f,0.935799f,0.937712f,0.939571f,0.941376f, 77 | 0.943128f,0.944829f,0.946481f,0.948085f,0.949642f,0.951154f,0.952621f,0.954045f,0.955428f,0.956769f,0.958072f,0.959335f,0.960562f,0.961752f,0.962907f,0.964028f, 78 | 0.965115f,0.96617f,0.967194f,0.968187f,0.969151f,0.970086f,0.970993f,0.971873f,0.972726f,0.973554f,0.974357f,0.975137f,0.975892f,0.976626f,0.977336f,0.978026f, 79 | 0.978695f,0.979344f,0.979973f,0.980583f,0.981175f,0.981749f,0.982305f,0.982845f,0.983368f,0.983876f,0.984368f,0.984845f,0.985308f,0.985757f,0.986192f,0.986614f, 80 | 0.987023f,0.98742f,0.987805f,0.988178f,0.98854f,0.98889f,0.98923f,0.98956f,0.989879f,0.990189f,0.99049f,0.990781f,0.991063f,0.991337f,0.991602f,0.99186f, 81 | 0.992109f,0.992351f,0.992585f,0.992813f,0.993033f,0.993247f,0.993454f,0.993655f,0.993849f,0.994038f,0.994221f,0.994398f,0.99457f,0.994737f,0.994898f,0.995055f, 82 | 0.995207f,0.995354f,0.995496f,0.995635f,0.995769f,0.995898f,0.996024f,0.996146f,0.996265f,0.99638f,0.996491f,0.996599f,0.996703f,0.996804f,0.996902f,0.996998f, 83 | 0.99709f,0.997179f,0.997266f,0.99735f,0.997431f,0.99751f,0.997587f,0.997661f,0.997733f,0.997803f,0.99787f,0.997936f,0.997999f,0.998061f,0.99812f,0.998178f, 84 | 0.998234f,0.998288f,0.998341f,0.998392f,0.998441f,0.998489f,0.998536f,0.998581f,0.998624f,0.998667f,0.998708f,0.998747f,0.998786f,0.998823f,0.998859f,0.998894f, 85 | 0.998929f,0.998961f,0.998993f,0.999024f,0.999054f,0.999083f,0.999112f,0.999139f,0.999165f,0.999191f,0.999216f,0.99924f,0.999263f,0.999286f,0.999308f,0.999329f, 86 | 0.99935f,0.99937f,0.999389f,0.999408f,0.999426f,0.999444f,0.999461f,0.999478f,0.999494f,0.999509f,0.999524f,0.999539f,0.999553f,0.999567f,0.99958f,0.999593f, 87 | 0.999606f,0.999618f,0.99963f,0.999641f,0.999652f,0.999663f,0.999673f,0.999683f,0.999693f,0.999702f,0.999712f,0.99972f,0.999729f,0.999737f,0.999745f,0.999753f, 88 | 0.999761f,0.999768f,0.999775f,0.999782f,0.999789f,0.999795f,0.999802f,0.999808f,0.999814f,0.99982f,0.999825f,0.99983f,0.999836f,0.999841f,0.999846f,0.99985f, 89 | 0.999855f,0.999859f,0.999864f,0.999868f,0.999872f,0.999876f,0.99988f,0.999884f,0.999887f,0.99989f,0.999894f,0.999897f,0.9999f,0.999903f,0.999906f,0.999909f, 90 | 0.999912f,0.999915f,0.999917f,0.99992f,0.999922f,0.999925f,0.999927f,0.999929f,0.999931f,0.999934f,0.999936f,0.999938f,0.99994f,0.999941f,0.999943f,0.999945f, 91 | 0.999947f,0.999948f,0.99995f,0.999951f,0.999953f,0.999954f,0.999956f,0.999957f,0.999958f,0.99996f,0.999961f,0.999962f,0.999963f,0.999964f,0.999966f,0.999967f, 92 | 0.999968f,0.999969f,0.99997f,0.99997f,0.999971f,0.999972f,0.999973f,0.999974f,0.999975f,0.999975f,0.999976f,0.999977f,0.999978f,0.999978f,0.999979f,0.99998f, 93 | 0.99998f,0.999981f,0.999981f,0.999982f,0.999983f,0.999983f,0.999984f,0.999984f,0.999985f,0.999985f,0.999986f,0.999986f,0.999987f,0.999987f,0.999987f,0.999988f, 94 | 0.999988f,0.999988f,0.999989f,0.999989f,0.99999f,0.99999f,0.99999f,0.99999f,0.999991f,0.999991f,0.999991f,0.999992f,0.999992f,0.999992f,0.999992f,0.999992f, 95 | 0.999993f,0.999993f,0.999993f,0.999993f,0.999994f,0.999994f,0.999994f,0.999994f,0.999994f,0.999995f,0.999995f,0.999995f,0.999995f,0.999995f,0.999995f,0.999995f, 96 | 0.999996f,0.999996f,0.999996f,0.999996f,0.999996f,0.999996f,0.999996f,0.999996f,0.999997f,0.999997f,0.999997f,0.999997f,0.999997f,0.999997f,0.999997f,0.999997f, 97 | 0.999997f,0.999997f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f,0.999998f, 98 | 0.999998f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f, 99 | 0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f,0.999999f, 100 | 0.999999f,0.999999f,0.999999f,0.999999f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f, 101 | 1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f,1.f }; 102 | #endif 103 | 104 | // not using class because I thought this may be faster than vptrs 105 | namespace tan_h 106 | { 107 | #ifndef MOJO_LUTS 108 | inline void f(float *in, const int size, const float *bias) // this is activation f(x) 109 | { 110 | for(int i=0; i= 1024) return 1.f; // iff exceed max index size 133 | else if (index<0) return -1.f; // or below min index 0 134 | in[i]= tanh_lut[index]; 135 | } 136 | } 137 | inline void fc(float *in, const int size, const float bias) // this is activation f(x) 138 | { 139 | for(int i=0; i= 1024) return 1.f; // iff exceed max index size 143 | else if (index<0) return -1.f; // or below min index 0 144 | in[i]= tanh_lut[index]; 145 | } 146 | } 147 | #endif // MOJO_LUTS 148 | inline float df(float *in, int i, const int size) { return (1.f - in[i]*in[i]); } // this is df(x), but we pass in the activated value f(x) and not x 149 | const char name[]="tanh"; 150 | } 151 | 152 | namespace elu 153 | { 154 | inline void f(float *in, const int size, const float *bias) 155 | { 156 | for(int i=0; i 0) return 1.f; else return 0.1f*std::exp(in[i]);} 172 | const char name[]="elu"; 173 | } 174 | 175 | namespace identity 176 | { 177 | inline void f(float *in, const int size, const float *bias) 178 | { 179 | for(int i=0; i 0) return 1.0f; else return 0.0f; } 206 | const char name[]="relu"; 207 | }; 208 | namespace lrelu 209 | { 210 | inline void f(float *in, const int size, const float *bias) 211 | { 212 | for(int i=0; i 0) return 1.0f; else return 0.01f; } 225 | const char name[]="lrelu"; 226 | }; 227 | namespace vlrelu 228 | { 229 | inline void f(float *in, const int size, const float *bias) 230 | { 231 | for(int i=0; i 0) return 1.0f; else return 0.33f; } 244 | const char name[]="vlrelu"; 245 | }; 246 | 247 | namespace sigmoid 248 | { 249 | inline void f(float *in, const int size, const float *bias) 250 | { 251 | for(int i=0; i max) max = in[j]; 268 | 269 | float denom = 0; 270 | for (int j = 0; j max) max = in[j]; 278 | 279 | float denom = 0; 280 | for (int j = 0; j max) max = in[j]; 307 | 308 | float denom = 0; 309 | for (int j = 0; j max) max = in[j]; 319 | 320 | float denom = 0; 321 | for (int j = 0; jf = &tan_h::f; p->fc = &tan_h::fc; p->df = &tan_h::df; p->name=tan_h::name;return p;} 359 | if(act.compare(identity::name)==0) { p->f = &identity::f; p->fc = &identity::fc; p->df = &identity::df; p->name=identity::name; return p;} 360 | if(act.compare(vlrelu::name)==0) { p->f = &vlrelu::f; p->fc = &vlrelu::fc; p->df = &vlrelu::df; p->name=vlrelu::name; return p;} 361 | if(act.compare(lrelu::name)==0) { p->f = &lrelu::f; p->fc = &lrelu::fc; p->df = &lrelu::df; p->name=lrelu::name; return p;} 362 | if(act.compare(relu::name)==0) { p->f = &relu::f; p->fc = &relu::fc;p->df = &relu::df; p->name=relu::name;return p;} 363 | if(act.compare(sigmoid::name)==0) { p->f = &sigmoid::f; p->fc = &sigmoid::fc;p->df = &sigmoid::df; p->name=sigmoid::name; return p;} 364 | if(act.compare(elu::name)==0) { p->f = &elu::f; p->fc = &elu::fc; p->df = &elu::df; p->name=elu::name; return p;} 365 | if(act.compare(none::name)==0) { p->f = &none::f; p->fc = &none::fc; p->df = &none::df; p->name=none::name; return p;} 366 | if(act.compare(softmax::name) == 0) { p->f = &softmax::f; p->fc = &softmax::fc;p->df = &softmax::df; p->name = softmax::name; return p; } 367 | if(act.compare(brokemax::name) == 0) { p->f = &brokemax::f; p->fc = &brokemax::fc;p->df = &brokemax::df; p->name = brokemax::name; return p; } 368 | delete p; 369 | return NULL; 370 | } 371 | 372 | activation_function* new_activation_function(const char *type) 373 | { 374 | std::string act(type); 375 | return new_activation_function(act); 376 | } 377 | 378 | } // namespace -------------------------------------------------------------------------------- /mojo/core_math.h: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // core_math.h: defines matrix class and math functions 27 | // ==================================================================== mojo == 28 | 29 | 30 | #pragma once 31 | 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | 40 | namespace mojo 41 | { 42 | 43 | enum pad_type { zero = 0, edge = 1, median_edge = 2 }; 44 | 45 | inline float dot(const float *x1, const float *x2, const int size) 46 | { 47 | switch (size) 48 | { 49 | case 1: return x1[0] * x2[0]; 50 | case 2: return x1[0] * x2[0] + x1[1] * x2[1]; 51 | case 3: return x1[0] * x2[0] + x1[1] * x2[1] + x1[2] * x2[2]; 52 | case 4: return x1[0] * x2[0] + x1[1] * x2[1] + x1[2] * x2[2] + x1[3] * x2[3]; 53 | case 5: return x1[0] * x2[0] + x1[1] * x2[1] + x1[2] * x2[2] + x1[3] * x2[3] + x1[4] * x2[4]; 54 | default: 55 | float v = 0; 56 | for (int i = 0; i 7) { off = 0; c1 += inc_off; } 153 | } 154 | 155 | } 156 | } 157 | 158 | 159 | inline void dotsum_unwrapped_NxN(const int N, const float *im, const float *filter_ptr, float *out, const int outsize) 160 | { 161 | const int NN=N*N; 162 | for (int j = 0; j < outsize; j += 8) 163 | { 164 | float *c = out+j; 165 | for(int i=0; i 0) s += 8 - remainder; 420 | return s; 421 | } 422 | else return w*h; 423 | } 424 | 425 | matrix( ): cols(0), rows(0), chans(0), _size(0), _capacity(0), chan_stride(0), x(NULL), chan_aligned(0)/*, empty_chan(NULL)*/{} 426 | 427 | 428 | matrix( int _w, int _h, int _c=1, const float *data=NULL, int align_chan=0): cols(_w), rows(_h), chans(_c) 429 | { 430 | chan_aligned = align_chan; 431 | chan_stride = calc_chan_stride(cols, rows); 432 | _size= chan_stride*chans; _capacity=_size; x = new_x(_size); 433 | if(data!=NULL) memcpy(x,data,_size*sizeof(float)); 434 | } 435 | 436 | // copy constructor - deep copy 437 | matrix( const matrix &m) : cols(m.cols), rows(m.rows), chan_aligned(m.chan_aligned), chans(m.chans), chan_stride(m.chan_stride), _size(m._size), _capacity(m._size) {x = new_x(_size); memcpy(x,m.x,sizeof(float)*_size); /*empty_chan = new unsigned char[chans]; memcpy(empty_chan, m.empty_chan, chans);*/} // { v=m.v; x=(float*)v.data();} 438 | // copy and pad constructor 439 | matrix( const matrix &m, int pad_cols, int pad_rows, mojo::pad_type padding= mojo::zero, int threads=1) : cols(m.cols), rows(m.rows), chans(m.chans), chan_aligned(m.chan_aligned), chan_stride(m.chan_stride), _size(m._size), _capacity(m._size) 440 | { 441 | x = new_x(_size); memcpy(x, m.x, sizeof(float)*_size); 442 | *this = pad(pad_cols, pad_rows, padding, threads); 443 | } 444 | 445 | ~matrix() { if (x) delete_x(); } 446 | 447 | matrix get_chans(int start_channel, int num_chans=1) const 448 | { 449 | return matrix(cols,rows,num_chans,&x[start_channel*chan_stride]); 450 | } 451 | 452 | 453 | // if edge_pad==0, then the padded area is just 0. 454 | // if edge_pad==1 it fills with edge pixel colors 455 | // if edge_pad==2 it fills with median edge pixel color 456 | matrix pad(int dx, int dy, mojo::pad_type edge_pad = mojo::zero, int threads=1) const 457 | { 458 | return pad(dx, dy, dx, dy, edge_pad, threads); 459 | } 460 | matrix pad(int dx, int dy, int dx_right, int dy_bottom, mojo::pad_type edge_pad = mojo::zero, int threads=1) const 461 | { 462 | matrix v(cols+dx+dx_right,rows+dy+dy_bottom,chans);//,NULL,this->chan_aligned); 463 | v.fill(0); 464 | 465 | //float *new_x = new float[chans*w*h]; 466 | #pragma omp parallel for num_threads(threads) 467 | for(int k=0; k d(perimeter); 477 | for (int i = 0; i < cols; i++) 478 | { 479 | d[i] = x[i+ chan_offset]; d[i + cols] = x[i + cols*(rows - 1)+ chan_offset]; 480 | } 481 | for (int i = 1; i < (rows - 1); i++) 482 | { 483 | d[i + cols * 2] = x[cols*i+ chan_offset]; 484 | // file from back so i dont need to cal index 485 | d[perimeter - i] = x[cols - 1 + cols*i+ chan_offset]; 486 | } 487 | 488 | std::nth_element(d.begin(), d.begin() + perimeter / 2, d.end()); 489 | median = d[perimeter / 2]; 490 | //for (int i = 0; i < v.rows*v.cols; i++) v.x[v_chan_offset + i] = solid_fill; 491 | } 492 | 493 | for(int j=0; j max) x[i]=max; 584 | } 585 | } 586 | 587 | 588 | void min_max(float *min, float *max, int *min_i=NULL, int *max_i=NULL) 589 | { 590 | int s = rows*cols; 591 | int mini = 0; 592 | int maxi = 0; 593 | for (int c = 0; c < chans; c++) 594 | { 595 | const int t = chan_stride*c; 596 | for (int i = t; i < t+s; i++) 597 | { 598 | if (x[i] < x[mini]) mini = i; 599 | if (x[i] > x[maxi]) maxi = i; 600 | } 601 | } 602 | *min = x[mini]; 603 | *max = x[maxi]; 604 | if (min_i) *min_i = mini; 605 | if (max_i) *max_i = maxi; 606 | } 607 | 608 | float mean() 609 | { 610 | const int s = rows*cols; 611 | int cnt = 0;// channel*s; 612 | float average = 0; 613 | for (int c = 0; c < chans; c++) 614 | { 615 | const int t = chan_stride*c; 616 | for (int i = 0; i < s; i++) 617 | average += x[i + t]; 618 | } 619 | average = average / (float)(s*chans); 620 | return average; 621 | } 622 | float remove_mean(int channel) 623 | { 624 | int s = rows*cols; 625 | int offset = channel*chan_stride; 626 | float average=0; 627 | for(int i=0; i dst(-range, range); 647 | for (int i = 0; i<_size; i++) x[i] = dst(gen); 648 | } 649 | void fill_random_normal(float std) 650 | { 651 | std::mt19937 gen(0); 652 | std::normal_distribution dst(0, std); 653 | for (int i = 0; i<_size; i++) x[i] = dst(gen); 654 | } 655 | 656 | 657 | // deep copy 658 | inline matrix& operator =(const matrix &m) 659 | { 660 | resize(m.cols, m.rows, m.chans, m.chan_aligned); 661 | memcpy(x,m.x,sizeof(float)*_size); 662 | // memcpy(empty_chan, m.empty_chan, chans); 663 | return *this; 664 | } 665 | 666 | int size() const {return _size;} 667 | 668 | void resize(int _w, int _h, int _c, int align_chans=0) { 669 | chan_aligned = align_chans; 670 | int new_stride = calc_chan_stride(_w,_h); 671 | int s = new_stride*_c; 672 | if(s>_capacity) 673 | { 674 | if(_capacity>0) delete_x(); _size = s; _capacity=_size; x = new_x(_size); 675 | } 676 | cols = _w; rows = _h; chans = _c; _size = s; chan_stride = new_stride; 677 | } 678 | 679 | // dot vector to 2d mat 680 | inline matrix dot_1dx2d(const matrix &m_2d) const 681 | { 682 | mojo::matrix v(m_2d.rows, 1, 1); 683 | for(int j=0; j 33 | #include 34 | #include 35 | 36 | namespace mojo { 37 | 38 | namespace mse 39 | { 40 | inline float cost(float out, float target) {return 0.5f*(out-target)*(out-target);}; 41 | inline float d_cost(float out, float target) {return (out-target);}; 42 | const char name[]="mse"; 43 | } 44 | /* 45 | namespace triplet_loss 46 | { 47 | inline float E(float out1, float out2, float out3) {return 0.5f*(out-target)*(out-target);}; 48 | inline float dE(float out, float target) {return (out-target);}; 49 | const char name[]="triplet_loss"; 50 | } 51 | */ 52 | namespace cross_entropy 53 | { 54 | inline float cost(float out, float target) {return (-target * std::log(out) - (1.f - target) * std::log(1.f - out));}; 55 | inline float d_cost(float out, float target) {return ((out - target) / (out*(1.f - out)));}; 56 | const char name[]="cross_entropy"; 57 | } 58 | 59 | 60 | typedef struct 61 | { 62 | public: 63 | float (*cost)(float, float); 64 | float (*d_cost)(float, float); 65 | const char *name; 66 | } cost_function; 67 | 68 | cost_function* new_cost_function(std::string loss) 69 | { 70 | cost_function *p = new cost_function; 71 | if(loss.compare(cross_entropy::name)==0) { p->cost = &cross_entropy::cost; p->d_cost = &cross_entropy::d_cost; p->name=cross_entropy::name;return p;} 72 | if(loss.compare(mse::name)==0) { p->cost = &mse::cost; p->d_cost = &mse::d_cost; p->name=mse::name; return p;} 73 | //if(loss.compare(triplet_loss::name)==0) { p->E = &triplet_loss::E; p->dE = &triplet_loss::dE; p->name=triplet_loss::name; return p;} 74 | delete p; 75 | return NULL; 76 | } 77 | 78 | cost_function* new_cost_function(const char *type) 79 | { 80 | std::string loss(type); 81 | return new_cost_function(loss); 82 | } 83 | 84 | } -------------------------------------------------------------------------------- /mojo/mojo.h: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // mojo.h: include this one file to use mojo cnn without OpenMP 27 | // ==================================================================== mojo == 28 | 29 | #pragma once 30 | 31 | #define MOJO_AVX // turn on AVX / SSE3 / SIMD optimizations 32 | #define MOJO_OMP // allow multi-threading through openmp 33 | //#define MOJO_LUTS // use look up tables, uses more memory 34 | //#define MOJO_CV3 // use OpenCV 3.x utilities 35 | //#define MOJO_CV2 // use OpenCV 2.x utilities 36 | //#define MOJO_PROFILE_LAYERS // std::cout layer names and latencies 37 | #define MOJO_INTERNAL_THREADING // try to speed forward pass with internal threading 38 | 39 | #ifdef MOJO_OMP 40 | #include 41 | #ifndef MOJO_INTERNAL_THREADING 42 | #define MOJO_THREAD_THIS_LOOP(a) 43 | #define MOJO_THREAD_THIS_LOOP_DYNAMIC(a) 44 | #else 45 | #ifdef _WIN32 46 | // this macro uses OMP where the loop is split up into equal chunks per thread 47 | #define MOJO_THREAD_THIS_LOOP(a) __pragma(omp parallel for num_threads(a)) 48 | // this macro uses OMP where the loop is split up dynamically and work is taken by next available thread 49 | #define MOJO_THREAD_THIS_LOOP_DYNAMIC(a) __pragma(omp parallel for schedule(dynamic) num_threads(a)) 50 | #else 51 | #define MOJO_THREAD_THIS_LOOP(a) 52 | #define MOJO_THREAD_THIS_LOOP_DYNAMIC(a) 53 | //#define MOJO_THREAD_THIS_LOOP(a) _Pragma("omp parallel for schedule(dynamic) num_threads(" #a ")") 54 | //#define MOJO_THREAD_THIS_LOOP_DYNAMIC(a) _Pragma("omp parallel for schedule(dynamic) num_threads(" #a ")") 55 | #endif 56 | #endif //_WIN32 57 | #else 58 | #define MOJO_THREAD_THIS_LOOP(a) 59 | #define MOJO_THREAD_THIS_LOOP_DYNAMIC(a) 60 | #endif 61 | 62 | 63 | #ifdef MOJO_AF 64 | #include 65 | #define AF_RELEASE 66 | #ifdef MOJO_CUDA 67 | #define AF_CUDA 68 | #pragma comment(lib, "afcuda") 69 | #else 70 | #define AF_CPU 71 | #pragma comment(lib, "afcpu") 72 | #endif 73 | #endif 74 | 75 | 76 | #include "core_math.h" 77 | 78 | #include "network.h" // this is the important thing 79 | #include "util.h" 80 | 81 | 82 | -------------------------------------------------------------------------------- /mojo/solver.h: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // solver.h: stochastic optimization approaches 27 | // ==================================================================== mojo == 28 | 29 | #pragma once 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | 37 | #include "core_math.h" 38 | 39 | // hack for VS2010 to handle c++11 for(:) 40 | #if (_MSC_VER == 1600) 41 | #ifndef __for__ 42 | #define __for__ for each 43 | #define __in__ in 44 | #endif 45 | #else 46 | #ifndef __for__ 47 | #define __for__ for 48 | #define __in__ : 49 | #endif 50 | #endif 51 | 52 | namespace mojo { 53 | 54 | 55 | 56 | class solver 57 | { 58 | public: 59 | // learning rates are 'tweaked' in inc_w function so that they can be similar for all solvers 60 | float learning_rate; 61 | solver(): learning_rate(0.01f) {} 62 | virtual ~solver(){} 63 | virtual void reset() {} 64 | // this increments the weight matrix w, which corresponds to connection index 'g' 65 | // bottom is the number of grads coming up from the lower layer 66 | // top is the current output node value of the upper layer 67 | virtual void increment_w(matrix *w, int g, const matrix &dW, const float custom_factor=1.0f){}//, matrix *top){} 68 | virtual void push_back(int w, int h, int c){} 69 | }; 70 | 71 | #ifndef MOJO_NO_TRAINING 72 | 73 | 74 | class sgd: public solver 75 | { 76 | public: 77 | static const char *name(){return "sgd";} 78 | 79 | virtual void increment_w(matrix *w, int g, const matrix &dW, const float custom_factor = 1.0f) 80 | { 81 | const float w_decay=0.01f;//1; 82 | const float lr=custom_factor*learning_rate; 83 | for(int s=0; ssize(); s++) 84 | w->x[s] -= lr*(dW.x[s] + w_decay*w->x[s]); 85 | } 86 | }; 87 | 88 | class adagrad: public solver 89 | { 90 | // persistent variables that mirror size of weight matrix 91 | std::vector G1; 92 | public: 93 | static const char *name(){return "adagrad";} 94 | 95 | virtual ~adagrad(){__for__(auto g __in__ G1) delete g;} 96 | virtual void push_back(int w, int h, int c) { 97 | G1.push_back(new matrix(w, h, c)); 98 | }// G1[G1.size() - 1]->fill(0); 99 | 100 | 101 | virtual void reset() { __for__(auto g __in__ G1) g->fill(0.f);} 102 | virtual void increment_w(matrix *w, int g, const matrix &dW, const float custom_factor = 1.0f) 103 | { 104 | float *g1 = G1[g]->x; 105 | //float min, max; 106 | //G1[g]->min_max(&min, &max); 107 | //std::cout << "((" << min << "," << max << ")"; 108 | const float eps = 1.e-8f; 109 | // if (G1[g]->size() != w->size()) throw; 110 | const float lr = custom_factor*learning_rate; 111 | for(int s=0; ssize(); s++) 112 | { 113 | g1[s] += dW.x[s] * dW.x[s]; 114 | //if (g1[s] < 1) throw; 115 | w->x[s] -= lr*dW.x[s]/(std::sqrt(g1[s]) + eps); 116 | } 117 | }; 118 | }; 119 | 120 | class rmsprop: public solver 121 | { 122 | // persistent variables that mirror size of weight matrix 123 | std::vector G1; 124 | public: 125 | static const char *name(){return "rmsprop";} 126 | virtual ~rmsprop(){__for__(auto g __in__ G1) delete g;} 127 | 128 | virtual void push_back(int w, int h, int c){ G1.push_back(new matrix(w,h,c)); G1[G1.size() - 1]->fill(0);} 129 | virtual void reset() { __for__(auto g __in__ G1) g->fill(0.f);} 130 | virtual void increment_w(matrix *w, int g, const matrix &dW, const float custom_factor = 1.0f) 131 | { 132 | float *g1 = G1[g]->x; 133 | const float eps = 1.e-8f; 134 | const float mu = 0.999f; 135 | const float lr = 0.01f*custom_factor*learning_rate; 136 | 137 | for(int s=0; s<(int)w->size(); s++) 138 | { 139 | g1[s] = mu * g1[s]+(1-mu) * dW.x[s] * dW.x[s]; 140 | w->x[s] -= lr*dW.x[s]/(std::sqrt(g1[s]) + eps); 141 | } 142 | }; 143 | 144 | }; 145 | 146 | class adam: public solver 147 | { 148 | float b1_t, b2_t; 149 | const float b1, b2; 150 | // persistent variables that mirror size of weight matrix 151 | std::vector G1; 152 | std::vector G2; 153 | public: 154 | static const char *name(){return "adam";} 155 | adam(): b1(0.9f), b1_t(0.9f), b2(0.999f), b2_t(0.999f), solver() {} 156 | virtual ~adam(){__for__(auto g __in__ G1) delete g; __for__(auto g __in__ G2) delete g;} 157 | 158 | virtual void reset() 159 | { 160 | b1_t*=b1; b2_t*=b2; 161 | __for__(auto g __in__ G1) g->fill(0.f); 162 | __for__(auto g __in__ G2) g->fill(0.f); 163 | } 164 | 165 | virtual void push_back(int w, int h, int c) 166 | { 167 | G1.push_back(new matrix(w,h,c)); G1[G1.size() - 1]->fill(0); 168 | G2.push_back(new matrix(w,h,c)); G2[G2.size() - 1]->fill(0); 169 | } 170 | 171 | virtual void increment_w(matrix *w, int g, const matrix &dW, const float custom_factor = 1.0f) 172 | { 173 | float *g1 = G1[g]->x; 174 | float *g2 = G2[g]->x; 175 | const float eps = 1.e-8f; 176 | const float b1=0.9f, b2=0.999f; 177 | const float lr = 0.1f*custom_factor*learning_rate; 178 | for(int s=0; s<(int)w->size(); s++) 179 | { 180 | g1[s] = b1* g1[s]+(1-b1) * dW.x[s]; 181 | g2[s] = b2* g2[s]+(1-b2) * dW.x[s]*dW.x[s]; 182 | w->x[s] -= lr* (g1[s]/(1.f-b1_t)) / ((float)std::sqrt(g2[s]/(1.-b2_t)) + eps); 183 | } 184 | }; 185 | 186 | }; 187 | 188 | 189 | solver* new_solver(const char *type) 190 | { 191 | if(type==NULL) return NULL; 192 | std::string act(type); 193 | if(act.compare(sgd::name())==0) { return new sgd();} 194 | if(act.compare(rmsprop::name())==0) { return new rmsprop();} 195 | if(act.compare(adagrad::name())==0) { return new adagrad();} 196 | if(act.compare(adam::name())==0) { return new adam();} 197 | 198 | return NULL; 199 | } 200 | 201 | #else 202 | 203 | 204 | solver* new_solver(const char *type) {return NULL;} 205 | solver* new_solver(std::string act){return NULL;} 206 | 207 | #endif 208 | 209 | 210 | } // namespace -------------------------------------------------------------------------------- /mojo/util.h: -------------------------------------------------------------------------------- 1 | // == mojo ==================================================================== 2 | // 3 | // Copyright (c) gnawice@gnawice.com. All rights reserved. 4 | // See LICENSE in root folder 5 | // 6 | // Permission is hereby granted, free of charge, to any person obtaining a 7 | // copy of this software and associated documentation files(the "Software"), 8 | // to deal in the Software without restriction, including without 9 | // limitation the rights to use, copy, modify, merge, publish, distribute, 10 | // sublicense, and/or sell copies of the Software, and to permit persons to 11 | // whom the Software is furnished to do so, subject to the following 12 | // conditions : 13 | // 14 | // The above copyright notice and this permission notice shall be included 15 | // in all copies or substantial portions of the Software. 16 | // 17 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 18 | // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 20 | // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 21 | // CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT 22 | // OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 23 | // THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | // 25 | // ============================================================================ 26 | // util.h: various stuff- progress, html log, opencv 27 | // ==================================================================== mojo == 28 | 29 | #pragma once 30 | 31 | 32 | #include 33 | #include 34 | #if (_MSC_VER != 1600) 35 | #include 36 | #else 37 | #include 38 | #endif 39 | #include "core_math.h" 40 | #include "network.h" 41 | 42 | 43 | #ifdef MOJO_CV2 44 | #include "opencv2/opencv.hpp" 45 | #include "opencv2/highgui/highgui.hpp" 46 | #include "opencv2/imgproc/imgproc.hpp" 47 | #include "opencv2/contrib/contrib.hpp" 48 | 49 | #pragma comment(lib, "opencv_core249") 50 | #pragma comment(lib, "opencv_highgui249") 51 | #pragma comment(lib, "opencv_imgproc249") 52 | #pragma comment(lib, "opencv_contrib249") 53 | #endif 54 | 55 | #ifdef MOJO_CV3 56 | #include "opencv2/opencv.hpp" 57 | #include "opencv2/highgui/highgui.hpp" 58 | #include "opencv2/imgproc/imgproc.hpp" 59 | 60 | #pragma comment(lib, "opencv_world310") 61 | #endif 62 | 63 | namespace mojo 64 | { 65 | 66 | // class to handle timing and drawing text progress output 67 | class progress 68 | { 69 | 70 | public: 71 | progress(int size=-1, const char *label=NULL ) {reset(size, label);} 72 | 73 | #if (_MSC_VER == 1600) 74 | unsigned int start_progress_time; 75 | #else 76 | std::chrono::time_point start_progress_time; 77 | #endif 78 | unsigned int total_progress_items; 79 | std::string label_progress; 80 | // if default values used, the values won't be changed from last call 81 | void reset(int size=-1, const char *label=NULL ) 82 | { 83 | #if (_MSC_VER == 1600) 84 | start_progress_time= clock(); 85 | #else 86 | start_progress_time= std::chrono::system_clock::now(); 87 | #endif 88 | if(size>0) total_progress_items=size; if(label!=NULL) label_progress=label; 89 | } 90 | float elapsed_seconds() 91 | { 92 | #if (_MSC_VER == 1600) 93 | float time_span = (float)(clock() - start_progress_time)/CLOCKS_PER_SEC; 94 | return time_span; 95 | #else 96 | std::chrono::duration time_span = std::chrono::duration_cast>(std::chrono::system_clock::now() - start_progress_time); 97 | return (float)time_span.count(); 98 | #endif 99 | } 100 | float remaining_seconds(int item_index) 101 | { 102 | float elapsed_dt = elapsed_seconds(); 103 | float percent_complete = 100.f*item_index/total_progress_items; 104 | if(percent_complete>0) return ((elapsed_dt/percent_complete*100.f)-elapsed_dt); 105 | return 0.f; 106 | } 107 | // this doesn't work correctly with g++/Cygwin 108 | // the carriage return seems to delete the text... 109 | void draw_progress(int item_index) 110 | { 111 | int time_remaining = (int)remaining_seconds(item_index); 112 | float percent_complete = 100.f*item_index/total_progress_items; 113 | if (percent_complete > 0) 114 | { 115 | std::cout << label_progress << (int)percent_complete << "% (" << (int)time_remaining << "sec remaining) \r"< log; 156 | std::string header; 157 | std::string notes; 158 | public: 159 | html_log() {}; 160 | 161 | // the header you set here should have tab \t separated column headers that match what will go in the row 162 | // the first 3 columns are always epoch, test accuracy, est accuracy 163 | void set_table_header(std::string tab_header) { header=tab_header;} 164 | // tab_row should be \t separated things to put after first 3 columns 165 | void add_table_row(float train_acccuracy, float test_accuracy, std::string tab_row) 166 | { 167 | log_stuff s; 168 | s.str = tab_row; s.test_accurracy = test_accuracy; s.train_accurracy_est = train_acccuracy; 169 | log.push_back(s); 170 | } 171 | void set_note(std::string msg) {notes = msg;} 172 | bool write(std::string filename) { 173 | 174 | std::string top = "Mojo CNN Training Report
Training Summary
"; 193 | 194 | std::string msg = ""; 195 | int N = (int)log.size(); 196 | msg += ""; 197 | int best = N - 1; 198 | int best_est = N - 1; 199 | for (int i = N - 1; i >= 0; i--) 200 | { 201 | if (log[i].test_accurracy > log[best].test_accurracy) best = i; 202 | if (log[i].train_accurracy_est > log[best_est].train_accurracy_est) best_est = i; 203 | } 204 | for (int i = N - 1; i >= 0; i--) 205 | { 206 | msg += ""; 215 | } 216 | replace_str(msg, "\t", "
" + header + "
" + int2str(i + 1); 207 | // make best green 208 | if (i == best) msg += ""; 209 | else msg += ""; 210 | msg+=float2str(log[i].test_accurracy); 211 | // mark bad trend in training 212 | if (i > best_est) msg += ""; 213 | else msg += ""; 214 | msg+=float2str(log[i].train_accurracy_est)+ "" + log[i].str + "
"); 217 | 218 | replace_str(notes, "\n", "
"); 219 | std::string bottom = "

"+notes+""; 220 | 221 | std::ofstream f(filename.c_str()); 222 | f << top; f << data; f << mid; f << msg; f << bottom; 223 | 224 | f.close(); 225 | return true; 226 | } 227 | 228 | }; 229 | 230 | #if defined(MOJO_CV2) || defined(MOJO_CV3) 231 | 232 | 233 | // transforms image. 234 | // x_center, y_center of input 235 | // out dim is size of output w or h 236 | // theta in degrees 237 | cv::Mat matrix2cv(const mojo::matrix &m, bool uc8)// = false) 238 | { 239 | cv::Mat cv_m; 240 | if (m.chans != 3) 241 | { 242 | cv_m = cv::Mat(m.cols, m.rows, CV_32FC1, m.x); 243 | } 244 | if (m.chans == 3) 245 | { 246 | cv::Mat in[3]; 247 | in[0] = cv::Mat(m.cols, m.rows, CV_32FC1, m.x); 248 | in[1] = cv::Mat(m.cols, m.rows, CV_32FC1, &m.x[m.cols*m.rows]); 249 | in[2] = cv::Mat(m.cols, m.rows, CV_32FC1, &m.x[2 * m.cols*m.rows]); 250 | cv::merge(in, 3, cv_m); 251 | } 252 | if (uc8) 253 | { 254 | double min_, max_; 255 | cv_m = cv_m.reshape(1); 256 | cv::minMaxIdx(cv_m, &min_, &max_); 257 | cv_m = cv_m - min_; 258 | max_ = max_ - min_; 259 | cv_m /= max_; 260 | cv_m *= 255; 261 | cv_m = cv_m.reshape(m.chans, m.rows); 262 | if (m.chans != 3) 263 | cv_m.convertTo(cv_m, CV_8UC1); 264 | else 265 | cv_m.convertTo(cv_m, CV_8UC3); 266 | } 267 | return cv_m; 268 | } 269 | 270 | mojo::matrix cv2matrix(cv::Mat &m) 271 | { 272 | if (m.type() == CV_8UC1) 273 | { 274 | m.convertTo(m, CV_32FC1); 275 | m = m / 255.; 276 | } 277 | if (m.type() == CV_8UC3) 278 | { 279 | m.convertTo(m, CV_32FC3); 280 | } 281 | if (m.type() == CV_32FC1) 282 | { 283 | return mojo::matrix(m.cols, m.rows, 1, (float*)m.data); 284 | } 285 | if (m.type() == CV_32FC3) 286 | { 287 | cv::Mat in[3]; 288 | cv::split(m, in); 289 | mojo::matrix out(m.cols, m.rows, 3); 290 | memcpy(out.x, in[0].data, m.cols*m.rows * sizeof(float)); 291 | memcpy(&out.x[m.cols*m.rows], in[1].data, m.cols*m.rows * sizeof(float)); 292 | memcpy(&out.x[2 * m.cols*m.rows], in[2].data, m.cols*m.rows * sizeof(float)); 293 | return out; 294 | } 295 | return mojo::matrix(0, 0, 0); 296 | } 297 | mojo::matrix transform(const mojo::matrix in, const int x_center, const int y_center, 298 | int out_dim, float theta, float scale)// theta =0 scale= 1.f) 299 | { 300 | const double _pi = 3.14159265358979323846; 301 | float cos_theta = (float)std::cos(theta / 180.*_pi); 302 | float sin_theta = (float)std::sin(theta / 180.*_pi); 303 | float half_dim = 0.5f*(float)out_dim / scale; 304 | 305 | cv::Point2f pts1[4], pts2[4]; 306 | pts1[0] = cv::Point2f(x_center - half_dim, y_center - half_dim); 307 | pts1[1] = cv::Point2f(x_center + half_dim, y_center - half_dim); 308 | pts1[2] = cv::Point2f(x_center + half_dim, y_center + half_dim); 309 | pts1[3] = cv::Point2f(x_center - half_dim, y_center + half_dim); 310 | 311 | pts2[0] = cv::Point2f(-half_dim, -half_dim); 312 | pts2[1] = cv::Point2f(half_dim, -half_dim); 313 | pts2[2] = cv::Point2f(half_dim, half_dim); 314 | pts2[3] = cv::Point2f(-half_dim, half_dim); 315 | 316 | // rotate around center spot 317 | for (int pt = 0; pt<4; pt++) 318 | { 319 | float x_t = (pts2[pt].x)*scale; 320 | float y_t = (pts2[pt].y)*scale; 321 | float x = cos_theta*x_t - sin_theta*y_t; 322 | float y = sin_theta*x_t + cos_theta*y_t; 323 | 324 | pts2[pt].x = x + (float)x_center; 325 | pts2[pt].y = y + (float)y_center; 326 | 327 | // we want to control how data is scaled down 328 | // if (scale>1) 329 | // { 330 | // pts1[pt].x = pts1[pt].x / (float)scale; 331 | // pts1[pt].y = pts1[pt].y / (float)scale; 332 | // } 333 | } 334 | 335 | cv::Mat input = mojo::matrix2cv(in,false); 336 | 337 | // if (scale>1) 338 | // cv::resize(in, input, cv::Size(0, 0), 1. / scale, 1. / scale); 339 | // else 340 | // input = in; 341 | 342 | 343 | cv::Mat M = cv::getPerspectiveTransform(pts1, pts2); 344 | cv::Mat cv_out; 345 | 346 | cv::warpPerspective(input, cv_out, 347 | cv::getPerspectiveTransform(pts1, pts2), 348 | cv::Size((int)((float)out_dim), (int)((float)out_dim)), 349 | cv::INTER_AREA, cv::BORDER_REPLICATE); //cv::INTER_LINEAR 350 | 351 | //INTER_AREA 352 | 353 | 354 | // double min; 355 | // cv::minMaxIdx(cv_out, &min); 356 | // std::cout << "min: " << min << "||"; 357 | return mojo::cv2matrix(cv_out); 358 | } 359 | 360 | mojo::matrix bgr2ycrcb(mojo::matrix &m) 361 | { 362 | cv::Mat cv_m = matrix2cv(m,false); 363 | double min_, max_; 364 | cv_m = cv_m.reshape(1); 365 | cv::minMaxIdx(cv_m, &min_, &max_); 366 | cv_m = cv_m - min_; 367 | max_ = max_ - min_; 368 | cv_m /= max_; 369 | 370 | cv_m = cv_m.reshape(m.chans, m.rows); 371 | cv::Mat cv_Y; 372 | cv::cvtColor(cv_m, cv_Y, CV_BGR2YCrCb); 373 | cv_Y = cv_Y.reshape(1); 374 | cv_Y -= 0.5f; 375 | cv_Y *= 2.f; 376 | cv_Y = cv_Y.reshape(m.chans); 377 | 378 | m = cv2matrix(cv_Y); 379 | return m; 380 | } 381 | 382 | void save(mojo::matrix &m, std::string filename) 383 | { 384 | cv::Mat m2 = matrix2cv(m,true); 385 | //cv::resize(m2, m2, cv::Size(0, 0), 4, 4); 386 | cv::imwrite(filename, m2); 387 | } 388 | 389 | void show(const mojo::matrix &m, float zoom = 1.0f, const char *win_name = "", int wait_ms=1) 390 | { 391 | if (m.cols <= 0 || m.rows <= 0 || m.chans <= 0) return; 392 | cv::Mat cv_m = matrix2cv(m,false); 393 | 394 | double min_, max_; 395 | cv_m = cv_m.reshape(1); 396 | cv::minMaxIdx(cv_m, &min_, &max_); 397 | cv_m = cv_m - min_; 398 | max_ = max_ - min_; 399 | cv_m /= max_; 400 | // cv_m += 1.f; 401 | // cv_m *= 0.5; 402 | cv_m = cv_m.reshape(m.chans, m.rows); 403 | 404 | if (zoom != 1.f) cv::resize(cv_m, cv_m, cv::Size(0, 0), zoom, zoom,0); 405 | cv::imshow(win_name, cv_m); 406 | cv::waitKey(wait_ms); 407 | } 408 | 409 | // null name hides all windows 410 | void hide(const char *win_name = "") 411 | { 412 | if (win_name == NULL) cv::destroyAllWindows(); 413 | else cv::destroyWindow(win_name); 414 | } 415 | 416 | enum mojo_palette{ gray=0, hot=1, tensorglow=2, voodoo=3, saltnpepa=4}; 417 | 418 | 419 | cv::Mat colorize(cv::Mat im, mojo::mojo_palette color_palette = mojo::gray) 420 | { 421 | 422 | if (im.cols <= 0 || im.rows <= 0) return im; 423 | 424 | cv::Mat RGB[3]; 425 | RGB[0] = im.clone(); // blue 426 | RGB[1] = im.clone(); 427 | RGB[2] = im.clone(); 428 | 429 | for (int i = 0; i < im.rows*im.cols; i++) 430 | { 431 | unsigned char c = (unsigned char)im.data[i]; 432 | // tensor flow colors (red black blue) 433 | if (color_palette == mojo::tensorglow) 434 | { 435 | if (c == 255) { RGB[2].data[i] = 255; RGB[1].data[i] = 255; RGB[0].data[i] = 255; } 436 | else if (c < 128) { RGB[2].data[i] = 0; RGB[1].data[i] = 0; RGB[0].data[i] = 2*(127 - c); } 437 | else { RGB[2].data[i] = 2* (c - 128); RGB[1].data[i] = 0; RGB[0].data[i] = 0; } 438 | } 439 | else if (color_palette == mojo::hot) 440 | { 441 | if (c == 255) { RGB[2].data[i] = 255; RGB[1].data[i] = 255; RGB[0].data[i] = 255; } 442 | else if (c < 128) { RGB[0].data[i] = 0; RGB[1].data[i] = 0; RGB[2].data[i] = c * 2; } 443 | else { RGB[0].data[i] = 0; RGB[1].data[i] = (c - 128) * 2; RGB[2].data[i] = 255; } 444 | } 445 | else if (color_palette == mojo::saltnpepa) 446 | { 447 | if (c == 255) { RGB[2].data[i] = 255; RGB[1].data[i] = 255; RGB[0].data[i] = 255; } 448 | else if (c&1) { RGB[0].data[i] = 0; RGB[1].data[i] = 0; RGB[2].data[i] = 0; } 449 | else { RGB[0].data[i] = 255; RGB[1].data[i] = 255; RGB[2].data[i] = 255; } 450 | } 451 | else if (color_palette == mojo::voodoo) 452 | { 453 | if (c == 255) { RGB[2].data[i] = 255; RGB[1].data[i] = 255; RGB[0].data[i] = 255; } 454 | else if (c < 128) 455 | { RGB[2].data[i] = (127-c); RGB[1].data[i] = 0; RGB[0].data[i] = 2 * (127 - c); } 456 | else { RGB[2].data[i] = (c - 128); RGB[1].data[i] = 2*(c-128); RGB[0].data[i] = 0; } 457 | } 458 | } 459 | 460 | cv::Mat out; 461 | cv::merge(RGB, 3, out); 462 | return out; 463 | //cv::applyColorMap(im, im, cv::COLORMAP_WINTER);// COLORMAP_HOT); // cv::COLORMAP_JET); COLORMAP_RAINBOW 464 | } 465 | 466 | mojo::matrix draw_cnn_weights(mojo::network &cnn, int layer_index, mojo::mojo_palette color_palette=mojo::gray) 467 | { 468 | int w = (int)cnn.W.size(); 469 | cv::Mat im; 470 | 471 | 472 | std::vector im_layers; 473 | 474 | int layers = (int)cnn.layer_sets[0].size(); 475 | //for (int k = 0; k < layers; k++) 476 | int k=layer_index-1; 477 | { 478 | base_layer *layer = cnn.layer_sets[0][k]; 479 | // if (dynamic_cast (layer) == NULL) return mojo::matrix();// continue; 480 | 481 | __for__(auto &link __in__ layer->forward_linked_layers) 482 | { 483 | int connection_index = link.first; 484 | base_layer *p_bottom = link.second; 485 | if (!p_bottom->has_weights()) continue; 486 | for (auto i = 0; i < cnn.W[connection_index]->chans; i++) 487 | { 488 | cv::Mat im = matrix2cv(cnn.W[connection_index]->get_chans(i), true); 489 | cv::resize(im, im, cv::Size(0, 0), 2., 2., 0); 490 | im_layers.push_back(im); 491 | } 492 | // draw these nicely 493 | int s = im_layers[0].cols; 494 | cv::Mat tmp(layer->node.chans*(s + 1) + 1, p_bottom->node.chans*(s+1)+1, CV_8UC1);// = im.clone(); 495 | tmp = 255; 496 | for (int j = 0; j < layer->node.chans; j++) 497 | { 498 | for (int i = 0; i < p_bottom->node.chans; i++) 499 | { 500 | // make colors go 0 to 254 501 | double min, max; 502 | int index = i+j*p_bottom->node.chans; 503 | cv::minMaxIdx(im_layers[index], &min, &max); 504 | im_layers[index] -= min; 505 | im_layers[index] /= (max - min) / 254; 506 | 507 | im_layers[index].convertTo(im_layers[index], CV_8UC1); 508 | im_layers[index].copyTo(tmp(cv::Rect(i*s + 1 + i, j*s+ 1+j, s, s))); 509 | } 510 | } 511 | im = tmp; 512 | } 513 | } 514 | /* 515 | int imgs = (int)im_layers.size(); 516 | cv::Mat im; 517 | if (imgs <= 0) return im; 518 | 519 | im = im_layers[0].clone(); //(im_layers[0].rows, im_layers[0].cols, CV_8UC1); 520 | int W = im.cols; 521 | 522 | if (W<400) 523 | { 524 | W = 400; 525 | float S = (float)W / (float)im.cols; 526 | cv::resize(im, im, cv::Size(W, (int)(S*im.rows)), 0, 0, 0); 527 | } 528 | 529 | for (auto i = 1; i= layers) return mojo::matrix(); 563 | 564 | std::vector im_layers; 565 | base_layer *layer = cnn.layer_sets[0][layer_index]; 566 | 567 | for (int i = 0; i < layer->node.chans; i++) 568 | { 569 | cv::Mat im = matrix2cv(layer->node.get_chans(i), true); 570 | cv::resize(im, im, cv::Size(0, 0), 2., 2., 0); 571 | im_layers.push_back(im); 572 | } 573 | // draw these nicely 574 | int s = im_layers[0].cols; 575 | cv::Mat tmp(s + 2, (int)im_layers.size()*(1+s) + 1, CV_8UC1);// = im.clone(); 576 | tmp = 255; 577 | for (int i = 0; i < im_layers.size(); i++) 578 | { 579 | // make colors go 0 to 254 580 | double min, max; 581 | cv::minMaxIdx(im_layers[i], &min, &max); 582 | im_layers[i] -= min; 583 | im_layers[i] /= (max - min) / 254; 584 | 585 | im_layers[i].convertTo(im_layers[i], CV_8UC1); 586 | im_layers[i].copyTo(tmp(cv::Rect(i*s + 1 + i, 1, s, s))); 587 | } 588 | im = tmp; 589 | 590 | return cv2matrix(colorize(im, color_palette)); 591 | } 592 | 593 | mojo::matrix draw_cnn_state(mojo::network &cnn, std::string layer_name, mojo::mojo_palette color_palette = mojo::gray) 594 | { 595 | int layer_index = cnn.layer_map[layer_name]; 596 | return draw_cnn_state(cnn, layer_index, color_palette); 597 | } 598 | 599 | 600 | 601 | #endif // MOJO_CV# 602 | 603 | }// namespace --------------------------------------------------------------------------------