├── .DS_Store ├── Client-Android ├── .gitignore ├── .idea │ ├── .gitignore │ ├── codeStyles │ │ └── Project.xml │ ├── compiler.xml │ ├── gradle.xml │ ├── jarRepositories.xml │ ├── misc.xml │ ├── runConfigurations.xml │ └── vcs.xml ├── app │ ├── .gitignore │ ├── CMakeLists.txt │ ├── build.gradle │ ├── includes │ │ ├── MNN │ │ │ ├── AutoTime.hpp │ │ │ ├── ErrorCode.hpp │ │ │ ├── HalideRuntime.h │ │ │ ├── ImageProcess.hpp │ │ │ ├── Interpreter.hpp │ │ │ ├── MNNDefine.h │ │ │ ├── MNNForwardType.h │ │ │ ├── Matrix.h │ │ │ ├── Rect.h │ │ │ ├── Tensor.hpp │ │ │ ├── expr │ │ │ │ ├── Executor.hpp │ │ │ │ ├── Expr.hpp │ │ │ │ ├── ExprCreator.hpp │ │ │ │ ├── MathOp.hpp │ │ │ │ ├── NeuralNetWorkOp.hpp │ │ │ │ └── Optimizer.hpp │ │ │ └── plugin │ │ │ │ ├── PluginContext.hpp │ │ │ │ ├── PluginKernel.hpp │ │ │ │ └── PluginShapeInference.hpp │ │ └── MNNTrain │ │ │ ├── BlockingQueue.hpp │ │ │ ├── DataLoader.hpp │ │ │ ├── DataLoaderConfig.hpp │ │ │ ├── Dataset.hpp │ │ │ ├── Distributions.hpp │ │ │ ├── Example.hpp │ │ │ ├── ImageDataset.hpp │ │ │ ├── LearningRateScheduler.hpp │ │ │ ├── Lenet.hpp │ │ │ ├── Loss.hpp │ │ │ ├── MnistDataset.hpp │ │ │ ├── Module.hpp │ │ │ ├── NN.hpp │ │ │ ├── ParameterOptimizer.hpp │ │ │ ├── PipelineModule.hpp │ │ │ ├── SGD.hpp │ │ │ └── Transformer.hpp │ ├── libs │ │ └── arm64-v8a │ │ │ ├── libMNN.so │ │ │ ├── libMNNTrain.so │ │ │ └── libMNN_Express.so │ ├── proguard-rules.pro │ └── src │ │ ├── androidTest │ │ └── java │ │ │ └── com │ │ │ └── example │ │ │ └── websocket │ │ │ └── ExampleInstrumentedTest.java │ │ ├── main │ │ ├── AndroidManifest.xml │ │ ├── java │ │ │ └── com │ │ │ │ ├── demo │ │ │ │ ├── App.java │ │ │ │ ├── EmptyActivity.java │ │ │ │ ├── MainActivity.java │ │ │ │ └── MnnTrainFragment.java │ │ │ │ └── example │ │ │ │ ├── MainActivity_test.java │ │ │ │ ├── nativemnn │ │ │ │ ├── mnn │ │ │ │ │ └── MNNDataNative.java │ │ │ │ └── utils │ │ │ │ │ └── Common.java │ │ │ │ └── websocket │ │ │ │ ├── client │ │ │ │ └── ClientWebSocketListener.java │ │ │ │ ├── constants │ │ │ │ └── Constants.java │ │ │ │ ├── service │ │ │ │ ├── ClientWebSocketService.java │ │ │ │ └── MnnTrainService.java │ │ │ │ └── utils │ │ │ │ ├── CommonUtil.java │ │ │ │ ├── DeviceUtil.java │ │ │ │ └── FileUtil.java │ │ ├── jni │ │ │ └── mnndatanative.cpp │ │ └── res │ │ │ ├── drawable-v24 │ │ │ └── ic_launcher_foreground.xml │ │ │ ├── drawable │ │ │ └── ic_launcher_background.xml │ │ │ ├── layout │ │ │ ├── activity_empty.xml │ │ │ ├── activity_main.xml │ │ │ └── fragment_mnn_train.xml │ │ │ ├── mipmap-anydpi-v26 │ │ │ ├── ic_launcher.xml │ │ │ └── ic_launcher_round.xml │ │ │ ├── mipmap-hdpi │ │ │ ├── ic_launcher.png │ │ │ └── ic_launcher_round.png │ │ │ ├── mipmap-mdpi │ │ │ ├── ic_launcher.png │ │ │ └── ic_launcher_round.png │ │ │ ├── mipmap-xhdpi │ │ │ ├── ic_launcher.png │ │ │ └── ic_launcher_round.png │ │ │ ├── mipmap-xxhdpi │ │ │ ├── ic_launcher.png │ │ │ └── ic_launcher_round.png │ │ │ ├── mipmap-xxxhdpi │ │ │ ├── ic_launcher.png │ │ │ └── ic_launcher_round.png │ │ │ └── values │ │ │ ├── colors.xml │ │ │ ├── strings.xml │ │ │ └── styles.xml │ │ └── test │ │ └── java │ │ └── com │ │ └── example │ │ └── websocket │ │ └── ExampleUnitTest.java ├── build.gradle ├── gradle.properties ├── gradle │ └── wrapper │ │ ├── gradle-wrapper.jar │ │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat └── settings.gradle ├── README.md ├── README_CN.md ├── Server-Python ├── model │ ├── init.mnn │ └── mnist.snapshot.mnn ├── py │ ├── __pycache__ │ │ ├── clientThread.cpython-38.pyc │ │ ├── parameter.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── aggregateModel.out │ ├── clientThread.py │ ├── parameter.py │ ├── server.py │ └── utils.py ├── result │ ├── acc.npy │ └── time.npy ├── static │ └── picture │ │ ├── 0.bmp │ │ ├── 2.bmp │ │ ├── 4.bmp │ │ ├── 6.bmp │ │ └── 8.bmp └── test.html └── data ├── mnist.snapshot.mnn └── mnist_data ├── .DS_Store ├── t10k-images-idx3-ubyte ├── t10k-labels-idx1-ubyte ├── train-images-idx3-ubyte └── train-labels-idx1-ubyte /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/.DS_Store -------------------------------------------------------------------------------- /Client-Android/.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | .gradle 3 | /local.properties 4 | /.idea/caches 5 | /.idea/libraries 6 | /.idea/modules.xml 7 | /.idea/workspace.xml 8 | /.idea/navEditor.xml 9 | /.idea/assetWizardSettings.xml 10 | .DS_Store 11 | /build 12 | /captures 13 | .externalNativeBuild 14 | .cxx 15 | -------------------------------------------------------------------------------- /Client-Android/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /Client-Android/.idea/codeStyles/Project.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 20 | 21 | 22 | 23 | 25 | 26 | 27 |
28 | 29 | 30 | 31 | xmlns:android 32 | 33 | ^$ 34 | 35 | 36 | 37 |
38 |
39 | 40 | 41 | 42 | xmlns:.* 43 | 44 | ^$ 45 | 46 | 47 | BY_NAME 48 | 49 |
50 |
51 | 52 | 53 | 54 | .*:id 55 | 56 | http://schemas.android.com/apk/res/android 57 | 58 | 59 | 60 |
61 |
62 | 63 | 64 | 65 | .*:name 66 | 67 | http://schemas.android.com/apk/res/android 68 | 69 | 70 | 71 |
72 |
73 | 74 | 75 | 76 | name 77 | 78 | ^$ 79 | 80 | 81 | 82 |
83 |
84 | 85 | 86 | 87 | style 88 | 89 | ^$ 90 | 91 | 92 | 93 |
94 |
95 | 96 | 97 | 98 | .* 99 | 100 | ^$ 101 | 102 | 103 | BY_NAME 104 | 105 |
106 |
107 | 108 | 109 | 110 | .* 111 | 112 | http://schemas.android.com/apk/res/android 113 | 114 | 115 | ANDROID_ATTRIBUTE_ORDER 116 | 117 |
118 |
119 | 120 | 121 | 122 | .* 123 | 124 | .* 125 | 126 | 127 | BY_NAME 128 | 129 |
130 |
131 |
132 |
133 |
134 |
-------------------------------------------------------------------------------- /Client-Android/.idea/compiler.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Client-Android/.idea/gradle.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 20 | 21 | -------------------------------------------------------------------------------- /Client-Android/.idea/jarRepositories.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 9 | 10 | 14 | 15 | 19 | 20 | 24 | 25 | -------------------------------------------------------------------------------- /Client-Android/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 9 | -------------------------------------------------------------------------------- /Client-Android/.idea/runConfigurations.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 12 | 13 | -------------------------------------------------------------------------------- /Client-Android/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Client-Android/app/.gitignore: -------------------------------------------------------------------------------- 1 | /build -------------------------------------------------------------------------------- /Client-Android/app/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.4.1) 2 | 3 | set(lib_DIR ${CMAKE_SOURCE_DIR}/libs) 4 | include_directories(${CMAKE_SOURCE_DIR}/includes) 5 | 6 | add_library( MNN SHARED IMPORTED ) 7 | set_target_properties( 8 | MNN 9 | PROPERTIES IMPORTED_LOCATION 10 | ${lib_DIR}/${ANDROID_ABI}/libMNN.so 11 | ) 12 | 13 | add_library( MNNTrain SHARED IMPORTED ) 14 | set_target_properties( 15 | MNNTrain 16 | PROPERTIES IMPORTED_LOCATION 17 | ${lib_DIR}/${ANDROID_ABI}/libMNNTrain.so 18 | ) 19 | 20 | add_library( MNN_Express SHARED IMPORTED ) 21 | set_target_properties( 22 | MNN_Express 23 | PROPERTIES IMPORTED_LOCATION 24 | ${lib_DIR}/${ANDROID_ABI}/libMNN_Express.so 25 | ) 26 | 27 | #file(GLOB_RECURSE CPP_SRCS src/main/jni/*.cpp) 28 | file(GLOB_RECURSE CPP_SRCS src/main/jni/mnndatanative.cpp) 29 | add_library( mnncore SHARED ${CPP_SRCS} ) 30 | 31 | find_library( log-lib log ) 32 | find_library( jnigraphics-lib jnigraphics ) 33 | 34 | add_definitions(-DMNN_USE_LOGCAT) 35 | target_link_libraries( mnncore MNN ${log-lib} ${jnigraphics-lib}) 36 | target_link_libraries( mnncore MNNTrain ${log-lib} ${jnigraphics-lib}) 37 | target_link_libraries( mnncore MNN_Express ${log-lib} ${jnigraphics-lib}) -------------------------------------------------------------------------------- /Client-Android/app/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'com.android.application' 2 | 3 | android { 4 | compileSdkVersion 30 5 | buildToolsVersion "30.0.2" 6 | 7 | defaultConfig { 8 | applicationId "com.demo" 9 | minSdkVersion 16 10 | targetSdkVersion 30 11 | versionCode 1 12 | versionName "1.0" 13 | 14 | testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" 15 | 16 | externalNativeBuild { 17 | cmake { 18 | cppFlags "" 19 | abiFilters 'arm64-v8a' 20 | } 21 | } 22 | ndk{ 23 | abiFilters 'arm64-v8a' 24 | } 25 | } 26 | 27 | buildTypes { 28 | debug { 29 | ndk { 30 | abiFilters 'arm64-v8a' 31 | } 32 | } 33 | release { 34 | minifyEnabled false 35 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' 36 | } 37 | } 38 | 39 | externalNativeBuild { 40 | cmake { 41 | path "CMakeLists.txt" 42 | } 43 | } 44 | 45 | sourceSets { 46 | main { 47 | jniLibs.srcDirs = ['libs'] 48 | } 49 | } 50 | 51 | packagingOptions { 52 | pickFirst 'lib/arm64-v8a/libMNNTrain.so' 53 | pickFirst 'lib/arm64-v8a/libMNN_Express.so' 54 | pickFirst 'lib/arm64-v8a/libMNN.so' 55 | } 56 | 57 | dataBinding{ 58 | enabled = true 59 | } 60 | } 61 | 62 | dependencies { 63 | implementation fileTree(dir: "libs", include: ["*.jar"]) 64 | implementation 'androidx.appcompat:appcompat:1.1.0' 65 | implementation 'androidx.constraintlayout:constraintlayout:1.1.3' 66 | testImplementation 'junit:junit:4.12' 67 | androidTestImplementation 'androidx.test.ext:junit:1.1.1' 68 | androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0' 69 | implementation 'com.google.code.gson:gson:2.8.6' 70 | implementation 'androidx.legacy:legacy-support-v4:1.0.0' 71 | implementation 'com.google.android.material:material:1.0.0' 72 | implementation 'androidx.cardview:cardview:1.0.0' 73 | 74 | implementation("com.squareup.okhttp3:okhttp:4.9.0") 75 | 76 | 77 | } -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/AutoTime.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // AutoTime.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/07/27. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef AutoTime_hpp 10 | #define AutoTime_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | namespace MNN { 17 | 18 | /** time tracing util. prints duration between init and deinit. */ 19 | class MNN_PUBLIC AutoTime { 20 | public: 21 | AutoTime(int line, const char* func); 22 | ~AutoTime(); 23 | AutoTime(const AutoTime&) = delete; 24 | AutoTime(const AutoTime&&) = delete; 25 | AutoTime& operator=(const AutoTime&) = delete; 26 | AutoTime& operator=(const AutoTime&&) = delete; 27 | 28 | private: 29 | int mLine; 30 | char* mName; 31 | uint64_t mCurrentTime; 32 | }; 33 | } // namespace MNN 34 | 35 | #ifdef MNN_OPEN_TIME_TRACE 36 | #define AUTOTIME MNN::AutoTime ___t(__LINE__, __func__) 37 | #else 38 | #define AUTOTIME 39 | #endif 40 | 41 | #endif /* AutoTime_hpp */ 42 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/ErrorCode.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ErrorCode.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/09/18. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef ErrorCode_h 10 | #define ErrorCode_h 11 | 12 | namespace MNN { 13 | enum ErrorCode { 14 | #ifdef NO_ERROR 15 | #undef NO_ERROR 16 | #endif //NO_ERROR 17 | NO_ERROR = 0, 18 | OUT_OF_MEMORY = 1, 19 | NOT_SUPPORT = 2, 20 | COMPUTE_SIZE_ERROR = 3, 21 | NO_EXECUTION = 4, 22 | 23 | // User error 24 | INPUT_DATA_ERROR = 10, 25 | CALL_BACK_STOP = 11, 26 | 27 | // Op Resize Error 28 | TENSOR_NOT_SUPPORT = 20, 29 | TENSOR_NEED_DIVIDE = 21, 30 | }; 31 | } 32 | 33 | #endif /* ErrorCode_h */ 34 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/HalideRuntime.h: -------------------------------------------------------------------------------- 1 | #ifndef HALIDE_HALIDERUNTIME_H 2 | #define HALIDE_HALIDERUNTIME_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #ifdef __cplusplus 9 | extern "C" { 10 | #endif 11 | 12 | // Note that you should not use "inline" along with HALIDE_ALWAYS_INLINE; 13 | // it is not necessary, and may produce warnings for some build configurations. 14 | #ifdef _MSC_VER 15 | #define HALIDE_ALWAYS_INLINE __forceinline 16 | #define HALIDE_NEVER_INLINE __declspec(noinline) 17 | #else 18 | #define HALIDE_ALWAYS_INLINE __attribute__((always_inline)) inline 19 | #define HALIDE_NEVER_INLINE __attribute__((noinline)) 20 | #endif 21 | 22 | /** \file 23 | * 24 | * This file declares the routines used by Halide internally in its 25 | * runtime. On platforms that support weak linking, these can be 26 | * replaced with user-defined versions by defining an extern "C" 27 | * function with the same name and signature. 28 | * 29 | * When doing Just In Time (JIT) compilation methods on the Func being 30 | * compiled must be called instead. The corresponding methods are 31 | * documented below. 32 | * 33 | * All of these functions take a "void *user_context" parameter as their 34 | * first argument; if the Halide kernel that calls back to any of these 35 | * functions has been compiled with the UserContext feature set on its Target, 36 | * then the value of that pointer passed from the code that calls the 37 | * Halide kernel is piped through to the function. 38 | * 39 | * Some of these are also useful to call when using the default 40 | * implementation. E.g. halide_shutdown_thread_pool. 41 | * 42 | * Note that even on platforms with weak linking, some linker setups 43 | * may not respect the override you provide. E.g. if the override is 44 | * in a shared library and the halide object files are linked directly 45 | * into the output, the builtin versions of the runtime functions will 46 | * be called. See your linker documentation for more details. On 47 | * Linux, LD_DYNAMIC_WEAK=1 may help. 48 | * 49 | */ 50 | 51 | // Forward-declare to suppress warnings if compiling as C. 52 | struct halide_buffer_t; 53 | 54 | /** Types in the halide type system. They can be ints, unsigned ints, 55 | * or floats (of various bit-widths), or a handle (which is always 64-bits). 56 | * Note that the int/uint/float values do not imply a specific bit width 57 | * (the bit width is expected to be encoded in a separate value). 58 | */ 59 | typedef enum halide_type_code_t 60 | { 61 | halide_type_int = 0, //!< signed integers 62 | halide_type_uint = 1, //!< unsigned integers 63 | halide_type_float = 2, //!< floating point numbers 64 | halide_type_handle = 3 //!< opaque pointer type (void *) 65 | } halide_type_code_t; 66 | 67 | // Note that while __attribute__ can go before or after the declaration, 68 | // __declspec apparently is only allowed before. 69 | #ifndef HALIDE_ATTRIBUTE_ALIGN 70 | #ifdef _MSC_VER 71 | #define HALIDE_ATTRIBUTE_ALIGN(x) __declspec(align(x)) 72 | #else 73 | #define HALIDE_ATTRIBUTE_ALIGN(x) __attribute__((aligned(x))) 74 | #endif 75 | #endif 76 | 77 | /** A runtime tag for a type in the halide type system. Can be ints, 78 | * unsigned ints, or floats of various bit-widths (the 'bits' 79 | * field). Can also be vectors of the same (by setting the 'lanes' 80 | * field to something larger than one). This struct should be 81 | * exactly 32-bits in size. */ 82 | struct halide_type_t { 83 | /** The basic type code: signed integer, unsigned integer, or floating point. */ 84 | #if __cplusplus >= 201103L 85 | HALIDE_ATTRIBUTE_ALIGN(1) halide_type_code_t code; // halide_type_code_t 86 | #else 87 | HALIDE_ATTRIBUTE_ALIGN(1) uint8_t code; // halide_type_code_t 88 | #endif 89 | 90 | /** The number of bits of precision of a single scalar value of this type. */ 91 | HALIDE_ATTRIBUTE_ALIGN(1) uint8_t bits; 92 | 93 | /** How many elements in a vector. This is 1 for scalar types. */ 94 | HALIDE_ATTRIBUTE_ALIGN(2) uint16_t lanes; 95 | 96 | #ifdef __cplusplus 97 | /** Construct a runtime representation of a Halide type from: 98 | * code: The fundamental type from an enum. 99 | * bits: The bit size of one element. 100 | * lanes: The number of vector elements in the type. */ 101 | HALIDE_ALWAYS_INLINE halide_type_t(halide_type_code_t code, uint8_t bits, uint16_t lanes = 1) 102 | : code(code), bits(bits), lanes(lanes) { 103 | } 104 | 105 | /** Default constructor is required e.g. to declare halide_trace_event 106 | * instances. */ 107 | HALIDE_ALWAYS_INLINE halide_type_t() : code((halide_type_code_t)0), bits(0), lanes(0) {} 108 | 109 | /** Compare two types for equality. */ 110 | HALIDE_ALWAYS_INLINE bool operator==(const halide_type_t &other) const { 111 | return (code == other.code && 112 | bits == other.bits && 113 | lanes == other.lanes); 114 | } 115 | 116 | HALIDE_ALWAYS_INLINE bool operator!=(const halide_type_t &other) const { 117 | return !(*this == other); 118 | } 119 | 120 | /** Size in bytes for a single element, even if width is not 1, of this type. */ 121 | HALIDE_ALWAYS_INLINE int bytes() const { return (bits + 7) / 8; } 122 | #endif 123 | }; 124 | 125 | /** An opaque struct containing per-GPU API implementations of the 126 | * device functions. */ 127 | struct halide_device_interface_impl_t; 128 | 129 | /** Each GPU API provides a halide_device_interface_t struct pointing 130 | * to the code that manages device allocations. You can access these 131 | * functions directly from the struct member function pointers, or by 132 | * calling the functions declared below. Note that the global 133 | * functions are not available when using Halide as a JIT compiler. 134 | * If you are using raw halide_buffer_t in that context you must use 135 | * the function pointers in the device_interface struct. 136 | * 137 | * The function pointers below are currently the same for every GPU 138 | * API; only the impl field varies. These top-level functions do the 139 | * bookkeeping that is common across all GPU APIs, and then dispatch 140 | * to more API-specific functions via another set of function pointers 141 | * hidden inside the impl field. 142 | */ 143 | struct halide_device_interface_t { 144 | int (*device_malloc)(void *user_context, struct halide_buffer_t *buf, 145 | const struct halide_device_interface_t *device_interface); 146 | int (*device_free)(void *user_context, struct halide_buffer_t *buf); 147 | int (*device_sync)(void *user_context, struct halide_buffer_t *buf); 148 | void (*device_release)(void *user_context, 149 | const struct halide_device_interface_t *device_interface); 150 | int (*copy_to_host)(void *user_context, struct halide_buffer_t *buf); 151 | int (*copy_to_device)(void *user_context, struct halide_buffer_t *buf, 152 | const struct halide_device_interface_t *device_interface); 153 | int (*device_and_host_malloc)(void *user_context, struct halide_buffer_t *buf, 154 | const struct halide_device_interface_t *device_interface); 155 | int (*device_and_host_free)(void *user_context, struct halide_buffer_t *buf); 156 | int (*buffer_copy)(void *user_context, struct halide_buffer_t *src, 157 | const struct halide_device_interface_t *dst_device_interface, struct halide_buffer_t *dst); 158 | int (*device_crop)(void *user_context, const struct halide_buffer_t *src, 159 | struct halide_buffer_t *dst); 160 | int (*device_release_crop)(void *user_context, struct halide_buffer_t *buf); 161 | int (*wrap_native)(void *user_context, struct halide_buffer_t *buf, uint64_t handle, 162 | const struct halide_device_interface_t *device_interface); 163 | int (*detach_native)(void *user_context, struct halide_buffer_t *buf); 164 | const struct halide_device_interface_impl_t *impl; 165 | }; 166 | 167 | typedef struct halide_dimension_t { 168 | int32_t min, extent, stride; 169 | 170 | // Per-dimension flags. None are defined yet (This is reserved for future use). 171 | uint32_t flags; 172 | 173 | #ifdef __cplusplus 174 | HALIDE_ALWAYS_INLINE halide_dimension_t() : min(0), extent(0), stride(0), flags(0) {} 175 | HALIDE_ALWAYS_INLINE halide_dimension_t(int32_t m, int32_t e, int32_t s, uint32_t f = 0) : 176 | min(m), extent(e), stride(s), flags(f) {} 177 | 178 | HALIDE_ALWAYS_INLINE bool operator==(const halide_dimension_t &other) const { 179 | return (min == other.min) && 180 | (extent == other.extent) && 181 | (stride == other.stride) && 182 | (flags == other.flags); 183 | } 184 | 185 | HALIDE_ALWAYS_INLINE bool operator!=(const halide_dimension_t &other) const { 186 | return !(*this == other); 187 | } 188 | #endif 189 | } halide_dimension_t; 190 | 191 | #ifdef __cplusplus 192 | } // extern "C" 193 | #endif 194 | 195 | typedef enum {halide_buffer_flag_host_dirty = 1, 196 | halide_buffer_flag_device_dirty = 2} halide_buffer_flags; 197 | 198 | /** 199 | * The raw representation of an image passed around by generated 200 | * Halide code. It includes some stuff to track whether the image is 201 | * not actually in main memory, but instead on a device (like a 202 | * GPU). For a more convenient C++ wrapper, use Halide::Buffer. */ 203 | typedef struct halide_buffer_t { 204 | /** A device-handle for e.g. GPU memory used to back this buffer. */ 205 | uint64_t device; 206 | 207 | /** The interface used to interpret the above handle. */ 208 | const struct halide_device_interface_t *device_interface; 209 | 210 | /** A pointer to the start of the data in main memory. In terms of 211 | * the Halide coordinate system, this is the address of the min 212 | * coordinates (defined below). */ 213 | uint8_t* host; 214 | 215 | /** flags with various meanings. */ 216 | uint64_t flags; 217 | 218 | /** The type of each buffer element. */ 219 | struct halide_type_t type; 220 | 221 | /** The dimensionality of the buffer. */ 222 | int32_t dimensions; 223 | 224 | /** The shape of the buffer. Halide does not own this array - you 225 | * must manage the memory for it yourself. */ 226 | halide_dimension_t *dim; 227 | 228 | /** Pads the buffer up to a multiple of 8 bytes */ 229 | void *padding; 230 | } halide_buffer_t; 231 | 232 | 233 | #ifdef __cplusplus 234 | 235 | namespace { 236 | template struct check_is_pointer; 237 | template struct check_is_pointer {}; 238 | } 239 | 240 | /** Construct the halide equivalent of a C type */ 241 | template 242 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 243 | // Create a compile-time error if T is not a pointer (without 244 | // using any includes - this code goes into the runtime). 245 | check_is_pointer check; 246 | (void)check; 247 | return halide_type_t(halide_type_handle, 64); 248 | } 249 | 250 | template<> 251 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 252 | return halide_type_t(halide_type_float, 32); 253 | } 254 | 255 | template<> 256 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 257 | return halide_type_t(halide_type_float, 64); 258 | } 259 | 260 | template<> 261 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 262 | return halide_type_t(halide_type_uint, 1); 263 | } 264 | 265 | template<> 266 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 267 | return halide_type_t(halide_type_uint, 8); 268 | } 269 | 270 | template<> 271 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 272 | return halide_type_t(halide_type_uint, 16); 273 | } 274 | 275 | template<> 276 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 277 | return halide_type_t(halide_type_uint, 32); 278 | } 279 | 280 | template<> 281 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 282 | return halide_type_t(halide_type_uint, 64); 283 | } 284 | 285 | template<> 286 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 287 | return halide_type_t(halide_type_int, 8); 288 | } 289 | 290 | template<> 291 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 292 | return halide_type_t(halide_type_int, 16); 293 | } 294 | 295 | template<> 296 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 297 | return halide_type_t(halide_type_int, 32); 298 | } 299 | 300 | template<> 301 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 302 | return halide_type_t(halide_type_int, 64); 303 | } 304 | 305 | #endif 306 | 307 | #endif // HALIDE_HALIDERUNTIME_H 308 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/ImageProcess.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ImageProcess.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/09/19. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef ImageProcess_hpp 10 | #define ImageProcess_hpp 11 | 12 | #include 13 | #include "Matrix.h" 14 | #include 15 | 16 | namespace MNN { 17 | namespace CV { 18 | enum ImageFormat { 19 | RGBA = 0, 20 | RGB, 21 | BGR, 22 | GRAY, 23 | BGRA, 24 | YUV_NV21 = 11, 25 | }; 26 | 27 | enum Filter { NEAREST = 0, BILINEAR = 1, BICUBIC = 2 }; 28 | 29 | enum Wrap { CLAMP_TO_EDGE = 0, ZERO = 1, REPEAT = 2 }; 30 | 31 | /** 32 | * handle image process for tensor. 33 | * step: 34 | * 1: Do transform compute and get points 35 | * 2: Sample line and do format convert 36 | * 3: Turn RGBA to float tensor, and do sub and normalize 37 | */ 38 | class MNN_PUBLIC ImageProcess { 39 | public: 40 | struct Inside; 41 | struct Config { 42 | /** data filter */ 43 | Filter filterType = NEAREST; 44 | /** format of source data */ 45 | ImageFormat sourceFormat = RGBA; 46 | /** format of destination data */ 47 | ImageFormat destFormat = RGBA; 48 | 49 | // Only valid if the dest type is float 50 | float mean[4] = {0.0f, 0.0f, 0.0f, 0.0f}; 51 | float normal[4] = {1.0f, 1.0f, 1.0f, 1.0f}; 52 | 53 | /** edge wrapper */ 54 | Wrap wrap = CLAMP_TO_EDGE; 55 | }; 56 | 57 | public: 58 | /** 59 | * @brief create image process with given config for given tensor. 60 | * @param config given config. 61 | * @param dstTensor given tensor. 62 | * @return image processor. 63 | */ 64 | static ImageProcess* create(const Config& config, const Tensor* dstTensor = nullptr); 65 | 66 | /** 67 | * @brief create image process with given config for given tensor. 68 | * @param means given means 69 | * @param meanCount given means count 70 | * @param normals given normals 71 | * @param normalCount given normal count 72 | * @param sourceFormat format of source data 73 | * @param destFormat format of destination data 74 | * @param dstTensor given tensor. 75 | * @return image processor. 76 | */ 77 | static ImageProcess* create(const ImageFormat sourceFormat = RGBA, const ImageFormat destFormat = RGBA, 78 | const float* means = nullptr, const int meanCount = 0, const float* normals = nullptr, 79 | const int normalCount = 0, const Tensor* dstTensor = nullptr); 80 | 81 | ~ImageProcess(); 82 | 83 | /** 84 | * @brief get affine transform matrix. 85 | * @return affine transform matrix. 86 | */ 87 | inline const Matrix& matrix() const { 88 | return mTransform; 89 | } 90 | void setMatrix(const Matrix& matrix); 91 | 92 | /** 93 | * @brief convert source data to given tensor. 94 | * @param source source data. 95 | * @param iw source width. 96 | * @param ih source height. 97 | * @param stride number of elements per row. eg: 100 width RGB contains at least 300 elements. 98 | * @param dest given tensor. 99 | * @return result code. 100 | */ 101 | ErrorCode convert(const uint8_t* source, int iw, int ih, int stride, Tensor* dest); 102 | 103 | /** 104 | * @brief create tensor with given data. 105 | * @param w image width. 106 | * @param h image height. 107 | * @param bpp bytes per pixel. 108 | * @param p pixel data pointer. 109 | * @return created tensor. 110 | */ 111 | template 112 | static Tensor* createImageTensor(int w, int h, int bpp, void* p = nullptr) { 113 | return createImageTensor(halide_type_of(), w, h, bpp, p); 114 | } 115 | static Tensor* createImageTensor(halide_type_t type, int w, int h, int bpp, void* p = nullptr); 116 | 117 | private: 118 | ImageProcess(const Config& config); 119 | Matrix mTransform; 120 | Matrix mTransformInvert; 121 | Inside* mInside; 122 | }; 123 | } // namespace CV 124 | } // namespace MNN 125 | 126 | #endif /* ImageProcess_hpp */ 127 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/Interpreter.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Interpreter.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/07/23. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef Interpreter_hpp 10 | #define Interpreter_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace MNN { 20 | 21 | /** session schedule config */ 22 | struct ScheduleConfig { 23 | /** which tensor should be kept */ 24 | std::vector saveTensors; 25 | /** forward type */ 26 | MNNForwardType type = MNN_FORWARD_CPU; 27 | /** number of threads in parallel */ 28 | int numThread = 4; 29 | 30 | /** subpath to run */ 31 | struct Path { 32 | std::vector inputs; 33 | std::vector outputs; 34 | 35 | enum Mode { 36 | /** 37 | * Op Mode 38 | * - inputs means the source op, can NOT be empty. 39 | * - outputs means the sink op, can be empty. 40 | * The path will start from source op, then flow when encounter the sink op. 41 | * The sink op will not be compute in this path. 42 | */ 43 | Op = 0, 44 | 45 | /** 46 | * Tensor Mode (NOT supported yet) 47 | * - inputs means the inputs tensors, can NOT be empty. 48 | * - outputs means the outputs tensors, can NOT be empty. 49 | * It will find the pipeline that compute outputs from inputs. 50 | */ 51 | Tensor = 1 52 | }; 53 | 54 | /** running mode */ 55 | Mode mode = Op; 56 | }; 57 | Path path; 58 | 59 | /** backup backend used to create execution when desinated backend do NOT support any op */ 60 | MNNForwardType backupType = MNN_FORWARD_CPU; 61 | 62 | /** extra backend config */ 63 | BackendConfig* backendConfig = nullptr; 64 | }; 65 | 66 | class Session; 67 | struct Content; 68 | class Tensor; 69 | class Backend; 70 | 71 | class MNN_PUBLIC OperatorInfo { 72 | struct Info; 73 | 74 | public: 75 | /** Operator's name*/ 76 | const std::string& name() const; 77 | 78 | /** Operator's type*/ 79 | const std::string& type() const; 80 | 81 | /** Operator's flops, in M*/ 82 | float flops() const; 83 | 84 | protected: 85 | OperatorInfo(); 86 | ~OperatorInfo(); 87 | Info* mContent; 88 | }; 89 | 90 | typedef std::function&, const std::string& /*opName*/)> TensorCallBack; 91 | typedef std::function&, const OperatorInfo*)> TensorCallBackWithInfo; 92 | 93 | /** net data holder. multiple sessions could share same net. */ 94 | class MNN_PUBLIC Interpreter { 95 | public: 96 | /** 97 | * @brief create net from file. 98 | * @param file given file. 99 | * @return created net if success, NULL otherwise. 100 | */ 101 | static Interpreter* createFromFile(const char* file); 102 | /** 103 | * @brief create net from buffer. 104 | * @param buffer given data buffer. 105 | * @param size size of data buffer. 106 | * @return created net if success, NULL otherwise. 107 | */ 108 | static Interpreter* createFromBuffer(const void* buffer, size_t size); 109 | ~Interpreter(); 110 | 111 | public: 112 | /** 113 | * @brief create session with schedule config. created session will be managed in net. 114 | * @param config session schedule config. 115 | * @return created session if success, NULL otherwise. 116 | */ 117 | Session* createSession(const ScheduleConfig& config); 118 | 119 | /** 120 | * @brief create multi-path session with schedule configs. created session will be managed in net. 121 | * @param configs session schedule configs. 122 | * @return created session if success, NULL otherwise. 123 | */ 124 | Session* createMultiPathSession(const std::vector& configs); 125 | 126 | /** 127 | * @brief release session. 128 | * @param session given session. 129 | * @return true if given session is held by net and is freed. 130 | */ 131 | bool releaseSession(Session* session); 132 | 133 | /** 134 | * @brief call this function to get tensors ready. output tensor buffer (host or deviceId) should be retrieved 135 | * after resize of any input tensor. 136 | * @param session given session. 137 | */ 138 | void resizeSession(Session* session); 139 | 140 | /** 141 | * @brief call this function if don't need resize or create session any more, it will save a few memory that equal 142 | * to the size of model buffer 143 | */ 144 | void releaseModel(); 145 | 146 | /** 147 | * @brief Get the model buffer for user to save 148 | * @return std::make_pair(modleBuffer, modelSize). 149 | * @example: 150 | * std::ofstream output("trainResult.alinn") 151 | * auto buffer = net->getModelBuffer(); 152 | * output.write((const char*)buffer.first, buffer.second); 153 | */ 154 | std::pair getModelBuffer() const; 155 | 156 | /** 157 | * @brief update Session's Tensor to model's Const Op 158 | * @param session given session. 159 | * @return result of running. 160 | */ 161 | ErrorCode updateSessionToModel(Session* session); 162 | 163 | /** 164 | * @brief run session. 165 | * @param session given session. 166 | * @return result of running. 167 | */ 168 | ErrorCode runSession(Session* session) const; 169 | 170 | /* 171 | * @brief run session. 172 | * @param session given session. 173 | * @param before callback before each op. return true to run the op; return false to skip the op. 174 | * @param after callback after each op. return true to continue running; return false to interrupt the session. 175 | * @param sync synchronously wait for finish of execution or not. 176 | * @return result of running. 177 | */ 178 | ErrorCode runSessionWithCallBack(const Session* session, const TensorCallBack& before, const TensorCallBack& end, 179 | bool sync = false) const; 180 | 181 | /* 182 | * @brief run session. 183 | * @param session given session. 184 | * @param before callback before each op. return true to run the op; return false to skip the op. 185 | * @param after callback after each op. return true to continue running; return false to interrupt the session. 186 | * @param sync synchronously wait for finish of execution or not. 187 | * @return result of running. 188 | */ 189 | ErrorCode runSessionWithCallBackInfo(const Session* session, const TensorCallBackWithInfo& before, 190 | const TensorCallBackWithInfo& end, bool sync = false) const; 191 | 192 | /** 193 | * @brief get input tensor for given name. 194 | * @param session given session. 195 | * @param name given name. if NULL, return first input. 196 | * @return tensor if found, NULL otherwise. 197 | */ 198 | Tensor* getSessionInput(const Session* session, const char* name); 199 | /** 200 | * @brief get output tensor for given name. 201 | * @param session given session. 202 | * @param name given name. if NULL, return first output. 203 | * @return tensor if found, NULL otherwise. 204 | */ 205 | Tensor* getSessionOutput(const Session* session, const char* name); 206 | 207 | /** 208 | * @brief get all input tensors. 209 | * @param session given session. 210 | * @return all input tensors mapped with name. 211 | */ 212 | const std::map& getSessionOutputAll(const Session* session) const; 213 | /** 214 | * @brief get all output tensors. 215 | * @param session given session. 216 | * @return all output tensors mapped with name. 217 | */ 218 | const std::map& getSessionInputAll(const Session* session) const; 219 | 220 | public: 221 | /** 222 | * @brief resize given tensor. 223 | * @param tensor given tensor. 224 | * @param dims new dims. at most 6 dims. 225 | */ 226 | void resizeTensor(Tensor* tensor, const std::vector& dims); 227 | 228 | /** 229 | * @brief resize given tensor by nchw. 230 | * @param batch / N. 231 | * @param channel / C. 232 | * @param height / H. 233 | * @param width / W 234 | */ 235 | void resizeTensor(Tensor* tensor, int batch, int channel, int height, int width); 236 | 237 | /** 238 | * @brief get backend used to create given tensor. 239 | * @param session given session. 240 | * @param tensor given tensor. 241 | * @return backend used to create given tensor, may be NULL. 242 | */ 243 | const Backend* getBackend(const Session* session, const Tensor* tensor) const; 244 | 245 | /** 246 | * @brief get business code (model identifier). 247 | * @return business code. 248 | */ 249 | const char* bizCode() const; 250 | 251 | private: 252 | static Interpreter* createFromBufferInternal(Content* net); 253 | 254 | Content* mNet = nullptr; 255 | Interpreter(Content* net); 256 | 257 | Interpreter(const Interpreter&) = delete; 258 | Interpreter(const Interpreter&&) = delete; 259 | Interpreter& operator=(const Interpreter&) = delete; 260 | Interpreter& operator=(const Interpreter&&) = delete; 261 | }; 262 | } // namespace MNN 263 | 264 | #endif /* Interpreter_hpp */ 265 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/MNNDefine.h: -------------------------------------------------------------------------------- 1 | // 2 | // MNNDefine.h 3 | // MNN 4 | // 5 | // Created by MNN on 2018/08/09. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNNDefine_h 10 | #define MNNDefine_h 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__APPLE__) 16 | #include "TargetConditionals.h" 17 | #if TARGET_OS_IPHONE 18 | #define MNN_BUILD_FOR_IOS 19 | #endif 20 | #endif 21 | 22 | #ifdef MNN_USE_LOGCAT 23 | #include 24 | #define MNN_ERROR(format, ...) __android_log_print(ANDROID_LOG_ERROR, "MNNJNI", format, ##__VA_ARGS__) 25 | #define MNN_PRINT(format, ...) __android_log_print(ANDROID_LOG_INFO, "MNNJNI", format, ##__VA_ARGS__) 26 | #else 27 | #define MNN_PRINT(format, ...) printf(format, ##__VA_ARGS__) 28 | #define MNN_ERROR(format, ...) printf(format, ##__VA_ARGS__) 29 | #endif 30 | 31 | #ifdef DEBUG 32 | #define MNN_ASSERT(x) \ 33 | { \ 34 | int res = (x); \ 35 | if (!res) { \ 36 | MNN_ERROR("Error for %s, %d\n", __FILE__, __LINE__); \ 37 | assert(res); \ 38 | } \ 39 | } 40 | #else 41 | #define MNN_ASSERT(x) \ 42 | { \ 43 | int res = (x); \ 44 | if (!res) { \ 45 | MNN_ERROR("Error for %s, %d\n", __FILE__, __LINE__); \ 46 | } \ 47 | } 48 | #endif 49 | 50 | #define FUNC_PRINT(x) MNN_PRINT(#x "=%d in %s, %d \n", x, __func__, __LINE__); 51 | #define FUNC_PRINT_ALL(x, type) MNN_PRINT(#x "=" #type " %" #type " in %s, %d \n", x, __func__, __LINE__); 52 | 53 | #if defined(_MSC_VER) 54 | #ifdef BUILDING_DLL 55 | #define MNN_PUBLIC __declspec(dllexport) 56 | #else 57 | #define MNN_PUBLIC __declspec(dllimport) 58 | #endif 59 | #else 60 | #define MNN_PUBLIC __attribute__((visibility("default"))) 61 | #endif 62 | 63 | #endif /* MNNDefine_h */ 64 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/MNNForwardType.h: -------------------------------------------------------------------------------- 1 | // 2 | // MNNForwardType.h 3 | // MNN 4 | // 5 | // Created by MNN on 2019/01/19. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNNForwardType_h 10 | #define MNNForwardType_h 11 | 12 | typedef enum { 13 | MNN_FORWARD_CPU = 0, 14 | 15 | /* 16 | Firtly find the first available backends not equal to CPU 17 | If no other backends, use cpu 18 | */ 19 | MNN_FORWARD_AUTO = 4, 20 | 21 | /*Hand write metal*/ 22 | MNN_FORWARD_METAL = 1, 23 | 24 | /*Use IOS's MPS instead of hand-write metal, Not Support yet*/ 25 | MNN_FORWARD_MPS = 2, 26 | 27 | /*Android / Common Device GPU API*/ 28 | MNN_FORWARD_OPENCL = 3, 29 | MNN_FORWARD_OPENGL = 6, 30 | MNN_FORWARD_VULKAN = 7, 31 | 32 | /*Android 8.1's NNAPI, Not Support yet*/ 33 | MNN_FORWARD_NN = 5, 34 | 35 | /*User can use API from Backend.hpp to add or search Backend*/ 36 | MNN_FORWARD_USER_0 = 8, 37 | MNN_FORWARD_USER_1 = 9, 38 | MNN_FORWARD_USER_2 = 10, 39 | MNN_FORWARD_USER_3 = 11, 40 | 41 | MNN_FORWARD_ALL 42 | } MNNForwardType; 43 | #ifdef __cplusplus 44 | namespace MNN { 45 | struct BackendConfig { 46 | enum MemoryMode { 47 | Memory_Normal = 0, 48 | Memory_High, 49 | Memory_Low 50 | }; 51 | 52 | MemoryMode memory = Memory_Normal; 53 | 54 | enum PowerMode { 55 | Power_Normal = 0, 56 | Power_High, 57 | Power_Low 58 | }; 59 | 60 | PowerMode power = Power_Normal; 61 | 62 | enum PrecisionMode { 63 | Precision_Normal = 0, 64 | Precision_High, 65 | Precision_Low 66 | }; 67 | 68 | PrecisionMode precision = Precision_Normal; 69 | 70 | /** user defined context */ 71 | void* sharedContext = nullptr; 72 | }; 73 | }; 74 | #endif 75 | #endif /* MNNForwardType_h */ 76 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/Tensor.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Tensor.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/08/14. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef Tensor_hpp 10 | #define Tensor_hpp 11 | 12 | #include 13 | #include "HalideRuntime.h" 14 | #include 15 | 16 | namespace MNN { 17 | 18 | /** 19 | * data container. 20 | * data for host tensor is saved in `host` field. its memory is allocated malloc directly. 21 | * data for device tensor is saved in `deviceId` field. its memory is allocated by session's backend. 22 | * usually, device tensors are created by engine (like net, session). 23 | * meanwhile, host tensors could be created by engine or user. 24 | */ 25 | class MNN_PUBLIC Tensor { 26 | public: 27 | struct InsideDescribe; 28 | 29 | /** dimension type used to create tensor */ 30 | enum DimensionType { 31 | /** for tensorflow net type. uses NHWC as data format. */ 32 | TENSORFLOW, 33 | /** for caffe net type. uses NCHW as data format. */ 34 | CAFFE, 35 | /** for caffe net type. uses NC4HW4 as data format. */ 36 | CAFFE_C4 37 | }; 38 | 39 | /** handle type */ 40 | enum HandleDataType { 41 | /** default handle type */ 42 | HANDLE_NONE = 0, 43 | /** string handle type */ 44 | HANDLE_STRING = 1 45 | }; 46 | 47 | /** dimension reorder flag */ 48 | enum DataReorderType { 49 | /** default reorder type, do not reorder */ 50 | NO_REORDER = 0, 51 | /** reorder dimension 4 by 4. usually used with NC4HW4 or NHWC4 while data type is float. */ 52 | REORDER_4 = 1, 53 | /** reorder dimension 8 by 8. usually used with NC4HW4 or NHWC4 while data type is uint8 or int8. */ 54 | REORDER_8 55 | }; 56 | 57 | public: 58 | /** 59 | * @brief create a tensor with dimension size and type without acquire memory for data. 60 | * @param dimSize dimension size. 61 | * @param type dimension type. 62 | */ 63 | Tensor(int dimSize = 4, DimensionType type = CAFFE); 64 | 65 | /** 66 | * @brief create a tensor with same shape as given tensor. 67 | * @param tensor shape provider. 68 | * @param type dimension type. 69 | * @param allocMemory acquire memory for data or not. 70 | * @warning tensor data won't be copied. 71 | */ 72 | Tensor(const Tensor* tensor, DimensionType type = CAFFE, bool allocMemory = true); 73 | 74 | /** deinitializer */ 75 | ~Tensor(); 76 | 77 | private: 78 | // remove all assignment operator 79 | Tensor(const Tensor& tensor) = delete; 80 | Tensor(const Tensor&& tensor) = delete; 81 | Tensor& operator=(const Tensor&) = delete; 82 | Tensor& operator=(const Tensor&&) = delete; 83 | 84 | public: 85 | /** 86 | * @brief create tensor with shape, data type and dimension type. 87 | * @param shape tensor shape. 88 | * @param type data type. 89 | * @param dimType dimension type. 90 | * @return created tensor. 91 | * @warning memory for data won't be acquired. call backend's onAcquireBuffer to get memory ready. 92 | */ 93 | static Tensor* createDevice(const std::vector& shape, halide_type_t type, DimensionType dimType = TENSORFLOW); 94 | 95 | /** 96 | * @brief create tensor with shape and dimension type. data type is represented by `T`. 97 | * @param shape tensor shape. 98 | * @param dimType dimension type. 99 | * @return created tensor. 100 | * @warning memory for data won't be acquired. call backend's onAcquireBuffer to get memory ready. 101 | */ 102 | template 103 | static Tensor* createDevice(const std::vector& shape, DimensionType dimType = TENSORFLOW) { 104 | return createDevice(shape, halide_type_of(), dimType); 105 | } 106 | 107 | /** 108 | * @brief create tensor with shape, data type, data and dimension type. 109 | * @param shape tensor shape. 110 | * @param type data type. 111 | * @param data data to save. 112 | * @param dimType dimension type. 113 | * @return created tensor. 114 | */ 115 | static Tensor* create(const std::vector& shape, halide_type_t type, void* data = NULL, 116 | DimensionType dimType = TENSORFLOW); 117 | 118 | /** 119 | * @brief create tensor with shape, data and dimension type. data type is represented by `T`. 120 | * @param shape tensor shape. 121 | * @param data data to save. 122 | * @param dimType dimension type. 123 | * @return created tensor. 124 | */ 125 | template 126 | static Tensor* create(const std::vector& shape, void* data = NULL, DimensionType dimType = TENSORFLOW) { 127 | return create(shape, halide_type_of(), data, dimType); 128 | } 129 | 130 | public: 131 | /** 132 | * @brief for DEVICE tensor, copy data from given host tensor. 133 | * @param hostTensor host tensor, the data provider. 134 | * @return true for DEVICE tensor, and false for HOST tensor. 135 | */ 136 | bool copyFromHostTensor(const Tensor* hostTensor); 137 | 138 | /** 139 | * @brief for DEVICE tensor, copy data to given host tensor. 140 | * @param hostTensor host tensor, the data consumer. 141 | * @return true for DEVICE tensor, and false for HOST tensor. 142 | */ 143 | bool copyToHostTensor(Tensor* hostTensor) const; 144 | 145 | /** 146 | * @brief create HOST tensor from DEVICE tensor, with or without data copying. 147 | * @param deviceTensor given device tensor. 148 | * @param copyData copy data or not. 149 | * @return created host tensor. 150 | */ 151 | static Tensor* createHostTensorFromDevice(const Tensor* deviceTensor, bool copyData = true); 152 | 153 | public: 154 | const halide_buffer_t& buffer() const { 155 | return mBuffer; 156 | } 157 | halide_buffer_t& buffer() { 158 | return mBuffer; 159 | } 160 | 161 | /** 162 | * @brief get dimension type. 163 | * @return dimension type. 164 | */ 165 | DimensionType getDimensionType() const; 166 | 167 | /** 168 | * @brief handle data type. used when data type code is halide_type_handle. 169 | * @return handle data type. 170 | */ 171 | HandleDataType getHandleDataType() const; 172 | 173 | /** 174 | * @brief set data type. 175 | * @param type data type defined in 'Type_generated.h'. 176 | */ 177 | void setType(int type); 178 | 179 | /** 180 | * @brief get data type. 181 | * @return data type. 182 | */ 183 | inline halide_type_t getType() const { 184 | return mBuffer.type; 185 | } 186 | 187 | /** 188 | * @brief visit host memory, data type is represented by `T`. 189 | * @return data point in `T` type. 190 | */ 191 | template 192 | T* host() const { 193 | return (T*)mBuffer.host; 194 | } 195 | 196 | /** 197 | * @brief visit device memory. 198 | * @return device data ID. what the ID means varies between backends. 199 | */ 200 | uint64_t deviceId() const { 201 | return mBuffer.device; 202 | } 203 | 204 | public: 205 | int dimensions() const { 206 | return mBuffer.dimensions; 207 | } 208 | 209 | /** 210 | * @brief get all dimensions' extent. 211 | * @return dimensions' extent. 212 | */ 213 | std::vector shape() const; 214 | 215 | /** 216 | * @brief calculate number of bytes needed to store data taking reordering flag into account. 217 | * @return bytes needed to store data 218 | */ 219 | int size() const; 220 | 221 | /** 222 | * @brief calculate number of elements needed to store data taking reordering flag into account. 223 | * @return elements needed to store data 224 | */ 225 | inline int elementSize() const { 226 | return size() / mBuffer.type.bytes(); 227 | } 228 | 229 | public: 230 | inline int width() const { 231 | if (getDimensionType() == TENSORFLOW) { 232 | return mBuffer.dim[2].extent; 233 | } 234 | 235 | return mBuffer.dim[3].extent; 236 | } 237 | inline int height() const { 238 | if (getDimensionType() == TENSORFLOW) { 239 | return mBuffer.dim[1].extent; 240 | } 241 | return mBuffer.dim[2].extent; 242 | } 243 | inline int channel() const { 244 | if (getDimensionType() == TENSORFLOW) { 245 | return mBuffer.dim[3].extent; 246 | } 247 | return mBuffer.dim[1].extent; 248 | } 249 | inline int batch() const { 250 | return mBuffer.dim[0].extent; 251 | } 252 | 253 | // visit dimension's extent & stride 254 | inline int stride(int index) const { 255 | return mBuffer.dim[index].stride; 256 | } 257 | inline int length(int index) const { 258 | return mBuffer.dim[index].extent; 259 | } 260 | inline void setStride(int index, int stride) { 261 | mBuffer.dim[index].stride = stride; 262 | } 263 | inline void setLength(int index, int length) { 264 | mBuffer.dim[index].extent = length; 265 | } 266 | 267 | public: 268 | /** 269 | * @brief print tensor data. for DEBUG use only. 270 | */ 271 | void print() const; 272 | 273 | private: 274 | halide_buffer_t mBuffer; 275 | struct InsideDescribe* mDescribe; 276 | 277 | private: 278 | friend class TensorUtils; 279 | }; 280 | } // namespace MNN 281 | 282 | #endif /* Tensor_hpp */ 283 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/expr/Executor.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Executor.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/07/25. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | #ifndef Executor_hpp 9 | #define Executor_hpp 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | namespace MNN { 18 | class Backend; 19 | class Execution; 20 | namespace Express { 21 | class MNN_PUBLIC Executor { 22 | public: 23 | class ComputeCache { 24 | public: 25 | void setShapeDirty(int offset, Variable::Info* info); 26 | void setContentDirty(); 27 | void setContentReady(); 28 | void syncInput(int offset, const Variable::Info* info); 29 | void syncOutput(int offset, Variable::Info* info); 30 | 31 | struct TensorContent { 32 | std::shared_ptr tensor; 33 | int refCount = 0; 34 | void reset(); 35 | bool aliveOutside = false; 36 | }; 37 | struct Unit; 38 | virtual ~ ComputeCache() {} 39 | ComputeCache() {} 40 | virtual ErrorCode compute() = 0; 41 | virtual ErrorCode resize() = 0; 42 | protected: 43 | // Get the index tensor with the need of needBackend 44 | // If the Tensor don't belong to the backend, need use needBackend to alloc it and return 45 | virtual Tensor* getTensor(int index, bool host) = 0; 46 | void _setShapeDirty(); 47 | friend class Executor; 48 | bool mContentDirty = true; 49 | bool mShapeDirty = true; 50 | }; 51 | struct Requirement { 52 | std::vector contentNeedContent; 53 | std::vector shapeNeedContent; 54 | std::vector supportError; 55 | }; 56 | ~Executor(); 57 | Requirement getRequirement(Expr* expr) const; 58 | ErrorCode computeInfo(Expr* expr); 59 | void makeCache(const std::vector& expr, bool forceCPU = false); 60 | ErrorCode runCache(std::shared_ptr cache); 61 | void setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& config, int numberThread); 62 | enum GCFlag { 63 | FULL, 64 | PART 65 | }; 66 | void gc(GCFlag flag = FULL); 67 | static std::shared_ptr getGlobalExecutor(); 68 | void resetProfile(); 69 | void dumpProfile(); 70 | void addOpCostTime(int op, float costTime); 71 | class Profiler; 72 | private: 73 | void _createSingle(EXPRP expr); 74 | void _create(const std::vector& outputs, std::set>&& inputCaches, std::vector&& tensors, bool forceCPU); 75 | 76 | void _addToCache(const std::vector>& caches); 77 | void _resetCache(); 78 | void _visit(EXPRP expr, std::set>& inputCaches, std::vector& tensors); 79 | 80 | Executor(std::shared_ptr backend); 81 | std::shared_ptr mBackend; 82 | std::shared_ptr mBackupBackend; 83 | std::mutex mMutex; 84 | std::vector> mStack; 85 | std::vector mStackInputs; 86 | std::vector mStackOutputs; 87 | std::shared_ptr mProfiler; 88 | }; 89 | } // namespace Express 90 | } // namespace MNN 91 | #endif 92 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/expr/Expr.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Expr.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/06/10. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef Expr_hpp 10 | #define Expr_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | namespace MNN { 22 | struct OpT; 23 | struct Op; 24 | struct NetT; 25 | namespace Express { 26 | class Variable; 27 | class Expr; 28 | class Executor; 29 | typedef std::shared_ptr EXPRP; 30 | typedef std::weak_ptr WeakEXPRP; 31 | typedef std::vector INTS; 32 | enum Dimensionformat { NHWC, NC4HW4, NCHW }; 33 | class MNN_PUBLIC VARP { 34 | public: 35 | VARP() { 36 | // Do nothing 37 | } 38 | VARP(std::shared_ptr c) { 39 | mContent = std::move(c); 40 | } 41 | VARP(Variable* c) { 42 | mContent.reset(c); 43 | } 44 | Variable* get() const { 45 | return mContent.get(); 46 | } 47 | ~ VARP() { 48 | // Do nothing 49 | } 50 | VARP(const VARP& var) { 51 | mContent = var.mContent; 52 | } 53 | VARP(VARP&& var) { 54 | mContent = std::move(var.mContent); 55 | } 56 | VARP operator+(VARP var) const; 57 | VARP operator-(VARP var) const; 58 | VARP operator*(VARP var) const; 59 | VARP operator/(VARP var) const; 60 | VARP mean(INTS dims) const; 61 | VARP sum(INTS dims) const; 62 | 63 | bool operator==(const VARP& var) const { 64 | return var.mContent == mContent; 65 | } 66 | bool operator<(const VARP& var) const { 67 | return mContent < var.mContent; 68 | } 69 | bool operator<=(const VARP& var) const { 70 | return mContent <= var.mContent; 71 | } 72 | VARP& operator=(const VARP& var) { 73 | mContent = var.mContent; 74 | return *this; 75 | } 76 | VARP& operator=(Variable* var) { 77 | mContent.reset(var); 78 | return *this; 79 | } 80 | Variable* operator->() const { 81 | return mContent.get(); 82 | } 83 | enum InputType { 84 | INPUT = 0, 85 | CONSTANT = 1, 86 | TRAINABLE = 2, 87 | }; 88 | bool fix(InputType type) const; 89 | private: 90 | std::shared_ptr mContent; 91 | }; 92 | inline bool operator==(Variable* src, VARP dst) { 93 | return src == dst.get(); 94 | } 95 | inline bool operator!=(Variable* src, VARP dst) { 96 | return src != dst.get(); 97 | } 98 | // inline bool operator<(VARP src, VARP dst) { 99 | // return src.get() < dst.get(); 100 | // } 101 | typedef std::vector VARPS; 102 | 103 | class MNN_PUBLIC Variable { 104 | public: 105 | struct Info { 106 | Dimensionformat order = NHWC; 107 | INTS dim; 108 | halide_type_t type; 109 | int size; 110 | void* ptr = nullptr; 111 | void syncSize(); 112 | }; 113 | const std::string& name() const; 114 | void setName(const std::string& name); 115 | std::pair expr() const { 116 | return std::make_pair(mFrom, mFromIndex); 117 | } 118 | // If compute info error, return nullptr 119 | const Info* getInfo(); 120 | bool resize(INTS dims); 121 | template 122 | const T* readMap() { 123 | return (const T*)readInternal(); 124 | } 125 | 126 | template 127 | T* writeMap() { 128 | return (T*)writeInternal(); 129 | } 130 | 131 | //Depecerate 132 | void unMap(); 133 | 134 | bool input(VARP src); 135 | static void replace(VARP dst, VARP src); 136 | 137 | static VARP create(EXPRP expr, int index = 0); 138 | 139 | static std::vector load(const char* fileName); 140 | static std::map loadMap(const char* fileName); 141 | static std::vector load(const uint8_t* buffer, size_t length); 142 | static std::map loadMap(const uint8_t* buffer, size_t length); 143 | static std::pair, std::map> getInputAndOutput(const std::map& allVariable); 144 | static std::vector mapToSequence(const std::map& source); 145 | static std::vector getExecuteOrder(const std::vector& output); 146 | static void save(const std::vector& vars, const char* fileName); 147 | static void save(const std::vector& vars, NetT* dest); 148 | 149 | // Pack a few Variable to compute in one pipeline 150 | static void prepareCompute(const std::vector& vars, bool forceCPU = false); 151 | 152 | size_t linkNumber() const; 153 | const std::vector& toExprs() const; 154 | void setExpr(EXPRP expr, int index) { 155 | mFrom = expr; 156 | mFromIndex = index; 157 | } 158 | private: 159 | Variable(EXPRP expr, int index) { 160 | mFrom = expr; 161 | mFromIndex = index; 162 | } 163 | 164 | void* readInternal(bool forShape = false); 165 | void* writeInternal(bool inform=true); 166 | void informDirty(); 167 | 168 | friend class Expr; 169 | EXPRP mFrom; 170 | int mFromIndex; 171 | }; 172 | 173 | class MNN_PUBLIC Expr { 174 | public: 175 | struct Inside; 176 | static EXPRP create(Variable::Info&& info); 177 | static EXPRP create(const OpT* op, std::vector inputs, int outputSize = 1); 178 | static EXPRP create(std::pair, int> extra, std::vector&& inputs, int outputSize = 1); 179 | static EXPRP create(std::unique_ptr&& op, std::vector inputs, int outputSize = 1) { 180 | return create(op.get(), inputs, outputSize); 181 | } 182 | void setName(const std::string& name); 183 | 184 | const Op* get() const { 185 | return mOp; 186 | } 187 | const std::vector& inputs() const { 188 | return mInputs; 189 | } 190 | int outputSize() const { 191 | return mOutputNames.size(); 192 | } 193 | static void replace(EXPRP oldExpr, EXPRP newExpr); 194 | bool requireInfo(); 195 | void visitOutputs(const std::function& visit); 196 | static void visit(EXPRP expr, const std::function& before, const std::function& after); 197 | 198 | const std::vector& outputs() const { 199 | return mTo; 200 | } 201 | ~Expr(); 202 | 203 | bool visited() const { 204 | return mVisited; 205 | } 206 | void setVisited(bool visited) { 207 | mVisited = visited; 208 | } 209 | const std::string& name() const { 210 | return mName; 211 | } 212 | const std::string& outputName(int index) { 213 | return mOutputNames[index]; 214 | } 215 | 216 | VARP::InputType inputType() const {return mType;} 217 | Variable::Info* outputInfo(int index) const; 218 | std::pair, int> extra() const { 219 | return std::make_pair(mExtraBuffer, mOpBufferSize); 220 | } 221 | bool setInfoDirty(); 222 | std::shared_ptr inside() const { 223 | return mInside; 224 | } 225 | bool valid() const { 226 | return mValid; 227 | } 228 | 229 | void setEntry(const std::vector& entries) { 230 | mEntries = entries; 231 | } 232 | 233 | const std::vector& getEntry() const { 234 | return mEntries; 235 | } 236 | 237 | private: 238 | static void _addLinkForInputs(EXPRP expr); 239 | 240 | Expr(int outputSize); 241 | 242 | friend class Variable; 243 | friend class VARP; 244 | VARP::InputType mType; 245 | const Op* mOp; 246 | std::vector mInputs; 247 | std::vector mOutputNames; 248 | 249 | bool mValid = true; 250 | std::shared_ptr mExtraBuffer; 251 | int mOpBufferSize = 0; 252 | std::string mName; 253 | std::shared_ptr mInside = nullptr; 254 | bool mVisited = false; 255 | std::vector mTo; 256 | 257 | // Only the enter input has entries, and it helps to get info for enter 258 | // input expression. 259 | std::vector mEntries; 260 | }; 261 | } // namespace Express 262 | } // namespace MNN 263 | 264 | #endif /* Expr_hpp */ 265 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/expr/ExprCreator.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ExprCreator.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/06/27. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef ExprCreator_hpp 10 | #define ExprCreator_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #endif 17 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/expr/MathOp.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // MathOp.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/06/27. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MathOp_HPP 10 | #define MathOp_HPP 11 | 12 | namespace MNN { 13 | namespace Express { 14 | //BinaryOPs 15 | MNN_PUBLIC VARP _Add(VARP x, VARP y); 16 | MNN_PUBLIC VARP _Subtract(VARP x, VARP y); 17 | MNN_PUBLIC VARP _Multiply(VARP x, VARP y); 18 | MNN_PUBLIC VARP _Divide(VARP x, VARP y); 19 | MNN_PUBLIC VARP _Pow(VARP x, VARP y); 20 | MNN_PUBLIC VARP _Minimum(VARP x, VARP y); 21 | MNN_PUBLIC VARP _Maximum(VARP x, VARP y); 22 | MNN_PUBLIC VARP _BiasAdd(VARP value, VARP bias); 23 | MNN_PUBLIC VARP _Greater(VARP x, VARP y); 24 | MNN_PUBLIC VARP _GreaterEqual(VARP x, VARP y); 25 | MNN_PUBLIC VARP _Less(VARP x, VARP y); 26 | MNN_PUBLIC VARP _FloorDiv(VARP x, VARP y); 27 | MNN_PUBLIC VARP _SquaredDifference(VARP x, VARP y); 28 | MNN_PUBLIC VARP _Equal(VARP x, VARP y); 29 | MNN_PUBLIC VARP _LessEqual(VARP x, VARP y); 30 | MNN_PUBLIC VARP _FloorMod(VARP x, VARP y); 31 | MNN_PUBLIC VARP _Atan2(VARP x, VARP y); 32 | MNN_PUBLIC VARP _LogicalOr(VARP x, VARP y); 33 | MNN_PUBLIC VARP _NotEqual(VARP x, VARP y); 34 | 35 | //UnaryOPs 36 | MNN_PUBLIC VARP _Sign(VARP a); 37 | MNN_PUBLIC VARP _Abs(VARP x); 38 | MNN_PUBLIC VARP _Negative(VARP x); 39 | MNN_PUBLIC VARP _Floor(VARP x); 40 | MNN_PUBLIC VARP _Round(VARP x); 41 | MNN_PUBLIC VARP _Ceil(VARP x); 42 | MNN_PUBLIC VARP _Square(VARP x); 43 | MNN_PUBLIC VARP _Sqrt(VARP x); 44 | MNN_PUBLIC VARP _Rsqrt(VARP x); 45 | MNN_PUBLIC VARP _Exp(VARP x); 46 | MNN_PUBLIC VARP _Log(VARP x); 47 | MNN_PUBLIC VARP _Sin(VARP x); 48 | MNN_PUBLIC VARP _Sinh(VARP x); 49 | MNN_PUBLIC VARP _Cos(VARP x); 50 | MNN_PUBLIC VARP _Cosh(VARP x); 51 | MNN_PUBLIC VARP _Tan(VARP x); 52 | MNN_PUBLIC VARP _Asin(VARP x); 53 | MNN_PUBLIC VARP _Asinh(VARP x); 54 | MNN_PUBLIC VARP _Acos(VARP x); 55 | MNN_PUBLIC VARP _Acosh(VARP x); 56 | MNN_PUBLIC VARP _Atan(VARP x); 57 | MNN_PUBLIC VARP _Atanh(VARP x); 58 | MNN_PUBLIC VARP _Reciprocal(VARP x); 59 | MNN_PUBLIC VARP _Log1p(VARP x); 60 | //Only one but not in UnaryOPs 61 | MNN_PUBLIC VARP _Tanh(VARP x); 62 | MNN_PUBLIC VARP _Sigmoid(VARP x); 63 | MNN_PUBLIC VARP _Erf(VARP x); 64 | MNN_PUBLIC VARP _Erfc(VARP x); 65 | MNN_PUBLIC VARP _Erfinv(VARP x); 66 | MNN_PUBLIC VARP _Expm1(VARP x); 67 | 68 | 69 | //ReduceOPs 70 | MNN_PUBLIC VARP _ReduceSum(VARP input_variable, INTS axis = {}, bool keepDims = false); 71 | MNN_PUBLIC VARP _ReduceMean(VARP input_variable, INTS axis = {}, bool keepDims = false); 72 | MNN_PUBLIC VARP _ReduceMax(VARP input_variable, INTS axis = {}, bool keepDims = false); 73 | MNN_PUBLIC VARP _ReduceMin(VARP input_variable, INTS axis = {}, bool keepDims = false); 74 | MNN_PUBLIC VARP _ReduceProd(VARP input_variable, INTS axis = {}, bool keepDims = false); 75 | MNN_PUBLIC VARP _ReduceAny(VARP input_variable, INTS axis = {}, bool keepDims = false); 76 | MNN_PUBLIC VARP _ReduceAll(VARP input_variable, INTS axis = {}, bool keepDims = false); 77 | 78 | MNN_PUBLIC VARP _ReduceSumMutable(VARP input_variable, VARP axis, bool keepDims = false); 79 | MNN_PUBLIC VARP _ReduceMeanMutable(VARP input_variable, VARP axis, bool keepDims = false); 80 | MNN_PUBLIC VARP _ReduceMaxMutable(VARP input_variable, VARP axis, bool keepDims = false); 81 | MNN_PUBLIC VARP _ReduceMinMutable(VARP input_variable, VARP axis, bool keepDims = false); 82 | MNN_PUBLIC VARP _ReduceProdMutable(VARP input_variable, VARP axis, bool keepDims = false); 83 | MNN_PUBLIC VARP _ReduceAnyMutable(VARP input_variable, VARP axis, bool keepDims = false); 84 | MNN_PUBLIC VARP _ReduceAllMutable(VARP input_variable, VARP axis, bool keepDims = false); 85 | 86 | //EltwiseOPs 87 | MNN_PUBLIC VARP _Prod(VARP a, VARP b, std::vector coeff); 88 | MNN_PUBLIC VARP _Sum(VARP a, VARP b, std::vector coeff); 89 | MNN_PUBLIC VARP _Max(VARP a, VARP b, std::vector coeff); 90 | MNN_PUBLIC VARP _Sub(VARP a, VARP b, std::vector coeff); 91 | MNN_PUBLIC VARP _EltwiseProdInt8(VARP x, VARP y, 92 | std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, 93 | std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, 94 | std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); 95 | MNN_PUBLIC VARP _EltwiseSumInt8(VARP x, VARP y, 96 | std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, 97 | std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, 98 | std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); 99 | MNN_PUBLIC VARP _EltwiseSubInt8(VARP x, VARP y, 100 | std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, 101 | std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, 102 | std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); 103 | MNN_PUBLIC VARP _EltwiseMaxInt8(VARP x, VARP y, 104 | std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, 105 | std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, 106 | std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); 107 | 108 | 109 | //OtherOPs 110 | template 111 | VARP _Cast(VARP x) { 112 | return _Cast(x, halide_type_of()); 113 | } 114 | MNN_PUBLIC VARP _Cast(VARP x, halide_type_t dtype); 115 | MNN_PUBLIC VARP _MatMul(VARP a, VARP b, bool tranposeA = false, bool tranposeB = false); 116 | MNN_PUBLIC VARP _Normalize(VARP x, int32_t acrossSpatial, int32_t channelShared, float eps, std::vector scale); 117 | MNN_PUBLIC VARP _ArgMax(VARP input, int axis = 0); 118 | MNN_PUBLIC VARP _ArgMin(VARP input, int axis = 0); 119 | MNN_PUBLIC VARP _BatchMatMul(VARP x, VARP y, bool adj_x = false, bool adj_y = false); 120 | MNN_PUBLIC VARP _UnravelIndex(VARP indices, VARP dims); 121 | MNN_PUBLIC VARP _ScatterNd(VARP indices, VARP updates, VARP shape); 122 | MNN_PUBLIC VARP _OneHot(VARP indices, VARP depth, VARP onValue, VARP offValue, int axis = -1); 123 | MNN_PUBLIC VARP _BroadcastTo(VARP a, VARP shape); 124 | MNN_PUBLIC VARP _LinSpace(VARP start, VARP stop, VARP num); 125 | }; // namespace Express 126 | }; // namespace MNN 127 | 128 | #endif /* MathOp_HPP */ 129 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/expr/NeuralNetWorkOp.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // NeuralNetWorkOp.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/06/27. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef NeuralNetWorkOp_HPP 10 | #define NeuralNetWorkOp_HPP 11 | 12 | namespace MNN { 13 | namespace Express { 14 | enum PaddingMode {CAFFE, VALID, SAME}; 15 | enum PoolingMode {MAXPOOL, AVEPOOL}; 16 | enum PadValueMode {CONSTANT, REFLECT, SYMMETRIC}; 17 | MNN_PUBLIC VARP _Input(INTS shape = {}, Dimensionformat data_format = NC4HW4, halide_type_t dtype = halide_type_of()) ; 18 | MNN_PUBLIC VARP _Clone(VARP source, bool deepCopy = false); 19 | 20 | MNN_PUBLIC VARP _Scalar(const void* ptr, halide_type_t type); 21 | 22 | template 23 | VARP _Scalar(T value) { 24 | return _Scalar(&value, halide_type_of()); 25 | } 26 | 27 | 28 | MNN_PUBLIC VARP _Const(float value, INTS shape = {}, Dimensionformat format = NHWC); 29 | MNN_PUBLIC VARP _Const(const void* ptr, INTS shape = {}, Dimensionformat format = NHWC, 30 | halide_type_t type = halide_type_of()); 31 | MNN_PUBLIC VARP _TrainableParam(float value, INTS dims, Dimensionformat format); 32 | MNN_PUBLIC VARP _TrainableParam(const void* ptr, INTS dims, Dimensionformat format, 33 | halide_type_t type = halide_type_of()); 34 | MNN_PUBLIC VARP _Conv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, INTS stride = {1, 1}, 35 | INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}); 36 | 37 | MNN_PUBLIC VARP _Conv(float weight, float bias, VARP x, INTS channel, INTS kernelSize, PaddingMode pad = VALID, 38 | INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1); 39 | MNN_PUBLIC VARP _Conv(std::vector&& weight, std::vector&& bias, VARP x, INTS channel, INTS kernelSize, 40 | PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false); 41 | MNN_PUBLIC VARP _Conv(std::vector&& weight, std::vector&& bias, VARP x, INTS channel, INTS kernelSize, 42 | PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false); 43 | MNN_PUBLIC VARP _Deconv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, INTS stride = {1, 1}, 44 | INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}); 45 | MNN_PUBLIC VARP _MaxPool(VARP x, INTS kernel, INTS stride = {1, 1}, PaddingMode pad = VALID, INTS pads= {0, 0}); 46 | MNN_PUBLIC VARP _AvePool(VARP x, INTS kernel, INTS stride = {1, 1}, PaddingMode pad = VALID, INTS pads= {0, 0}); 47 | MNN_PUBLIC VARP _Reshape(VARP x, INTS shape, Dimensionformat original_format = NHWC); 48 | MNN_PUBLIC VARP _Reshape(VARP x, VARP shape); 49 | MNN_PUBLIC VARP _Scale(VARP x, int channels, std::vector&& scales, std::vector&& bias); 50 | 51 | MNN_PUBLIC VARP _Relu(VARP x, float slope = 0.0f); 52 | MNN_PUBLIC VARP _Relu6(VARP x); 53 | MNN_PUBLIC VARP _PRelu(VARP x, std::vector &&slopes); 54 | MNN_PUBLIC VARP _Softmax(VARP logits, int axis = -1); 55 | MNN_PUBLIC VARP _Softplus(VARP features); 56 | MNN_PUBLIC VARP _Softsign(VARP features); 57 | MNN_PUBLIC std::vector _Split(VARP value, INTS size_splits, int axis = 0); 58 | MNN_PUBLIC VARP _Slice(VARP x, VARP starts, VARP sizes); 59 | MNN_PUBLIC VARP _StridedSlice(VARP input, VARP begin, VARP end, VARP strided, 60 | int32_t beginMask, int32_t endMask, int32_t ellipsisMask, 61 | int32_t newAxisMask, int32_t shrinkAxisMask); 62 | MNN_PUBLIC VARP _Concat(VARPS values, int axis); 63 | MNN_PUBLIC VARP _Convert(VARP input, Dimensionformat format); 64 | MNN_PUBLIC VARP _Transpose(VARP x, INTS perm); 65 | MNN_PUBLIC VARP _Transpose(VARP x, VARP perm); 66 | MNN_PUBLIC VARP _ChannelShuffle(VARP x, int group); 67 | MNN_PUBLIC VARP _ChangeInputFormat(VARP input, Dimensionformat format); 68 | MNN_PUBLIC VARP _Conv2DBackPropFilter(VARP input, VARP inputGrad, INTS kernelSize, PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}); 69 | MNN_PUBLIC VARP _PoolGrad(VARP originInput, VARP originOutput, VARP inputGrad, INTS kernel, INTS stride, PoolingMode type, PaddingMode pad = VALID, INTS pads= {0, 0}); 70 | // FIXME: move the api to Array Ops 71 | MNN_PUBLIC VARP _ReverseSequence(VARP x, VARP y, int batchDim, int seqDim); 72 | // FIXME: move the api to Image Ops 73 | MNN_PUBLIC VARP _Crop(VARP images, VARP size, int axis, INTS offset); 74 | MNN_PUBLIC VARP _Resize(VARP images, float xScale, float yScale); 75 | MNN_PUBLIC VARP _Pad(VARP x, VARP paddings, PadValueMode mode = CONSTANT); 76 | MNN_PUBLIC VARP _ExpandDims(VARP input, int axis); 77 | MNN_PUBLIC VARP _ExpandDims(VARP input, VARP axis); 78 | 79 | MNN_PUBLIC VARP _Shape(VARP input); 80 | MNN_PUBLIC VARP _Stack(VARPS values, int axis=0); 81 | enum InterpolationMethod {BILINEAR, NEAREST}; 82 | MNN_PUBLIC VARP _CropAndResize(VARP image, VARP boxes, VARP box_ind, VARP crop_size, 83 | InterpolationMethod method, float extrapolation_value = 0.0); 84 | MNN_PUBLIC VARP _Fill(VARP dims, VARP value); 85 | MNN_PUBLIC VARP _Tile(VARP input, VARP multiples); 86 | MNN_PUBLIC VARP _Gather(VARP params, VARP indices); 87 | MNN_PUBLIC VARP _GatherV2(VARP params, VARP indices, VARP axis = nullptr); 88 | MNN_PUBLIC VARP _Squeeze(VARP input, INTS axis = {}); 89 | MNN_PUBLIC VARP _Unsqueeze(VARP input, INTS axis = {}); 90 | MNN_PUBLIC VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops); 91 | MNN_PUBLIC VARP _GatherND(VARP params, VARP indices); 92 | MNN_PUBLIC VARP _Selu(VARP features, float scale, float alpha); 93 | MNN_PUBLIC VARP _Size(VARP input); 94 | MNN_PUBLIC VARP _Elu(VARP features, float alpha=1.0); 95 | MNN_PUBLIC VARP _MatrixBandPart(VARP input, VARP num_lower, VARP num_upper); 96 | MNN_PUBLIC std::vector _Moments(VARP x, INTS axis, VARP shift, bool keepDims); 97 | MNN_PUBLIC VARP _SetDiff1D(VARP x, VARP y); 98 | MNN_PUBLIC VARP _SpaceToDepth(VARP input, int block_size); 99 | MNN_PUBLIC VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings); 100 | MNN_PUBLIC VARP _ZerosLike(VARP input); 101 | MNN_PUBLIC std::vector _Unstack(VARP value, int axis=0); 102 | MNN_PUBLIC VARP _Rank(VARP input); 103 | MNN_PUBLIC VARP _Range(VARP start, VARP limit, VARP delta); 104 | MNN_PUBLIC VARP _DepthToSpace(VARP input, int block_size); 105 | MNN_PUBLIC VARP _PriorBox(VARP feature, VARP image, 106 | std::vector min_size, std::vector max_size, std::vectoraspect_ratio, 107 | bool flip, bool clip, std::vectorvariance, 108 | unsigned int img_h, unsigned int img_w, float step_h, float step_w, float offset = 0.5); 109 | MNN_PUBLIC VARP _Permute(VARP input, INTS dims); 110 | MNN_PUBLIC VARP _DetectionOutput(VARP location, VARP confidence, VARP priorbox, 111 | unsigned int num_classes, bool share_location, int background_label_id, 112 | float nms_threshhold, int nms_topk, int code_type, 113 | bool variance_encoded_in_target, 114 | int keep_top_k, float confidence_threshold, float visualize_threshold); 115 | MNN_PUBLIC std::vector _DetectionPostProcess(VARP encode_boxes, VARP class_predictions, VARP anchors, 116 | int num_classes, int max_detections, 117 | int max_class_per_detection, int detections_per_class, 118 | float nms_threshold, float iou_threshold, 119 | bool use_regular_nms, std::vector centersize_encoding); 120 | MNN_PUBLIC VARP _Interp(VARPS xs, float widthScale, float heightScale, int outputWidth, int outputHeight, int resizeType, bool alignCorners); 121 | 122 | MNN_PUBLIC VARP _ZeroGrad(VARP x); 123 | 124 | // Int8 Inference 125 | MNN_PUBLIC VARP _Conv(std::vector&& weight, std::vector&& bias, std::vector&& scale, VARP x, INTS channel, INTS kernelSize, 126 | PaddingMode pad, INTS stride, INTS dilate, int group, INTS pads, bool relu); 127 | MNN_PUBLIC VARP _FloatToInt8(VARP x, VARP scale, char minValue, char maxValue); 128 | MNN_PUBLIC VARP _Int8ToFloat(VARP x, VARP scale); 129 | 130 | MNN_PUBLIC VARP _Select(VARP select, VARP input0, VARP input1); 131 | 132 | } // namespace Express 133 | } // namespace MNN 134 | 135 | #endif /* NeuralNetWorkOp_HPP */ 136 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/expr/Optimizer.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Optimizer.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/08/20. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | #ifndef Optimizer_hpp 9 | #define Optimizer_hpp 10 | #include 11 | #include 12 | 13 | namespace MNN { 14 | namespace Express { 15 | class MNN_PUBLIC Optimizer { 16 | public: 17 | enum Device { 18 | CPU = 0, 19 | GPU = 1, 20 | OTHER = 2, 21 | AUTO = 3 22 | }; 23 | struct Config { 24 | Device device = CPU; 25 | MNNForwardType forwardType = MNN_FORWARD_ALL; 26 | int numThread = 4; 27 | }; 28 | static std::shared_ptr create(Config config); 29 | struct Cost { 30 | float compute; // MFlops 31 | float memory; // MB 32 | }; 33 | class Parameters { 34 | public: 35 | Parameters(int n); 36 | virtual ~Parameters(); 37 | 38 | float* get() const { 39 | return mValue; 40 | } 41 | int size() const { 42 | return mSize; 43 | } 44 | 45 | private: 46 | float* mValue; 47 | int mSize; 48 | }; 49 | virtual std::shared_ptr onGetParameters(const std::vector& outputs) { 50 | return nullptr; 51 | } 52 | 53 | //Given paramters and measure cost, the parameters must be the same as onGetParameters 54 | virtual Cost onMeasure(const std::vector& outputs, std::shared_ptr parameters = nullptr) = 0; 55 | 56 | //Modify the output directly, the parameters must be the same as onGetParameters 57 | virtual bool onExecute(const std::vector& outputs, std::shared_ptr parameters = nullptr) = 0; 58 | 59 | Optimizer() = default; 60 | virtual ~Optimizer() = default; 61 | }; 62 | } // namespace Express 63 | } // namespace MNN 64 | #endif 65 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/plugin/PluginContext.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ShapeInference.h 3 | // MNN 4 | // 5 | // Created by MNN on 2020/04/05. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_PLUGIN_PLUGIN_CONTEXT_HPP_ 10 | #define MNN_PLUGIN_PLUGIN_CONTEXT_HPP_ 11 | 12 | #include 13 | #include 14 | 15 | #include // Backend 16 | #include 17 | #include 18 | 19 | namespace MNN { 20 | namespace plugin { 21 | 22 | class MNN_PUBLIC PluginContext { 23 | public: 24 | PluginContext() = delete; 25 | PluginContext(const std::vector& inputs, // NOLINT 26 | const std::vector& outputs); 27 | 28 | virtual ~PluginContext() = default; 29 | 30 | const std::vector& inputs() const { 31 | return inputs_; 32 | } 33 | const std::vector& outputs() const { 34 | return outputs_; 35 | } 36 | 37 | const Tensor* input(const int index) const; 38 | const Tensor* output(const int index) const; 39 | 40 | Tensor* output(const int index); 41 | 42 | bool hasAttr(const std::string& name) const; 43 | 44 | bool setAttr(const std::string& name, const Attribute* attr); 45 | 46 | void setAttrs(const std::unordered_map& attrs); 48 | 49 | const Attribute* getAttr(const std::string& name) const; 50 | 51 | const std::unordered_map& getAttrs() const; 52 | 53 | protected: 54 | const std::vector& inputs_; 55 | const std::vector& outputs_; 56 | std::unordered_map attrs_; 57 | }; 58 | 59 | class MNN_PUBLIC InferShapeContext : public PluginContext { 60 | public: 61 | InferShapeContext() = delete; 62 | InferShapeContext(const std::vector& inputs, // NOLINT 63 | const std::vector& outputs); 64 | 65 | virtual ~InferShapeContext() = default; 66 | }; 67 | 68 | class MNN_PUBLIC CPUKernelContext : public PluginContext { 69 | public: 70 | CPUKernelContext() = delete; 71 | CPUKernelContext(const std::string& op_type, // NOLINT 72 | Backend* backend, // NOLINT 73 | const std::vector& inputs, // NOLINT 74 | const std::vector& outputs); 75 | 76 | virtual ~CPUKernelContext() = default; 77 | 78 | Backend* backend() const { 79 | return backend_; 80 | } 81 | 82 | const std::string& op_type() const { 83 | return op_type_; 84 | } 85 | 86 | private: 87 | const std::string op_type_ = ""; 88 | Backend* backend_ = nullptr; 89 | }; 90 | 91 | inline PluginContext::PluginContext(const std::vector& inputs, // NOLINT 92 | const std::vector& outputs) // NOLINT 93 | : inputs_(inputs), outputs_(outputs) { 94 | } 95 | 96 | inline const Tensor* PluginContext::input(const int index) const { 97 | MNN_ASSERT(index < inputs_.size()); 98 | return inputs_.at(index); 99 | } 100 | 101 | inline const Tensor* PluginContext::output(const int index) const { 102 | MNN_ASSERT(index < outputs_.size()); 103 | return outputs_.at(index); 104 | } 105 | 106 | inline Tensor* PluginContext::output(const int index) { 107 | MNN_ASSERT(index < outputs_.size()); 108 | return outputs_.at(index); 109 | } 110 | 111 | inline bool PluginContext::hasAttr(const std::string& name) const { 112 | return attrs_.count(name) > 0; 113 | } 114 | 115 | inline bool PluginContext::setAttr(const std::string& name, // NOLINT 116 | const Attribute* attr) { 117 | return attrs_.emplace(name, attr).second; 118 | } 119 | 120 | inline void PluginContext::setAttrs( // NOLINT 121 | const std::unordered_map& attrs) { 122 | attrs_ = attrs; 123 | } 124 | 125 | inline const Attribute* PluginContext::getAttr(const std::string& name) const { 126 | const auto& it = attrs_.find(name); 127 | MNN_ASSERT(it != attrs_.end()); 128 | return it->second; 129 | } 130 | 131 | inline const std::unordered_map& // NOLINT 132 | PluginContext::getAttrs() const { 133 | return attrs_; 134 | } 135 | 136 | } // namespace plugin 137 | } // namespace MNN 138 | 139 | #endif // MNN_PLUGIN_PLUGIN_CONTEXT_HPP_ 140 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/plugin/PluginKernel.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ShapeInference.h 3 | // MNN 4 | // 5 | // Created by MNN on 2020/04/05. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_PLUGIN_PLUGIN_KERNEL_HPP_ 10 | #define MNN_PLUGIN_PLUGIN_KERNEL_HPP_ 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | namespace MNN { 19 | namespace plugin { 20 | 21 | template 22 | class MNN_PUBLIC ComputeKernel { 23 | public: 24 | ComputeKernel() = default; 25 | virtual ~ComputeKernel() = default; 26 | virtual bool compute(KernelContextT* ctx) = 0; 27 | }; 28 | 29 | class MNN_PUBLIC CPUComputeKernel : public ComputeKernel { 30 | public: 31 | using ContextT = CPUKernelContext; 32 | using KernelT = CPUComputeKernel; 33 | 34 | CPUComputeKernel() = default; 35 | virtual ~CPUComputeKernel() = default; 36 | virtual bool init(CPUKernelContext* ctx) = 0; 37 | virtual bool compute(CPUKernelContext* ctx) = 0; 38 | }; 39 | 40 | template 41 | class MNN_PUBLIC ComputeKernelRegistry { 42 | public: 43 | typedef std::function Factory; 44 | static std::unordered_map* getFactoryMap(); 45 | 46 | static bool add(const std::string& name, Factory factory); 47 | 48 | static PluginKernelT* get(const std::string& name); 49 | }; 50 | 51 | template 52 | struct ComputeKernelRegistrar { 53 | ComputeKernelRegistrar(const std::string& name) { 54 | ComputeKernelRegistry::add(name, []() { // NOLINT 55 | return new PluginKernelT; // NOLINT 56 | }); 57 | } 58 | }; 59 | 60 | #define REGISTER_PLUGIN_COMPUTE_KERNEL(name, computeKernel) \ 61 | namespace { \ 62 | static auto _plugin_compute_kernel_##name##_ __attribute__((unused)) = \ 63 | ComputeKernelRegistrar(#name); \ 64 | } // namespace 65 | 66 | } // namespace plugin 67 | } // namespace MNN 68 | 69 | #endif // MNN_PLUGIN_PLUGIN_KERNEL_HPP_ 70 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNN/plugin/PluginShapeInference.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ShapeInference.h 3 | // MNN 4 | // 5 | // Created by MNN on 2020/04/05. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_ 10 | #define MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_ 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | namespace MNN { 19 | namespace plugin { 20 | 21 | class MNN_PUBLIC InferShapeKernel { 22 | public: 23 | virtual ~InferShapeKernel() = default; 24 | virtual bool compute(InferShapeContext* ctx) = 0; 25 | }; 26 | 27 | class MNN_PUBLIC InferShapeKernelRegister { 28 | public: 29 | // typedef InferShapeKernel* (*Factory)(); 30 | typedef std::function Factory; 31 | static std::unordered_map* getFactoryMap(); 32 | 33 | static bool add(const std::string& name, Factory factory); 34 | 35 | static InferShapeKernel* get(const std::string& name); 36 | }; 37 | 38 | template 39 | struct InferShapeKernelRegistrar { 40 | InferShapeKernelRegistrar(const std::string& name) { 41 | InferShapeKernelRegister::add(name, []() { // NOLINT 42 | return new PluginKernel; // NOLINT 43 | }); 44 | } 45 | }; 46 | 47 | #define REGISTER_PLUGIN_OP(name, inferShapeKernel) \ 48 | namespace { \ 49 | static auto _plugin_infer_shape_##name##_ __attribute__((unused)) = \ 50 | InferShapeKernelRegistrar(#name); \ 51 | } // namespace 52 | 53 | } // namespace plugin 54 | } // namespace MNN 55 | 56 | #endif // MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_ 57 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/BlockingQueue.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // BlockingQueue.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/19. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef BlockingQueue_hpp 10 | #define BlockingQueue_hpp 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace MNN { 17 | namespace Train { 18 | 19 | template 20 | class BlockingQueue { 21 | public: 22 | BlockingQueue() = default; 23 | BlockingQueue(size_t maxSize) : mMaxSize(maxSize) { 24 | } 25 | 26 | bool isFull() { 27 | return mQueue.size() == mMaxSize; 28 | } 29 | 30 | bool isEmpty() { 31 | return mQueue.empty(); 32 | } 33 | 34 | void push(T value) { 35 | { 36 | std::unique_lock lock(mMutex); 37 | mCondVar.wait(lock, [&] { return !isFull(); }); 38 | MNN_ASSERT(!isFull()); 39 | mQueue.push(std::move(value)); 40 | lock.unlock(); 41 | } 42 | mCondVar.notify_one(); 43 | } 44 | 45 | T pop() { 46 | std::unique_lock lock(mMutex); 47 | mCondVar.wait(lock, [&] { return !isEmpty(); }); 48 | MNN_ASSERT(!isEmpty()); 49 | T value = mQueue.front(); 50 | mQueue.pop(); 51 | mCondVar.notify_one(); 52 | lock.unlock(); 53 | 54 | return std::move(value); 55 | } 56 | 57 | size_t clear() { 58 | std::lock_guard lock(mMutex); 59 | const auto size = mQueue.size(); 60 | while (!isEmpty()) { 61 | mQueue.pop(); 62 | } 63 | return size; 64 | } 65 | 66 | private: 67 | size_t mMaxSize; 68 | std::queue mQueue; 69 | std::mutex mMutex; 70 | std::condition_variable_any mCondVar; 71 | }; 72 | 73 | } // namespace Train 74 | } // namespace MNN 75 | 76 | #endif // BlockingQueue_hpp 77 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/DataLoader.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // DataLoader.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/15. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef DataLoader_hpp 10 | #define DataLoader_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | #include "BlockingQueue.hpp" 16 | #include "DataLoaderConfig.hpp" 17 | #include "Example.hpp" 18 | namespace MNN { 19 | namespace Train { 20 | class BatchDataset; 21 | class Sampler; 22 | class BatchTransform; 23 | class MNN_PUBLIC DataLoader { 24 | public: 25 | DataLoader(std::shared_ptr dataset, std::shared_ptr sampler, 26 | std::shared_ptr config); 27 | /* 28 | When use Windows v141 toolset to compile class having vector of non-copyable element (std::thread, for example), 29 | copy constructor (or assignment operator) must be deleted explicity, otherwise compile will failed. 30 | */ 31 | DataLoader(const DataLoader&) = delete; 32 | DataLoader& operator = (const DataLoader&) = delete; 33 | 34 | virtual ~DataLoader() { 35 | join(); 36 | }; 37 | 38 | void prefetch(size_t nJobs); 39 | 40 | void workerThread(); 41 | 42 | void join(); 43 | 44 | std::vector next(); 45 | 46 | void reset(); 47 | 48 | void clean(); 49 | 50 | size_t iterNumber() const; 51 | size_t size() const; 52 | static DataLoader* makeDataLoader(std::shared_ptr dataset, 53 | const int batchSize, 54 | const bool stack = true, 55 | const bool shuffle = true, 56 | const int numWorkers = 0); 57 | static DataLoader* makeDataLoader(std::shared_ptr dataset, 58 | std::vector> transforms, 59 | const int batchSize, 60 | const bool shuffle = true, 61 | const int numWorkers = 0); 62 | 63 | private: 64 | struct Job { 65 | std::vector job; 66 | bool quit = false; 67 | }; 68 | std::shared_ptr mDataset; 69 | std::shared_ptr mSampler; 70 | std::shared_ptr mConfig; 71 | std::shared_ptr> mJobs; 72 | std::shared_ptr>> mDataQueue; 73 | std::vector mWorkers; 74 | }; 75 | 76 | } // namespace Train 77 | } // namespace MNN 78 | 79 | #endif // DataLoader_hpp 80 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/DataLoaderConfig.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // DataLoaderConfig.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/15. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef DataLoaderConfig_hpp 10 | #define DataLoaderConfig_hpp 11 | #include 12 | namespace MNN { 13 | namespace Train { 14 | 15 | class MNN_PUBLIC DataLoaderConfig { 16 | public: 17 | DataLoaderConfig() = default; 18 | DataLoaderConfig(size_t batchSize, size_t nWorkers = 0) : batchSize(batchSize), numWorkers(nWorkers) { 19 | } 20 | 21 | size_t batchSize = 1; 22 | size_t numWorkers = 0; 23 | size_t numJobs = numWorkers * 2; 24 | bool dropLast = false; 25 | }; 26 | 27 | } // namespace Train 28 | } // namespace MNN 29 | 30 | #endif // DataLoaderConfig 31 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/Dataset.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Dataset.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/14. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef Dataset_hpp 10 | #define Dataset_hpp 11 | 12 | #include 13 | #include 14 | #include "Example.hpp" 15 | #include "DataLoader.hpp" 16 | 17 | namespace MNN { 18 | namespace Train { 19 | struct MNN_PUBLIC DatasetPtr { 20 | public: 21 | std::shared_ptr mDataset; 22 | 23 | DataLoader* createLoader( 24 | const int batchSize, 25 | const bool stack = true, 26 | const bool shuffle = true, 27 | const int numWorkers = 0); 28 | ~ DatasetPtr() = default; 29 | template 30 | T* get() const { 31 | return (T*)mDataset.get(); 32 | } 33 | }; 34 | 35 | class MNN_PUBLIC BatchDataset { 36 | public: 37 | virtual ~BatchDataset() = default; 38 | 39 | // get batch using given indices 40 | virtual std::vector getBatch(std::vector indices) = 0; 41 | 42 | // size of the dataset 43 | virtual size_t size() = 0; 44 | }; 45 | 46 | class MNN_PUBLIC Dataset : public BatchDataset { 47 | public: 48 | // return a specific example with given index 49 | virtual Example get(size_t index) = 0; 50 | 51 | std::vector getBatch(std::vector indices) { 52 | std::vector batch; 53 | batch.reserve(indices.size()); 54 | for (const auto i : indices) { 55 | batch.emplace_back(get(i)); 56 | } 57 | MNN_ASSERT(batch.size() != 0); 58 | return batch; 59 | } 60 | }; 61 | 62 | } // namespace Train 63 | } // namespace MNN 64 | 65 | #endif /* Dataset_hpp */ 66 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/Distributions.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Distributions.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/28. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef Distributions_hpp 10 | #define Distributions_hpp 11 | 12 | #include 13 | #include 14 | 15 | namespace MNN { 16 | namespace Train { 17 | 18 | class Distributions { 19 | public: 20 | static void uniform(const int count, const float min, const float max, float* r, std::mt19937 gen); 21 | static void gaussian(const int count, const float mu, const float sigma, float* r, std::mt19937 gen); 22 | }; 23 | 24 | } // namespace Train 25 | } // namespace MNN 26 | 27 | #endif // Distritutions_hpp 28 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/Example.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Example.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/14. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef Example_hpp 10 | #define Example_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | using namespace MNN::Express; 17 | 18 | namespace MNN { 19 | namespace Train { 20 | /** 21 | First: data: a vector of input tensors (for single input dataset is only one) 22 | Second: target: a vector of output tensors (for single output dataset is only one) 23 | */ 24 | typedef std::pair, std::vector> Example; 25 | 26 | } // namespace Train 27 | } // namespace MNN 28 | 29 | #endif /* Example_hpp */ 30 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/ImageDataset.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ImageDataset.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/12/30. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef ImageDataset_hpp 10 | #define ImageDataset_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | #include "Dataset.hpp" 16 | #include "Example.hpp" 17 | #include 18 | 19 | // 20 | // the ImageDataset read stored images as input data. 21 | // use 'pathToImages' and a txt file to construct a ImageDataset. 22 | // the txt file should use format as below: 23 | // image1.jpg label1,label2,... 24 | // image2.jpg label3,label4,... 25 | // ... 26 | // the ImageDataset would read images from: 27 | // pathToImages/image1.jpg 28 | // pathToImages/image2.jpg 29 | // ... 30 | // 31 | 32 | namespace MNN { 33 | namespace Train { 34 | class MNN_PUBLIC ImageDataset : public Dataset { 35 | public: 36 | class ImageConfig { 37 | public: 38 | static ImageConfig* create(CV::ImageFormat destFmt = CV::GRAY, int resizeH = 0, int resizeW = 0, 39 | std::vector s = {1, 1, 1, 1}, std::vector m = {0, 0, 0, 0}, 40 | std::vector cropFract = {1/*height*/, 1/*width*/}, const bool centerOrRandom = false/*false:center*/) { 41 | auto config = new ImageConfig; 42 | config->destFormat = destFmt; 43 | config->resizeHeight = resizeH; 44 | config->resizeWidth = resizeW; 45 | config->scale = s; 46 | config->mean = m; 47 | MNN_ASSERT(cropFract.size() == 2); 48 | MNN_ASSERT(cropFract[0] > 0 && cropFract[0] <= 1); 49 | MNN_ASSERT(cropFract[1] > 0 && cropFract[1] <= 1); 50 | config->cropFraction = cropFract; 51 | config->centerOrRandomCrop = centerOrRandom; 52 | return config; 53 | } 54 | CV::ImageFormat destFormat; 55 | int resizeHeight; 56 | int resizeWidth; 57 | std::vector scale; 58 | std::vector mean; 59 | std::vector cropFraction; 60 | bool centerOrRandomCrop; 61 | }; 62 | 63 | static DatasetPtr create(const std::string pathToImages, const std::string pathToImageTxt, 64 | const ImageConfig* cfg, bool readAllToMemory = false); 65 | static Express::VARP convertImage(const std::string& imageName, const ImageConfig& config, const MNN::CV::ImageProcess::Config& cvConfig); 66 | 67 | Example get(size_t index) override; 68 | 69 | size_t size() override; 70 | 71 | private: 72 | ImageDataset(){} 73 | bool mReadAllToMemory; 74 | std::vector > > mAllTxtLines; 75 | std::vector > mDataAndLabels; 76 | ImageConfig mConfig; 77 | MNN::CV::ImageProcess::Config mProcessConfig; 78 | 79 | void getAllDataAndLabelsFromTxt(const std::string pathToImages, std::string pathToImageTxt); 80 | std::pair getDataAndLabelsFrom(std::pair > dataAndLabels); 81 | }; 82 | } // namespace Train 83 | } // namespace MNN 84 | 85 | #endif // ImageDataset_hpp 86 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/LearningRateScheduler.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // LearningRateScheduler.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/12/03. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef LearningRateScheduler_hpp 10 | #define LearningRateScheduler_hpp 11 | 12 | #include 13 | #include 14 | 15 | namespace MNN { 16 | namespace Train { 17 | 18 | class MNN_PUBLIC LrScheduler { 19 | public: 20 | static float multiStep(const float baseLr, const int step, std::vector stepIterations, 21 | std::vector lrMulti); 22 | 23 | static float inv(const float baseLr, const int step, const float gamma, const float power); 24 | 25 | static float exp(const float baseLr, const int step, const float gamma); 26 | }; 27 | 28 | } // namespace Train 29 | } // namespace MNN 30 | 31 | #endif // LearningRateScheduler_hpp 32 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/Lenet.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Lenet.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2020/01/10. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef LenetModels_hpp 10 | #define LenetModels_hpp 11 | 12 | #include "Module.hpp" 13 | #include "NN.hpp" 14 | 15 | namespace MNN { 16 | namespace Train { 17 | namespace Model { 18 | 19 | class MNN_PUBLIC Lenet : public Module { 20 | public: 21 | Lenet(); 22 | 23 | virtual std::vector onForward(const std::vector& inputs) override; 24 | 25 | std::shared_ptr conv1; 26 | std::shared_ptr conv2; 27 | std::shared_ptr ip1; 28 | std::shared_ptr ip2; 29 | std::shared_ptr dropout; 30 | }; 31 | 32 | } // namespace Model 33 | } // namespace Train 34 | } // namespace MNN 35 | 36 | #endif // LenetModels_hpp 37 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/Loss.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Loss.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/29. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef Loss_hpp 10 | #define Loss_hpp 11 | 12 | #include 13 | 14 | namespace MNN { 15 | namespace Train { 16 | 17 | MNN_PUBLIC Express::VARP _CrossEntropy(Express::VARP predicts, Express::VARP oneHotTargets); 18 | 19 | MNN_PUBLIC Express::VARP _KLDivergence(Express::VARP predicts, Express::VARP oneHotTargets); 20 | 21 | MNN_PUBLIC Express::VARP _MSE(Express::VARP predicts, Express::VARP oneHotTargets); 22 | 23 | MNN_PUBLIC Express::VARP _MAE(Express::VARP predicts, Express::VARP oneHotTargets); 24 | 25 | MNN_PUBLIC Express::VARP _Hinge(Express::VARP predicts, Express::VARP oneHotTargets); 26 | 27 | MNN_PUBLIC Express::VARP _DistillLoss(Express::VARP studentLogits, Express::VARP teacherLogits, Express::VARP oneHotTargets, 28 | const float temperature, const float alpha); 29 | 30 | } // namespace Train 31 | } // namespace MNN 32 | 33 | #endif // Loss_hpp 34 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/MnistDataset.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // MnistDataset.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/15. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MnistDataset_hpp 10 | #define MnistDataset_hpp 11 | 12 | #include 13 | #include "Dataset.hpp" 14 | #include "Example.hpp" 15 | 16 | namespace MNN { 17 | namespace Train { 18 | class MNN_PUBLIC MnistDataset : public Dataset { 19 | public: 20 | enum Mode { TRAIN, TEST }; 21 | 22 | Example get(size_t index) override; 23 | 24 | size_t size() override; 25 | 26 | const VARP images(); 27 | 28 | const VARP labels(); 29 | 30 | static DatasetPtr create(const std::string path, Mode mode = Mode::TRAIN); 31 | private: 32 | explicit MnistDataset(const std::string path, Mode mode = Mode::TRAIN); 33 | VARP mImages, mLabels; 34 | const uint8_t* mImagePtr = nullptr; 35 | const uint8_t* mLabelsPtr = nullptr; 36 | }; 37 | } 38 | } 39 | 40 | 41 | #endif // MnistDataset_hpp 42 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/Module.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Module.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/25. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_Train_Module_hpp 10 | #define MNN_Train_Module_hpp 11 | #include 12 | namespace MNN { 13 | namespace Train { 14 | class MNN_PUBLIC Module { 15 | public: 16 | Module() = default; 17 | virtual ~Module() = default; 18 | virtual std::vector onForward(const std::vector& inputs) = 0; 19 | Express::VARP forward(Express::VARP input); 20 | std::vector parameters() const; 21 | bool loadParameters(const std::vector& parameters); 22 | void setIsTraining(const bool isTraining); 23 | bool getIsTraining(); 24 | static std::shared_ptr transform(const std::vector& inputs, 25 | const std::vector& outputs); 26 | 27 | void clearCache(); 28 | 29 | const std::string& name() const { 30 | return mName; 31 | }; 32 | void setName(std::string name) { 33 | mName = std::move(name); 34 | } 35 | const std::string type() const { 36 | return mType; 37 | } 38 | void setType(std::string type) { 39 | mType = std::move(type); 40 | } 41 | protected: 42 | void registerModel(const std::vector>& children); 43 | void addParameter(Express::VARP parameter); 44 | virtual void onClearCache() { 45 | } 46 | 47 | private: 48 | void _collectParameters(std::vector& result) const; 49 | std::vector> mChildren; 50 | std::vector mParameters; 51 | bool mIsTraining = true; 52 | std::string mName; 53 | std::string mType; 54 | }; 55 | } // namespace Train 56 | } // namespace MNN 57 | 58 | #endif 59 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/NN.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // NN.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/25. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_Train_NN_hpp 10 | #define MNN_Train_NN_hpp 11 | #include 12 | #include "Distributions.hpp" 13 | #include "Module.hpp" 14 | #include 15 | namespace MNN { 16 | namespace Train { 17 | class Initializer; 18 | 19 | class MNN_PUBLIC NN { 20 | public: 21 | enum ActivationFunctionType { 22 | None = 0, 23 | Relu = 1, 24 | Relu6 = 2, 25 | }; 26 | enum ScaleUpdateMethod { 27 | Maximum = 0, 28 | MovingAverage = 1 29 | }; 30 | enum FeatureScaleStatMethod { 31 | PerTensor = 0, 32 | PerChannel = 1 33 | }; 34 | /* Unlike enum in class, class in class need be dllimport or dllexport explcility. 35 | Compiling in other system will not be affected. 36 | */ 37 | struct MNN_PUBLIC ConvOption { 38 | Express::INTS kernelSize = {1, 1}; 39 | Express::INTS channel = {0, 0}; 40 | Express::INTS stride = {1, 1}; 41 | Express::INTS dilate = {1, 1}; 42 | Express::PaddingMode padMode = Express::VALID; 43 | Express::INTS pads = {0, 0}; 44 | bool depthwise = false; 45 | ActivationFunctionType fusedActivationFunction = None; 46 | void reset(int size = 2); 47 | }; 48 | static Module* Conv(const ConvOption& option, bool bias = true, 49 | std::shared_ptr weightInit = nullptr, 50 | std::shared_ptr biasInit = nullptr); 51 | static Module* ConvTranspose(const ConvOption& option, bool bias = true, 52 | std::shared_ptr weightInit = nullptr, 53 | std::shared_ptr biasInit = nullptr); 54 | static Module* Linear(int l, int t, bool hasBias = true, 55 | std::shared_ptr weightInit = nullptr, 56 | std::shared_ptr biasInit = nullptr); 57 | static Module* Dropout(const float dropRatio); 58 | static Module* BatchNorm(const int channels, const int dims = 4, const float m = 0.999, 59 | const float e = 1e-5); 60 | 61 | static Module* ConvInt8(const ConvOption& option, int bits = 8, bool bias = true, 62 | std::shared_ptr weightInit = nullptr, 63 | std::shared_ptr biasInit = nullptr, 64 | FeatureScaleStatMethod featureMethod = PerChannel, 65 | ScaleUpdateMethod method = MovingAverage 66 | ); 67 | struct ConvParameters { 68 | ConvOption option; 69 | Express::VARP weight; 70 | Express::VARP bias; 71 | int group; 72 | std::string name; 73 | }; 74 | static Module* ConvInt8(const ConvParameters& parameters, int bits, 75 | FeatureScaleStatMethod featureMethod = PerChannel, 76 | ScaleUpdateMethod method = MovingAverage); 77 | static Module* ConvOctave(const ConvParameters& parameters, float inFactor, float outFactor); 78 | static Module* Conv(const ConvParameters& parameters); 79 | static Module* ConvBNReluFused(std::vector > modules, 80 | NN::FeatureScaleStatMethod featureScaleStatMethod = PerTensor, 81 | NN::ScaleUpdateMethod scaleUpdateMethod = MovingAverage, const int bits = 8); 82 | 83 | class Utils { 84 | public: 85 | // ConvOption, Weight, Bias, Group 86 | static ConvParameters ExtractConvolution(Express::EXPRP expr); 87 | 88 | // Extract BatchNormal and Dropout 89 | static Module* ExtractNotRunableOp(Express::EXPRP expr); 90 | }; 91 | }; 92 | 93 | } // namespace Train 94 | } // namespace MNN 95 | 96 | #endif 97 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/ParameterOptimizer.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ParameterOptimizer.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/22. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef ParameterOptimizer_hpp 10 | #define ParameterOptimizer_hpp 11 | #include 12 | #include 13 | namespace MNN { 14 | namespace Train { 15 | 16 | class MNN_PUBLIC ParameterOptimizer { 17 | public: 18 | enum RegularizationMethod { 19 | L1, 20 | L2, 21 | L1L2, 22 | }; 23 | 24 | ParameterOptimizer() = default; 25 | virtual ~ParameterOptimizer() = default; 26 | bool step(Express::VARP loss); 27 | int currentStep(); 28 | void setCurrentStep(int step); 29 | void append(const std::vector& parameters); 30 | void remove(const std::vector& parameters); 31 | 32 | virtual std::map onGetNextParameter(Express::VARP loss) = 0; 33 | const std::set& parameters() const; 34 | 35 | static ParameterOptimizer* createSGD(float lr, float momentum, float weightDecay, RegularizationMethod method); 36 | static ParameterOptimizer* createADAM(float lr, float momentum, float momentum2, float weightDecay, float eps, RegularizationMethod method); 37 | private: 38 | virtual void onAppend(Express::VARP parameter) = 0; 39 | virtual void onRemove(Express::VARP parameter) = 0; 40 | std::set mParameters; 41 | int mStep = 0; 42 | }; 43 | 44 | } // namespace Train 45 | } // namespace MNN 46 | 47 | #endif 48 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/PipelineModule.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // PipelineModule.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2020/01/09. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef PipelineModule_hpp 10 | #define PipelineModule_hpp 11 | #include "Module.hpp" 12 | #include "NN.hpp" 13 | #include 14 | namespace MNN { 15 | namespace Train { 16 | 17 | class MNN_PUBLIC PipelineModule : public Module { 18 | public: 19 | typedef std::function, std::shared_ptr>(Express::EXPRP)> Transformer; 20 | static Module* extract(std::vector inputs, std::vector outputs, bool fortrain); 21 | static bool turnQuantize(Module* module, const int bits = 8, NN::FeatureScaleStatMethod featureScaleStatMethod = NN::PerTensor, NN::ScaleUpdateMethod scaleUpdateMethod = NN::MovingAverage); 22 | void toTrainQuant(const int bits = 8, NN::FeatureScaleStatMethod featureScaleStatMethod = NN::PerTensor, 23 | NN::ScaleUpdateMethod scaleUpdateMethod = NN::MovingAverage); 24 | virtual std::vector onForward(const std::vector& inputs) override; 25 | virtual void onClearCache() override; 26 | std::vector countOutputReference(std::vector outputIndices); 27 | 28 | private: 29 | PipelineModule(std::vector inputs, std::vector outputs, 30 | const Transformer& transformFunction = {}); 31 | std::vector, std::vector, std::vector>> mSubModules; 32 | std::vector mStack; 33 | std::vector mInputIndexes; 34 | std::vector mOutputIndexes; 35 | }; 36 | } // namespace Train 37 | } // namespace MNN 38 | 39 | #endif 40 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/SGD.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // SGD.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/22. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef SGD_hpp 10 | #define SGD_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | #include "ParameterOptimizer.hpp" 16 | 17 | namespace MNN { 18 | namespace Train { 19 | 20 | class MNN_PUBLIC SGD : public ParameterOptimizer { 21 | public: 22 | virtual std::map onGetNextParameter(Express::VARP loss) override; 23 | 24 | Express::VARP regularizeParameters(Express::VARP param, Express::VARP grad); 25 | 26 | virtual Express::VARP onComputeUpdateValue(Express::VARP param, Express::VARP grad); 27 | 28 | void setLearningRate(float rate); 29 | 30 | float getMomentum(); 31 | 32 | void setMomentum(float momentum); 33 | 34 | float getWeightDecay(); 35 | 36 | void setWeightDecay(float decay); 37 | 38 | RegularizationMethod getRegularizationMethod(); 39 | 40 | void setRegularizationMethod(RegularizationMethod method); 41 | 42 | float currentLearningRate(); 43 | 44 | virtual void onAppend(Express::VARP parameters) override; 45 | 46 | virtual void onRemove(Express::VARP parameters) override; 47 | 48 | void setGradBlockName(std::string block) { 49 | mGradBlockExprName = std::move(block); 50 | } 51 | 52 | protected: 53 | float mLearningRate = 0.001f; 54 | float mMomentum = 0; 55 | float mWeightDecay = 0; 56 | RegularizationMethod mRegularizationMethod = L2; 57 | std::map mHistory; 58 | 59 | // For Cache 60 | const Express::Expr* mLoss = nullptr; 61 | int mLossFromIndex = 0; 62 | std::string mGradBlockExprName; 63 | }; 64 | 65 | } // namespace Train 66 | } // namespace MNN 67 | 68 | #endif // SGD_hpp 69 | -------------------------------------------------------------------------------- /Client-Android/app/includes/MNNTrain/Transformer.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Transformer.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/12/16. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef Transformer_hpp 10 | #define Transformer_hpp 11 | #include 12 | 13 | namespace MNN { 14 | namespace Train { 15 | class MNN_PUBLIC Transformer { 16 | public: 17 | struct TrainConfig { 18 | std::vector variableLimits; 19 | }; 20 | 21 | static std::shared_ptr turnModelToTrainable(TrainConfig config); 22 | static std::shared_ptr turnModelToInfer(); 23 | }; 24 | } // namespace Train 25 | } // namespace MNN 26 | #endif 27 | -------------------------------------------------------------------------------- /Client-Android/app/libs/arm64-v8a/libMNN.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/Client-Android/app/libs/arm64-v8a/libMNN.so -------------------------------------------------------------------------------- /Client-Android/app/libs/arm64-v8a/libMNNTrain.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/Client-Android/app/libs/arm64-v8a/libMNNTrain.so -------------------------------------------------------------------------------- /Client-Android/app/libs/arm64-v8a/libMNN_Express.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/Client-Android/app/libs/arm64-v8a/libMNN_Express.so -------------------------------------------------------------------------------- /Client-Android/app/proguard-rules.pro: -------------------------------------------------------------------------------- 1 | # Add project specific ProGuard rules here. 2 | # You can control the set of applied configuration files using the 3 | # proguardFiles setting in build.gradle. 4 | # 5 | # For more details, see 6 | # http://developer.android.com/guide/developing/tools/proguard.html 7 | 8 | # If your project uses WebView with JS, uncomment the following 9 | # and specify the fully qualified class name to the JavaScript interface 10 | # class: 11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview { 12 | # public *; 13 | #} 14 | 15 | # Uncomment this to preserve the line number information for 16 | # debugging stack traces. 17 | #-keepattributes SourceFile,LineNumberTable 18 | 19 | # If you keep the line number information, uncomment this to 20 | # hide the original source file name. 21 | #-renamesourcefileattribute SourceFile -------------------------------------------------------------------------------- /Client-Android/app/src/androidTest/java/com/example/websocket/ExampleInstrumentedTest.java: -------------------------------------------------------------------------------- 1 | package com.example.websocket; 2 | 3 | import android.content.Context; 4 | 5 | import androidx.test.platform.app.InstrumentationRegistry; 6 | import androidx.test.ext.junit.runners.AndroidJUnit4; 7 | 8 | import org.junit.Test; 9 | import org.junit.runner.RunWith; 10 | 11 | import static org.junit.Assert.*; 12 | 13 | /** 14 | * Instrumented test, which will execute on an Android device. 15 | * 16 | * @see Testing documentation 17 | */ 18 | @RunWith(AndroidJUnit4.class) 19 | public class ExampleInstrumentedTest { 20 | @Test 21 | public void useAppContext() { 22 | // Context of the app under test. 23 | Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); 24 | assertEquals("com.example.websocket", appContext.getPackageName()); 25 | } 26 | } -------------------------------------------------------------------------------- /Client-Android/app/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 8 | 9 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/demo/App.java: -------------------------------------------------------------------------------- 1 | package com.demo; 2 | 3 | import android.app.Application; 4 | 5 | 6 | public class App extends Application { 7 | 8 | 9 | private static App app; 10 | 11 | @Override 12 | public void onCreate() { 13 | super.onCreate(); 14 | app = this; 15 | } 16 | 17 | 18 | public static App getApp() { 19 | return app; 20 | } 21 | 22 | @Override 23 | public void onTerminate() { 24 | super.onTerminate(); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/demo/EmptyActivity.java: -------------------------------------------------------------------------------- 1 | package com.demo; 2 | 3 | import android.os.Bundle; 4 | 5 | import androidx.appcompat.app.AppCompatActivity; 6 | 7 | public class EmptyActivity extends AppCompatActivity { 8 | 9 | @Override 10 | protected void onCreate(Bundle savedInstanceState) { 11 | super.onCreate(savedInstanceState); 12 | setContentView(R.layout.activity_empty); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/demo/MnnTrainFragment.java: -------------------------------------------------------------------------------- 1 | package com.demo; 2 | 3 | import android.os.Bundle; 4 | import android.view.LayoutInflater; 5 | import android.view.View; 6 | import android.view.ViewGroup; 7 | import android.widget.TextView; 8 | 9 | import androidx.fragment.app.Fragment; 10 | 11 | /** 12 | * A simple {@link Fragment} subclass. 13 | * Use the {@link MnnTrainFragment#newInstance} factory method to 14 | * create an instance of this fragment. 15 | */ 16 | public class MnnTrainFragment extends Fragment { 17 | 18 | // TODO: Rename parameter arguments, choose names that match 19 | // the fragment initialization parameters, e.g. ARG_ITEM_NUMBER 20 | private static final String ARG_PARAM1 = "param1"; 21 | private static final String ARG_PARAM2 = "param2"; 22 | 23 | // TODO: Rename and change types of parameters 24 | private String mParam1; 25 | private String mParam2; 26 | private String trainInfo; 27 | 28 | private TextView stateText; 29 | 30 | public MnnTrainFragment() { 31 | // Required empty public constructor 32 | } 33 | 34 | /** 35 | * Use this factory method to create a new instance of 36 | * this fragment using the provided parameters. 37 | * 38 | * @param param1 Parameter 1. 39 | * @param param2 Parameter 2. 40 | * @return A new instance of fragment MnnTrainFragment. 41 | */ 42 | // TODO: Rename and change types and number of parameters 43 | public static MnnTrainFragment newInstance(String param1, String param2) { 44 | MnnTrainFragment fragment = new MnnTrainFragment(); 45 | Bundle args = new Bundle(); 46 | args.putString(ARG_PARAM1, param1); 47 | args.putString(ARG_PARAM2, param2); 48 | fragment.setArguments(args); 49 | return fragment; 50 | } 51 | 52 | @Override 53 | public void onCreate(Bundle savedInstanceState) { 54 | super.onCreate(savedInstanceState); 55 | if (getArguments() != null) { 56 | mParam1 = getArguments().getString(ARG_PARAM1); 57 | mParam2 = getArguments().getString(ARG_PARAM2); 58 | } 59 | 60 | //开启线程,动态展示设备状态 61 | 62 | // new Thread(new Runnable() { 63 | // @Override 64 | // public void run() { 65 | // try { 66 | // //需要在子线程中处理的逻辑 67 | // stateText.setText(trainInfo); 68 | // Thread.sleep(100); 69 | // } catch (InterruptedException e) { 70 | // e.printStackTrace(); 71 | // } 72 | // 73 | // } 74 | // }).start(); 75 | 76 | } 77 | 78 | @Override 79 | public View onCreateView(LayoutInflater inflater, ViewGroup container, 80 | Bundle savedInstanceState) { 81 | // Inflate the layout for this fragment 82 | 83 | 84 | 85 | return inflater.inflate(R.layout.fragment_mnn_train, container, false); 86 | } 87 | 88 | 89 | // @Override 90 | // public void onAttach(Activity activity) { 91 | // super.onAttach(activity); 92 | // trainInfo = ((MainActivity) activity).getTrainInfo();//通过强转成宿主activity,就可以获取到传递过来的数据 93 | // } 94 | 95 | } -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/MainActivity_test.java: -------------------------------------------------------------------------------- 1 | package com.example; 2 | 3 | import androidx.appcompat.app.AppCompatActivity; 4 | 5 | import android.content.Context; 6 | import android.os.Bundle; 7 | 8 | import com.demo.R; 9 | import com.example.websocket.service.MnnTrainService; 10 | 11 | public class MainActivity_test extends AppCompatActivity { 12 | 13 | MnnTrainService mnnTrainService = new MnnTrainService(); 14 | 15 | @Override 16 | protected void onCreate(Bundle savedInstanceState) { 17 | super.onCreate(savedInstanceState); 18 | setContentView(R.layout.activity_main); 19 | 20 | mnnTrainService.onCreate(); 21 | 22 | } 23 | 24 | 25 | } -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/nativemnn/mnn/MNNDataNative.java: -------------------------------------------------------------------------------- 1 | package com.example.nativemnn.mnn; 2 | 3 | import android.graphics.Bitmap; 4 | import android.util.Log; 5 | 6 | import com.example.nativemnn.utils.Common; 7 | 8 | public class MNNDataNative { 9 | // load libraries 10 | static void loadGpuLibrary(String name) { 11 | try { 12 | System.loadLibrary(name); 13 | } catch (Throwable ce) { 14 | Log.w(Common.TAG, "load MNNTrain " + name + " GPU so exception=%s", ce); 15 | } 16 | } 17 | // load mnn library 18 | static { 19 | System.loadLibrary("MNNTrain"); 20 | System.loadLibrary("MNN_Express"); 21 | System.loadLibrary("MNN"); 22 | System.loadLibrary("mnncore"); 23 | } 24 | 25 | /** 26 | * We call the training module of C++ to complete training task, and obtain the relevant information required by Android UI 27 | * @param modelCachePath the storage path of model in Android 28 | * @param dataCachePath the storage path of training dataset in Android 29 | * @return result of client training with total local epochs (format: "loss,trainSamples,accuracy,testSamples") 30 | */ 31 | public static native String nativeCreateDatasetFromFile(String modelCachePath, String dataCachePath); 32 | /* 33 | Return: current epoch and the loss value in this epoch (format: "epoch,loss") 34 | */ 35 | 36 | /** 37 | * 38 | * @return the local epoch index in each global epoch training, and the training loss in this local epoch 39 | */ 40 | public static native String getEpochAndLoss(); 41 | 42 | } 43 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/nativemnn/utils/Common.java: -------------------------------------------------------------------------------- 1 | package com.example.nativemnn.utils; 2 | 3 | import android.content.Context; 4 | 5 | import java.io.File; 6 | import java.io.FileOutputStream; 7 | import java.io.IOException; 8 | import java.io.InputStream; 9 | 10 | public class Common { 11 | public static String TAG = "MNNDemo"; 12 | 13 | public static void copyAssetResource2File(Context context, String assetsFile, String outFile) throws IOException { 14 | InputStream is = context.getAssets().open(assetsFile); 15 | File outF = new File(outFile); 16 | FileOutputStream fos = new FileOutputStream(outF); 17 | 18 | int byteCount; 19 | byte[] buffer = new byte[1024]; 20 | while ((byteCount = is.read(buffer)) != -1) { 21 | fos.write(buffer, 0, byteCount); 22 | } 23 | fos.flush(); 24 | is.close(); 25 | fos.close(); 26 | outF.setReadable(true); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/websocket/client/ClientWebSocketListener.java: -------------------------------------------------------------------------------- 1 | package com.example.websocket.client; 2 | 3 | import android.util.Log; 4 | 5 | import androidx.annotation.Nullable; 6 | 7 | import org.jetbrains.annotations.NotNull; 8 | 9 | import okhttp3.OkHttpClient; 10 | import okhttp3.Request; 11 | import okhttp3.Response; 12 | import okhttp3.WebSocket; 13 | import okhttp3.WebSocketListener; 14 | import okio.ByteString; 15 | 16 | class ClientWebSocketListener extends WebSocketListener { 17 | 18 | @Override 19 | public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response){ 20 | 21 | } 22 | 23 | @Override 24 | public void onClosing(@NotNull WebSocket webSocket, int code, @NotNull String reason){ 25 | 26 | } 27 | 28 | @Override 29 | public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) { 30 | 31 | } 32 | 33 | 34 | @Override 35 | public void onMessage (@NotNull WebSocket webSocket, @NotNull String text){ 36 | 37 | } 38 | 39 | @Override 40 | public void onMessage(@NotNull WebSocket webSocket, ByteString bytes) { 41 | 42 | } 43 | 44 | @Override 45 | public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, @Nullable Response response) { 46 | 47 | } 48 | 49 | 50 | 51 | } 52 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/websocket/constants/Constants.java: -------------------------------------------------------------------------------- 1 | package com.example.websocket.constants; 2 | 3 | public class Constants { 4 | 5 | public static final String TAG = "websocketclient"; 6 | public static final String SERVER_URL = "ws://10.28.241.114:8080/ws"; 7 | //public static final String dir = Environment.getExternalStorageDirectory().toString(); 8 | //public static final String TRAIN_MODEL_FILE_PATH = dir+"/mnn_data/mnist.snapshot.mnn"; 9 | public static final String TRAIN_MODEL_FILE_PATH = "/data/local/tmp/mnn/mnist.snapshot.mnn"; 10 | public static final String TRAIN_DATA_FILE_PATH = "/data/local/tmp/mnist_data"; 11 | 12 | } 13 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/websocket/service/ClientWebSocketService.java: -------------------------------------------------------------------------------- 1 | package com.example.websocket.service; 2 | 3 | import android.app.Service; 4 | import android.content.Intent; 5 | import android.os.IBinder; 6 | 7 | 8 | public class ClientWebSocketService extends Service { 9 | 10 | /** 绑定的客户端接口 */ 11 | 12 | public ClientWebSocketService(){ 13 | 14 | } 15 | 16 | @Override 17 | public void onCreate() { 18 | super.onCreate(); 19 | } 20 | 21 | @Override 22 | public int onStartCommand(Intent intent, int flags, int startId) { 23 | return super.onStartCommand(intent, flags, startId); 24 | } 25 | 26 | 27 | @Override 28 | public void onDestroy() { 29 | 30 | } 31 | 32 | @Override 33 | public IBinder onBind(Intent intent) { 34 | // TODO: Return the communication channel to the service. 35 | throw new UnsupportedOperationException("Not yet implemented"); 36 | } 37 | 38 | 39 | 40 | 41 | } 42 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/websocket/service/MnnTrainService.java: -------------------------------------------------------------------------------- 1 | package com.example.websocket.service; 2 | 3 | import android.app.Service; 4 | import android.content.Intent; 5 | import android.os.IBinder; 6 | 7 | import com.example.nativemnn.mnn.MNNDataNative; 8 | import static com.example.websocket.constants.Constants.TRAIN_DATA_FILE_PATH; 9 | import static com.example.websocket.constants.Constants.TRAIN_MODEL_FILE_PATH; 10 | 11 | public class MnnTrainService extends Service { 12 | private MNNDataNative mnnDataNative; 13 | 14 | public MnnTrainService(){ 15 | 16 | } 17 | 18 | @Override 19 | public void onCreate() { 20 | super.onCreate(); 21 | 22 | mnnTrainNative(); 23 | 24 | } 25 | 26 | @Override 27 | public int onStartCommand(Intent intent, int flags, int startId) { 28 | return super.onStartCommand(intent, flags, startId); 29 | } 30 | 31 | 32 | @Override 33 | public void onDestroy() { 34 | 35 | } 36 | 37 | @Override 38 | public IBinder onBind(Intent intent) { 39 | // TODO: Return the communication channel to the service. 40 | throw new UnsupportedOperationException("Not yet implemented"); 41 | } 42 | 43 | public void mnnTrainNative(){ 44 | String result = mnnDataNative.nativeCreateDatasetFromFile(TRAIN_MODEL_FILE_PATH,TRAIN_DATA_FILE_PATH); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/websocket/utils/CommonUtil.java: -------------------------------------------------------------------------------- 1 | package com.example.websocket.utils; 2 | 3 | class CommonUtil { 4 | //byte 数组与 int 的相互转换 5 | public static int byteArrayToInt(byte[] b) { 6 | return b[3] & 0xFF | 7 | (b[2] & 0xFF) << 8 | 8 | (b[1] & 0xFF) << 16 | 9 | (b[0] & 0xFF) << 24; 10 | } 11 | 12 | public static byte[] intToByteArray(int a) { 13 | return new byte[] { 14 | (byte) ((a >> 24) & 0xFF), 15 | (byte) ((a >> 16) & 0xFF), 16 | (byte) ((a >> 8) & 0xFF), 17 | (byte) (a & 0xFF) 18 | }; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/websocket/utils/DeviceUtil.java: -------------------------------------------------------------------------------- 1 | package com.example.websocket.utils; 2 | 3 | import android.content.Context; 4 | import android.net.wifi.WifiInfo; 5 | import android.net.wifi.WifiManager; 6 | 7 | import com.example.MainActivity_test; 8 | 9 | public class DeviceUtil { 10 | 11 | public String getIpFromWifi(Context context) { 12 | //获取wifi服务 13 | WifiManager wifiManager = (WifiManager) context.getSystemService(Context.WIFI_SERVICE); 14 | //判断wifi是否开启 15 | if (!wifiManager.isWifiEnabled()) { 16 | wifiManager.setWifiEnabled(true); 17 | } 18 | WifiInfo wifiInfo = wifiManager.getConnectionInfo(); 19 | int ipAddress = wifiInfo.getIpAddress(); 20 | String ip = intToIp(ipAddress); 21 | return ip; 22 | } 23 | 24 | //获取Wifi ip 地址 25 | private String intToIp(int ipAddress) { 26 | return (ipAddress & 0xFF) + "." + 27 | ((ipAddress >> 8) & 0xFF) + "." + 28 | ((ipAddress >> 16) & 0xFF) + "." + 29 | (ipAddress >> 24 & 0xFF); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/java/com/example/websocket/utils/FileUtil.java: -------------------------------------------------------------------------------- 1 | package com.example.websocket.utils; 2 | 3 | import android.util.Log; 4 | 5 | import java.io.File; 6 | import java.io.FileInputStream; 7 | import java.io.FileOutputStream; 8 | import java.io.IOException; 9 | import java.nio.ByteBuffer; 10 | import java.nio.channels.FileChannel; 11 | 12 | import static com.example.websocket.constants.Constants.TAG; 13 | 14 | public class FileUtil { 15 | 16 | public static byte[] getFileContent(String filePath) throws IOException { 17 | File file = new File(filePath); 18 | 19 | long fileSize = file.length(); 20 | if (fileSize > Integer.MAX_VALUE) { 21 | Log.d(TAG,"file too big..."); 22 | return null; 23 | } 24 | 25 | FileInputStream fi = new FileInputStream(file); 26 | byte[] buffer = new byte[(int) fileSize]; 27 | int offset = 0; 28 | int numRead = 0; 29 | 30 | while (offset < buffer.length 31 | && (numRead = fi.read(buffer, offset, buffer.length - offset)) >= 0) { 32 | offset += numRead; 33 | } 34 | 35 | // 确保所有数据均被读取 36 | if (offset != buffer.length) { 37 | throw new IOException("Could not completely read file " 38 | + file.getName()); 39 | } 40 | fi.close(); 41 | return buffer; 42 | } 43 | 44 | public static void saveFileToPath(String path, byte[] data)throws Exception { 45 | if(data != null){ 46 | Log.d(TAG,"save file to path:"+path); 47 | FileOutputStream out = new FileOutputStream(path);//指定写到哪个路径中 48 | FileChannel fileChannel = out.getChannel(); 49 | fileChannel.write(ByteBuffer.wrap(data)); //将字节流写入文件中 50 | fileChannel.force(true);//强制刷新 51 | fileChannel.close(); 52 | } 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/jni/mnndatanative.cpp: -------------------------------------------------------------------------------- 1 | // Created by Jimmy Yuan on 2021/05/19. 2 | // @Copyright Copyright © 2021, BUPT Holding Limited 3 | // @Content CNN model training 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #define LOG_TAG "test====" 28 | #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) 29 | #define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__) 30 | #define LOGD(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) 31 | 32 | // Global variable to get the current epoch value and current loss in this epoch 33 | static jint curEpoch = 0; 34 | static jfloat curLoss = 0.0; 35 | 36 | 37 | extern "C" 38 | JNIEXPORT jstring JNICALL 39 | Java_com_example_nativemnn_mnn_MNNDataNative_getEpochAndLoss(JNIEnv *env, jclass clazz) { 40 | std::string result = std::to_string(curEpoch) + "," + std::to_string(curLoss); 41 | return env->NewStringUTF(result.data()); 42 | } 43 | 44 | 45 | // @Return: trainSamples, testSamples, LOSS, ACC 46 | extern "C" 47 | JNIEXPORT jstring JNICALL 48 | Java_com_example_nativemnn_mnn_MNNDataNative_nativeCreateDatasetFromFile(JNIEnv *env, jclass clazz, 49 | jstring modelCachePath, jstring dataCachePath) { 50 | LOGE("======TEST======"); 51 | const char *modelPath = env->GetStringUTFChars(modelCachePath, 0); 52 | const char *dataPath = env->GetStringUTFChars(dataCachePath, 0); 53 | 54 | float accuracy = 0.0; 55 | float LOSS = 0.0; 56 | int trainSamples = 0; 57 | int testSamples = 0; 58 | 59 | auto exe = Executor::getGlobalExecutor(); 60 | MNN::BackendConfig config; 61 | exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4); 62 | std::shared_ptr sgd(new MNN::Train::SGD); 63 | 64 | // init lenet 65 | std::shared_ptr model(new MNN::Train::Model::Lenet); 66 | 67 | // model->parameters() to get model parameter 68 | auto para_0 = model->parameters(); 69 | { 70 | // Load model snapshot 71 | auto para = Variable::load(modelPath); 72 | int times = para.size()/8; 73 | for(int i = 0; i < para_0.size(); i++) 74 | para_0[i] = para[i * times]; 75 | model->loadParameters(para_0); 76 | } 77 | 78 | sgd->append(model->parameters()); 79 | // sgd->setMomentum(0.9f); 80 | // sgd->setMomentum2(0.99f); 81 | // sgd->setWeightDecay(0.0005f); 82 | 83 | // train data setting 84 | auto datasetPtr = MNN::Train::MnistDataset::create(dataPath, MNN::Train::MnistDataset::Mode::TRAIN); 85 | const size_t batchSize = 64; 86 | bool shuffle = true; 87 | const size_t numWorkers = 0; 88 | auto dataLoader = std::shared_ptr(datasetPtr.createLoader(batchSize, true, shuffle, numWorkers)); 89 | size_t total_size = dataLoader->size(); 90 | size_t iterations = dataLoader->iterNumber(); 91 | iterations = iterations / 150; 92 | trainSamples = iterations * batchSize; 93 | 94 | // test data setting 95 | auto testDataset = MNN::Train::MnistDataset::create("/data/local/tmp/mnist_data", MNN::Train::MnistDataset::Mode::TEST); 96 | const size_t testBatchSize = 20; 97 | const size_t testNumWorkers = 0; 98 | shuffle = false; 99 | auto testDataLoader = std::shared_ptr(testDataset.createLoader(testBatchSize, true, shuffle, testNumWorkers)); 100 | size_t testIterations = testDataLoader->iterNumber(); 101 | testIterations = testIterations / 100; 102 | testSamples = testIterations * testBatchSize; 103 | 104 | // start training 105 | for (int epoch = 0; epoch < 5; ++epoch){ 106 | curEpoch = (jint)epoch; 107 | model->clearCache(); 108 | exe->gc(Executor::FULL); 109 | exe->resetProfile(); 110 | { 111 | dataLoader->reset(); 112 | model->setIsTraining(true); 113 | int lastIndex = 0; 114 | int moveBatchSize = 0; 115 | for (int i = 0; i < iterations; i++) { 116 | auto trainData = dataLoader->next(); 117 | auto example = trainData[0]; 118 | auto cast = _Cast(example.first[0]); 119 | example.first[0] = cast * _Const(1.0f / 255.0f); 120 | moveBatchSize += example.first[0]->getInfo()->dim[0]; 121 | auto newTarget = _OneHot(_Cast(example.second[0]), _Scalar(10), _Scalar(1.0f), 122 | _Scalar(0.0f)); 123 | auto predict = model->forward(example.first[0]); 124 | auto loss = MNN::Train::_CrossEntropy(predict, newTarget); 125 | auto lossvalue = loss->readMap(); 126 | LOSS = *lossvalue; 127 | curLoss = LOSS; 128 | float rate = MNN::Train::LrScheduler::inv(0.01, epoch * iterations + i, 0.0001, 0.75); 129 | sgd->setLearningRate(rate); 130 | sgd->step(loss); 131 | } 132 | } 133 | 134 | Variable::save(model->parameters(), modelPath); 135 | 136 | int correct = 0; 137 | testDataLoader->reset(); 138 | model->setIsTraining(false); 139 | int moveBatchSize = 0; 140 | // start testing 141 | for (int i = 0; i < testIterations; i++) { 142 | auto data = testDataLoader->next(); 143 | auto example = data[0]; 144 | moveBatchSize += example.first[0]->getInfo()->dim[0]; 145 | auto cast = _Cast(example.first[0]); 146 | example.first[0] = cast * _Const(1.0f / 255.0f); 147 | auto predict = model->forward(example.first[0]); 148 | predict = _ArgMax(predict, 1); 149 | auto accu = _Cast(_Equal(predict, _Cast(example.second[0]))).sum({}); 150 | correct += accu->readMap()[0]; 151 | } 152 | auto accu = (float)correct / (float)testSamples; 153 | accuracy = accu; 154 | exe->dumpProfile(); 155 | 156 | } 157 | 158 | std::string result = std::to_string(LOSS) + "," + std::to_string(trainSamples) + 159 | "," + std::to_string(accuracy) + "," + std::to_string(testSamples); 160 | 161 | env->ReleaseStringUTFChars(modelCachePath, modelPath); 162 | env->ReleaseStringUTFChars(dataCachePath, dataPath); 163 | 164 | return env->NewStringUTF(result.data()); 165 | } -------------------------------------------------------------------------------- /Client-Android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml: -------------------------------------------------------------------------------- 1 | 7 | 8 | 9 | 15 | 18 | 21 | 22 | 23 | 24 | 30 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/res/drawable/ic_launcher_background.xml: -------------------------------------------------------------------------------- 1 | 2 | 7 | 10 | 15 | 20 | 25 | 30 | 35 | 40 | 45 | 50 | 55 | 60 | 65 | 70 | 75 | 80 | 85 | 90 | 95 | 100 | 105 | 110 | 115 | 120 | 125 | 130 | 135 | 140 | 145 | 150 | 155 | 160 | 165 | 170 | 171 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/res/layout/activity_empty.xml: -------------------------------------------------------------------------------- 1 | 2 | 8 | 9 | -------------------------------------------------------------------------------- /Client-Android/app/src/main/res/layout/activity_main.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 10 | 11 | 24 | 25 | 26 | 32 | 33 |
34 |
35 |
36 |
37 |

Aggregation Server Web

38 |
39 |
40 |

41 |

42 |
43 | 44 | 45 |
46 |
47 |
48 |
49 |

50 |
51 |
52 |
53 |
54 |
55 |
56 |

Client Message:

57 |
58 |
59 |
60 | 61 | 344 | 345 | 346 | 347 | -------------------------------------------------------------------------------- /data/mnist.snapshot.mnn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/data/mnist.snapshot.mnn -------------------------------------------------------------------------------- /data/mnist_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/data/mnist_data/.DS_Store -------------------------------------------------------------------------------- /data/mnist_data/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/data/mnist_data/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /data/mnist_data/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /data/mnist_data/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/data/mnist_data/train-images-idx3-ubyte -------------------------------------------------------------------------------- /data/mnist_data/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UbiquitousLearning/End2end-Federated-Learning/e2cdcd9829779798fc56f2f63b19ee6cdc2307d0/data/mnist_data/train-labels-idx1-ubyte --------------------------------------------------------------------------------