├── .gitignore ├── .gitlab-ci.yml ├── LICENSE ├── README.md ├── app ├── .gitignore ├── build.gradle ├── debug.properties ├── proguard-rules.pro └── src │ ├── androidTest │ └── java │ │ └── io │ │ └── whz │ │ └── synapse │ │ └── ExampleInstrumentedTest.java │ ├── debug │ ├── AndroidManifest.xml │ └── java │ │ └── io │ │ └── whz │ │ └── synapse │ │ └── component │ │ ├── DebugApp.java │ │ └── WrapperActivity.java │ ├── main │ ├── AndroidManifest.xml │ ├── assets │ │ └── demo.model │ ├── java │ │ └── io │ │ │ └── whz │ │ │ └── synapse │ │ │ ├── component │ │ │ ├── AboutDialog.java │ │ │ ├── App.java │ │ │ ├── BaseActivity.java │ │ │ ├── MainActivity.java │ │ │ ├── MainService.java │ │ │ ├── ModelDetailActivity.java │ │ │ ├── NeuralModelActivity.java │ │ │ └── PlayActivity.java │ │ │ ├── element │ │ │ ├── AutoFitWidthCardView.java │ │ │ ├── AutoFitWidthLineChart.java │ │ │ ├── ChannelCreator.java │ │ │ ├── DigitView.java │ │ │ ├── Dir.java │ │ │ ├── FigureProvider.java │ │ │ ├── Global.java │ │ │ ├── IThreadExecutor.java │ │ │ ├── Scheduler.java │ │ │ ├── Singleton.java │ │ │ └── VerticalGap.java │ │ │ ├── matrix │ │ │ ├── Matrix.java │ │ │ └── MatrixChecker.java │ │ │ ├── neural │ │ │ ├── ActivateFunction.java │ │ │ ├── DataSet.java │ │ │ ├── MNISTUtil.java │ │ │ ├── NeuralNetwork.java │ │ │ └── TrainCallback.java │ │ │ ├── pojo │ │ │ ├── constant │ │ │ │ ├── PreferenceCons.java │ │ │ │ └── TrackCons.java │ │ │ ├── dao │ │ │ │ └── DBModel.java │ │ │ ├── event │ │ │ │ ├── MANEvent.java │ │ │ │ ├── MSNEvent.java │ │ │ │ ├── ModelDeletedEvent.java │ │ │ │ ├── NormalEvent.java │ │ │ │ ├── TrackEvent.java │ │ │ │ ├── TrainEvent.java │ │ │ │ └── TypeEvent.java │ │ │ ├── multiple │ │ │ │ ├── binder │ │ │ │ │ ├── PlayViewBinder.java │ │ │ │ │ ├── TrainedModelViewBinder.java │ │ │ │ │ ├── TrainingModelViewBinder.java │ │ │ │ │ └── WelcomeViewBinder.java │ │ │ │ └── item │ │ │ │ │ ├── PlayItem.java │ │ │ │ │ ├── TrainedModelItem.java │ │ │ │ │ ├── TrainingModelItem.java │ │ │ │ │ └── WelcomeItem.java │ │ │ └── neural │ │ │ │ ├── Batch.java │ │ │ │ ├── Digit.java │ │ │ │ ├── Figure.java │ │ │ │ └── Model.java │ │ │ ├── track │ │ │ ├── AbsTrackHandler.java │ │ │ ├── ActivityLifecycleTracker.java │ │ │ ├── AmplitudeTrackHandler.java │ │ │ ├── DebugTrackHandler.java │ │ │ ├── ExceptionHelper.java │ │ │ ├── FirebaseTrackHandler.java │ │ │ ├── ITracker.java │ │ │ ├── TimeHelper.java │ │ │ └── Tracker.java │ │ │ ├── transition │ │ │ ├── FabTransform.java │ │ │ └── GravityArcMotion.java │ │ │ └── util │ │ │ ├── DbHelper.java │ │ │ ├── FileUtil.java │ │ │ ├── MatrixUtil.java │ │ │ ├── Precondition.java │ │ │ ├── ProcessUtil.java │ │ │ ├── StringFormatUtil.java │ │ │ └── Versatile.java │ └── res │ │ ├── anim │ │ ├── item_animation_from_bottom.xml │ │ └── layout_animation_from_bottom.xml │ │ ├── drawable-nodpi │ │ └── marker.webp │ │ ├── drawable-xxxhdpi │ │ ├── blue_ripple.webp │ │ ├── notify_icon.webp │ │ ├── red_paper.webp │ │ ├── red_sun.webp │ │ └── stack_rectangle.webp │ │ ├── drawable │ │ ├── bg_splash.xml │ │ ├── bg_white_fillet.xml │ │ ├── dialog_background.xml │ │ ├── ic_add_white_24dp.xml │ │ ├── ic_arrow_forward_24dp.xml │ │ ├── ic_block_24dp.xml │ │ ├── ic_change_24dp.xml │ │ ├── ic_close_24dp.xml │ │ ├── ic_favorite_24dp.xml │ │ ├── ic_github_code_24.xml │ │ ├── ic_play_24dp.xml │ │ ├── ic_refresh_24dp.xml │ │ └── ic_share_24dp.xml │ │ ├── layout │ │ ├── ac_detail_mark_view.xml │ │ ├── activity_main.xml │ │ ├── activity_model_detail.xml │ │ ├── activity_neural_model.xml │ │ ├── activity_play.xml │ │ ├── dialog_about.xml │ │ ├── dialog_model_list.xml │ │ ├── hidden_layer_input.xml │ │ ├── item_paly.xml │ │ ├── item_trained.xml │ │ ├── item_training.xml │ │ └── item_welcome.xml │ │ ├── menu │ │ ├── ac_main_menu.xml │ │ ├── ac_model_detail_menu.xml │ │ └── ac_play_menu.xml │ │ ├── mipmap-hdpi │ │ └── ic_launcher.png │ │ ├── mipmap-xhdpi │ │ └── ic_launcher.png │ │ ├── mipmap-xxhdpi │ │ └── ic_launcher.png │ │ ├── mipmap-xxxhdpi │ │ └── ic_launcher.png │ │ └── values │ │ ├── colors.xml │ │ ├── dimens.xml │ │ ├── ids.xml │ │ ├── strings.xml │ │ ├── styles.xml │ │ └── themes.xml │ ├── release │ └── java │ │ └── io.whz.synapse.component │ │ └── WrapperActivity.java │ └── test │ └── java │ └── io │ └── whz │ └── synapse │ ├── ExampleUnitTest.java │ └── matrix │ ├── MatrixCheckerTest.java │ └── MatrixTest.java ├── build.gradle ├── gradle.properties ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── publicity └── ad.png └── settings.gradle /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | .gradle 3 | /local.properties 4 | /.idea 5 | .DS_Store 6 | /build 7 | /captures 8 | .externalNativeBuild 9 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: huazhouwang/android-26 2 | 3 | before_script: 4 | - chmod +x ./gradlew 5 | 6 | stages: 7 | - test 8 | - build 9 | 10 | test: 11 | stage: test 12 | script: 13 | - ./gradlew testDebug 14 | artifacts: 15 | paths: 16 | - app/build/reports 17 | 18 | build: 19 | stage: build 20 | script: 21 | - ./gradlew assembleDebug 22 | artifacts: 23 | paths: 24 | - app/build/outputs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Synapse 2 | Synapse is a beauty, funny application which allows you train the SGD model to recognize MNIST handwritten digits on the local device directly. It does not depend on any deep learning library, just a pure Java implementation, for learning purposes only. 3 | 4 | Encourage you to take a look at [Neural Networks and Deep Learning](http://neuralnetworksanddeeplearning.com/index.html) which is written by [Michael Nielsen](http://michaelnielsen.org/). 5 | 6 | ## Enjoy Yourself 7 | 8 | ![](publicity/ad.png) 9 | 10 | If your device doesn't have Google Play, you can download it from [this site](http://fir.im/6vta). 11 | 12 | ## License 13 | 14 | Copyright 2016 HuazhouWang. 15 | 16 | Licensed under the Apache License, Version 2.0 (the "License"); 17 | you may not use this file except in compliance with the License. 18 | You may obtain a copy of the License at 19 | 20 | http://www.apache.org/licenses/LICENSE-2.0 21 | 22 | Unless required by applicable law or agreed to in writing, software 23 | distributed under the License is distributed on an "AS IS" BASIS, 24 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 | See the License for the specific language governing permissions and 26 | limitations under the License. 27 | -------------------------------------------------------------------------------- /app/.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | google-services.json 3 | release.properties 4 | -------------------------------------------------------------------------------- /app/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'com.android.application' 2 | apply plugin: 'org.greenrobot.greendao' 3 | 4 | final boolean hasGSJson = false// hasGoogleServicesJson() 5 | loadExtProperties(hasGSJson) 6 | 7 | android { 8 | compileSdkVersion 26 9 | buildToolsVersion "26.0.0" 10 | defaultConfig { 11 | applicationId "io.whz.synapse" 12 | minSdkVersion 21 13 | targetSdkVersion 26 14 | versionCode 2 15 | versionName "1.1" 16 | testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" 17 | 18 | javaCompileOptions { 19 | annotationProcessorOptions { 20 | arguments = [eventBusIndex: 'io.whz.synapse.EventBusIndex'] 21 | } 22 | } 23 | 24 | buildConfigField "boolean", "TRACK_ENABLE", "$hasGSJson" 25 | buildConfigField "String", "AMPLITUDE_ID", "\"$gradle.AMPLITUDE_ID\"" 26 | } 27 | 28 | buildTypes { 29 | debug { 30 | if (!hasGSJson) { 31 | applicationIdSuffix ".debug" 32 | } 33 | } 34 | 35 | release { 36 | minifyEnabled true 37 | shrinkResources true 38 | proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' 39 | } 40 | } 41 | } 42 | 43 | greendao { 44 | schemaVersion 1 45 | } 46 | 47 | android.applicationVariants.all { v -> 48 | if (v.buildType.name == "release"){ 49 | v.assemble.doFirst { 50 | if (!hasGSJson) { 51 | throw new IllegalArgumentException("Please enable Google Service Json!") 52 | } 53 | } 54 | } 55 | } 56 | 57 | dependencies { 58 | compile fileTree(dir: 'libs', include: ['*.jar']) 59 | 60 | androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { 61 | exclude group: 'com.android.support', module: 'support-annotations' 62 | }) 63 | compile 'com.android.support:appcompat-v7:26.0.2' 64 | compile 'com.android.support:design:26.0.2' 65 | compile 'com.android.support:cardview-v7:26.0.2' 66 | compile 'me.drakeet.multitype:multitype:3.3.0' 67 | compile 'org.greenrobot:greendao:3.2.2' 68 | compile 'org.greenrobot:eventbus:3.0.0' 69 | compile 'com.github.PhilJay:MPAndroidChart:v3.0.2' 70 | compile 'com.google.firebase:firebase-core:11.4.2' 71 | compile 'com.amplitude:android-sdk:2.15.0' 72 | 73 | debugCompile 'com.jakewharton.scalpel:scalpel:1.1.2' 74 | debugCompile 'com.squareup.leakcanary:leakcanary-android:1.5.4' 75 | 76 | testCompile 'junit:junit:4.12' 77 | annotationProcessor 'org.greenrobot:eventbus-annotation-processor:3.0.1' 78 | } 79 | 80 | if (hasGSJson) { 81 | println("Import google services plugin") 82 | apply plugin: 'com.google.gms.google-services' 83 | } 84 | 85 | def hasGoogleServicesJson() { 86 | final File file = new File(projectDir.absolutePath, "google-services.json"); 87 | return file.exists(); 88 | } 89 | 90 | def loadExtProperties(isRelease) { 91 | final File file = new File(projectDir.absolutePath, 92 | isRelease ? "release.properties" : "debug.properties"); 93 | 94 | if (!file.exists()) { 95 | return; 96 | } 97 | 98 | final Properties prop = new Properties() 99 | prop.load(file.newDataInputStream()) 100 | 101 | for (String key : prop.keySet()) { 102 | gradle.ext."$key" = prop.get(key) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /app/debug.properties: -------------------------------------------------------------------------------- 1 | AMPLITUDE_ID=null -------------------------------------------------------------------------------- /app/proguard-rules.pro: -------------------------------------------------------------------------------- 1 | # Add project specific ProGuard rules here. 2 | # By default, the flags in this file are appended to flags specified 3 | # in /Users/wanghuazhou/Library/Android/sdk/tools/proguard/proguard-android.txt 4 | # You can edit the include path and order by changing the proguardFiles 5 | # directive in build.gradle. 6 | # 7 | # For more details, see 8 | # http://developer.android.com/guide/developing/tools/proguard.html 9 | 10 | # Add any project specific keep options here: 11 | 12 | # If your project uses WebView with JS, uncomment the following 13 | # and specify the fully qualified class name to the JavaScript interface 14 | # class: 15 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview { 16 | # public *; 17 | #} 18 | 19 | # Uncomment this to preserve the line number information for 20 | # debugging stack traces. 21 | #-keepattributes SourceFile,LineNumberTable 22 | 23 | # If you keep the line number information, uncomment this to 24 | # hide the original source file name. 25 | #-renamesourcefileattribute SourceFile 26 | 27 | ##--- For:android默认 --- 28 | -optimizationpasses 5 # 指定代码的压缩级别 29 | -allowaccessmodification #优化时允许访问并修改有修饰符的类和类的成员 30 | -dontusemixedcaseclassnames # 是否使用大小写混合 31 | -dontskipnonpubliclibraryclasses # 是否混淆第三方jar 32 | -dontpreverify # 混淆时是否做预校验 33 | -verbose # 混淆时是否记录日志 34 | -ignorewarnings # 忽略警告,避免打包时某些警告出现 35 | -optimizations !code/simplification/arithmetic,!code/simplification/cast,!field/*,!class/merging/* # 混淆时所采用的算法 36 | 37 | -keepattributes *Annotation* 38 | -keep public class com.google.vending.licensing.ILicensingService 39 | -keep public class com.android.vending.licensing.ILicensingService 40 | -keepclasseswithmembernames class * { # 保持 native 方法不被混淆 41 | native ; 42 | } 43 | 44 | -keepclassmembers public class * extends android.view.View { 45 | void set*(***); 46 | *** get*(); 47 | } 48 | 49 | -keepclassmembers class * extends android.app.Activity { 50 | public void *(android.view.View); 51 | } 52 | 53 | -keepclassmembers enum * { # 保持枚举 enum 类不被混淆 54 | public static **[] values(); 55 | public static ** valueOf(java.lang.String); 56 | } 57 | 58 | -keep class * implements android.os.Parcelable { # 保持 Parcelable 不被混淆 59 | public static final android.os.Parcelable$Creator *; 60 | } 61 | 62 | -keepclassmembers class **.R$* { #不混淆R文件 63 | public static ; 64 | } 65 | 66 | -dontwarn android.support.** 67 | ##--- End android默认 --- 68 | 69 | ##--- For:不能被混淆的 --- 70 | -keep public class * extends android.app.Activity 71 | -keep public class * extends android.app.Fragment 72 | -keep public class * extends android.app.Application 73 | -keep public class * extends android.app.Service 74 | -keep public class * extends android.content.BroadcastReceiver 75 | -keep public class * extends android.content.ContentProvider 76 | -keep public class * extends android.app.backup.BackupAgentHelper 77 | -keep public class * extends android.preference.Preference 78 | 79 | ##--- For:android-support-v4 --- 80 | -dontwarn android.support.v4.** 81 | -keep class android.support.v4.** { *; } 82 | -keep interface android.support.v4.app.** { *; } 83 | -keep class * extends android.support.v4.** { *; } 84 | -keep public class * extends android.support.v4.** 85 | -keep public class * extends android.support.v4.widget 86 | -keep class * extends android.support.v4.app.** {*;} 87 | -keep class * extends android.support.v4.view.** {*;} 88 | 89 | ##--- For:Serializable --- 90 | -keep class * implements java.io.Serializable {*;} 91 | -keepnames class * implements java.io.Serializable 92 | -keepclassmembers class * implements java.io.Serializable {*;} 93 | 94 | ##--- For:Remove log --- 95 | -assumenosideeffects class android.util.Log { 96 | public static boolean isLoggable(java.lang.String, int); 97 | public static int v(...); 98 | public static int i(...); 99 | public static int w(...); 100 | public static int d(...); 101 | public static int e(...); 102 | } 103 | 104 | -keepattributes *Annotation* #使用注解 105 | -keepattributes Signature #过滤泛型 出现类型转换错误时,启用这个 106 | 107 | ### Event Bus 108 | -keepattributes *Annotation* 109 | -keepclassmembers class ** { 110 | @org.greenrobot.eventbus.Subscribe ; 111 | } 112 | -keep enum org.greenrobot.eventbus.ThreadMode { *; } 113 | 114 | # Only required if you use AsyncExecutor 115 | -keepclassmembers class * extends org.greenrobot.eventbus.util.ThrowableFailureEvent { 116 | (java.lang.Throwable); 117 | } 118 | 119 | ### GreenDAO 3 120 | -keepclassmembers class * extends org.greenrobot.greendao.AbstractDao { 121 | public static java.lang.String TABLENAME; 122 | } 123 | -keep class **$Properties 124 | 125 | # If you do not use SQLCipher: 126 | -dontwarn org.greenrobot.greendao.database.** 127 | # If you do not use RxJava: 128 | -dontwarn rx.** 129 | 130 | ### Amplitude 131 | -keepattributes Signature 132 | -keepattributes Annotation 133 | -keep class okhttp3.** { *; } 134 | -keep interface okhttp3.** { *; } 135 | -dontwarn okhttp3.** 136 | -dontwarn okio.** -------------------------------------------------------------------------------- /app/src/androidTest/java/io/whz/synapse/ExampleInstrumentedTest.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse; 2 | 3 | import android.content.Context; 4 | import android.support.test.InstrumentationRegistry; 5 | import android.support.test.runner.AndroidJUnit4; 6 | 7 | import org.junit.Test; 8 | import org.junit.runner.RunWith; 9 | 10 | import static org.junit.Assert.*; 11 | 12 | /** 13 | * Instrumentation test, which will execute on an Android device. 14 | * 15 | * @see Testing documentation 16 | */ 17 | @RunWith(AndroidJUnit4.class) 18 | public class ExampleInstrumentedTest { 19 | @Test 20 | public void useAppContext() throws Exception { 21 | // Context of the app under test. 22 | Context appContext = InstrumentationRegistry.getTargetContext(); 23 | 24 | assertEquals("io.whz.androidneuralnetwork", appContext.getPackageName()); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /app/src/debug/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 10 | 11 | -------------------------------------------------------------------------------- /app/src/debug/java/io/whz/synapse/component/DebugApp.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.component; 2 | 3 | import com.squareup.leakcanary.LeakCanary; 4 | 5 | public class DebugApp extends App { 6 | 7 | @Override 8 | public void onCreate() { 9 | super.onCreate(); 10 | 11 | if (!LeakCanary.isInAnalyzerProcess(this)) { 12 | LeakCanary.install(this); 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /app/src/debug/java/io/whz/synapse/component/WrapperActivity.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.component; 2 | 3 | import android.graphics.drawable.Drawable; 4 | import android.support.annotation.CallSuper; 5 | import android.support.annotation.LayoutRes; 6 | import android.support.v4.content.ContextCompat; 7 | import android.view.LayoutInflater; 8 | import android.view.Menu; 9 | import android.view.MenuItem; 10 | import android.view.ViewGroup; 11 | import android.widget.Toast; 12 | 13 | import com.jakewharton.scalpel.ScalpelFrameLayout; 14 | 15 | import io.whz.synapse.R; 16 | 17 | public class WrapperActivity extends BaseActivity { 18 | private Drawable mDrawableBackUp; 19 | private boolean mScalpelEnable = false; 20 | private MenuItem mScalpelMenu; 21 | private ScalpelFrameLayout mScalpelLayout; 22 | 23 | @Override 24 | public void setContentView(@LayoutRes int layoutResID) { 25 | mScalpelLayout = new ScalpelFrameLayout(this); 26 | 27 | mScalpelLayout.setLayerInteractionEnabled(false); 28 | mScalpelLayout.setDrawViews(true); 29 | mScalpelLayout.setDrawIds(true); 30 | mScalpelLayout.setChromeColor(ContextCompat.getColor(this, R.color.white$1)); 31 | mScalpelLayout.setChromeShadowColor(ContextCompat.getColor(this, R.color.red$1)); 32 | 33 | LayoutInflater.from(this).inflate(layoutResID, mScalpelLayout, true); 34 | 35 | final ViewGroup.LayoutParams lp = new ViewGroup.LayoutParams(ViewGroup.LayoutParams.MATCH_PARENT, 36 | ViewGroup.LayoutParams.MATCH_PARENT); 37 | 38 | super.setContentView(mScalpelLayout, lp); 39 | } 40 | 41 | @CallSuper 42 | @Override 43 | public boolean onCreateOptionsMenu(Menu menu) { 44 | mScalpelMenu = menu.add(Menu.NONE, R.id.scalpel_menu, Menu.NONE, "Enable Scalpel"); 45 | 46 | return true; 47 | } 48 | 49 | @CallSuper 50 | @Override 51 | public boolean onOptionsItemSelected(MenuItem item) { 52 | if (item.getItemId() == R.id.scalpel_menu) { 53 | 54 | mScalpelEnable = !mScalpelEnable; 55 | 56 | if (mScalpelEnable) { 57 | mDrawableBackUp = getWindow().getDecorView().getBackground(); 58 | getWindow().getDecorView().setBackgroundResource(R.color.black$2); 59 | 60 | mScalpelLayout.setLayerInteractionEnabled(true); 61 | mScalpelMenu.setTitle("Disable Scalpel"); 62 | } else { 63 | getWindow().getDecorView().setBackground(mDrawableBackUp); 64 | 65 | mScalpelLayout.setLayerInteractionEnabled(false); 66 | mScalpelMenu.setTitle("Enable Scalpel"); 67 | } 68 | 69 | Toast.makeText(this, "Submit change", Toast.LENGTH_SHORT) 70 | .show(); 71 | } 72 | 73 | return super.onOptionsItemSelected(item); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /app/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 22 | 23 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 41 | 42 | 48 | 49 | 55 | 56 | 60 | 61 | -------------------------------------------------------------------------------- /app/src/main/assets/demo.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huazhouwang/Synapse/cc536d98a284a7c5113e90abecdd1ea67a7531a2/app/src/main/assets/demo.model -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/component/AboutDialog.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.component; 2 | 3 | import android.app.Activity; 4 | import android.app.Dialog; 5 | import android.content.Intent; 6 | import android.content.pm.PackageManager; 7 | import android.net.Uri; 8 | import android.os.Bundle; 9 | import android.support.annotation.NonNull; 10 | import android.support.v4.app.DialogFragment; 11 | import android.support.v4.app.ShareCompat; 12 | import android.support.v7.app.AlertDialog; 13 | import android.view.View; 14 | import android.view.ViewGroup; 15 | 16 | import io.whz.synapse.BuildConfig; 17 | import io.whz.synapse.R; 18 | import io.whz.synapse.pojo.constant.TrackCons; 19 | import io.whz.synapse.track.Tracker; 20 | 21 | public class AboutDialog extends DialogFragment implements View.OnClickListener { 22 | private static final String APP_IN_GOOGLE_PLAY = 23 | "https://play.google.com/store/apps/details?id=" + BuildConfig.APPLICATION_ID; 24 | 25 | @NonNull 26 | @Override 27 | public Dialog onCreateDialog(Bundle savedInstanceState) { 28 | final AlertDialog.Builder builder = new AlertDialog.Builder(getContext()); 29 | 30 | builder.setTitle(R.string.text_about) 31 | .setMessage(R.string.text_dialog_about_msg) 32 | .setView(inflateCustom()) 33 | .setPositiveButton(R.string.text_dialog_about_positive, null); 34 | 35 | return builder.create(); 36 | } 37 | 38 | private View inflateCustom() { 39 | final View view = View.inflate(getContext(), R.layout.dialog_about, null); 40 | 41 | view.findViewById(R.id.github).setOnClickListener(this); 42 | view.findViewById(R.id.rate).setOnClickListener(this); 43 | view.findViewById(R.id.share).setOnClickListener(this); 44 | 45 | final ViewGroup.LayoutParams lp = new ViewGroup.LayoutParams( 46 | ViewGroup.LayoutParams.MATCH_PARENT, ViewGroup.LayoutParams.WRAP_CONTENT); 47 | 48 | view.setLayoutParams(lp); 49 | 50 | return view; 51 | } 52 | 53 | @Override 54 | public void onClick(View view) { 55 | final int id = view.getId(); 56 | 57 | switch (id) { 58 | case R.id.github: 59 | handleGitHubAction(); 60 | break; 61 | 62 | case R.id.rate: 63 | handleRateAction(); 64 | break; 65 | 66 | case R.id.share: 67 | handleShareAction(); 68 | break; 69 | 70 | default: 71 | break; 72 | } 73 | } 74 | 75 | private void handleGitHubAction() { 76 | final Activity activity = getActivity(); 77 | 78 | if (activity == null || activity.isFinishing()) { 79 | return; 80 | } 81 | 82 | final Intent intent = new Intent(Intent.ACTION_VIEW, Uri.parse("https://github.com/huazhouwang/Synapse")); 83 | final PackageManager manager = activity.getPackageManager(); 84 | 85 | if (intent.resolveActivity(manager) != null) { 86 | activity.startActivity(intent); 87 | } 88 | 89 | Tracker.getInstance() 90 | .logEvent(TrackCons.About.CLICK_GITHUB); 91 | } 92 | 93 | private void handleRateAction() { 94 | final Activity activity = getActivity(); 95 | 96 | if (activity == null || activity.isFinishing()) { 97 | return; 98 | } 99 | 100 | final Intent intent = new Intent(Intent.ACTION_VIEW, Uri.parse(APP_IN_GOOGLE_PLAY)); 101 | final PackageManager manager = activity.getPackageManager(); 102 | 103 | if (intent.resolveActivity(manager) != null) { 104 | activity.startActivity(intent); 105 | } 106 | 107 | Tracker.getInstance() 108 | .logEvent(TrackCons.About.CLICK_RATE); 109 | } 110 | 111 | private void handleShareAction() { 112 | ShareCompat.IntentBuilder.from(getActivity()) 113 | .setChooserTitle(R.string.text_share_chooser_title) 114 | .setSubject(getString(R.string.text_share_subject)) 115 | .setText(getString(R.string.text_share_text) + " " + APP_IN_GOOGLE_PLAY) 116 | .setType("text/plain") 117 | .startChooser(); 118 | 119 | Tracker.getInstance() 120 | .logEvent(TrackCons.About.CLICK_SHARE); 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/component/App.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.component; 2 | 3 | import android.app.Application; 4 | import android.content.SharedPreferences; 5 | import android.os.Build; 6 | import android.support.annotation.Nullable; 7 | 8 | import org.greenrobot.eventbus.EventBus; 9 | 10 | import io.whz.synapse.EventBusIndex; 11 | import io.whz.synapse.element.ChannelCreator; 12 | import io.whz.synapse.element.Global; 13 | import io.whz.synapse.pojo.constant.TrackCons; 14 | import io.whz.synapse.pojo.dao.DaoMaster; 15 | import io.whz.synapse.pojo.dao.DaoSession; 16 | import io.whz.synapse.track.ExceptionHelper; 17 | import io.whz.synapse.track.TimeHelper; 18 | import io.whz.synapse.track.Tracker; 19 | 20 | public class App extends Application { 21 | public static final String TAG = "Synapse"; 22 | private static final String DB_NAME = "global-db"; 23 | private static final String PREFERENCE_NAME = "global-preferences"; 24 | 25 | private final Global mGlobal = Global.getInstance(); 26 | 27 | @Override 28 | public void onCreate() { 29 | super.onCreate(); 30 | 31 | TimeHelper.getInstance() 32 | .start(TrackCons.APP.INITIALIZE); 33 | 34 | configEvenBus(); 35 | configPreferences(); 36 | configGreenDao(); 37 | initTrackEngines(); 38 | 39 | createNotificationChannel(); 40 | hookUncaughtExceptionHandler(); 41 | 42 | Tracker.getInstance() 43 | .event(TrackCons.APP.INITIALIZE) 44 | .put(TrackCons.Key.TIME_USED, TimeHelper.getInstance().stop(TrackCons.APP.INITIALIZE)) 45 | .log(); 46 | } 47 | 48 | private void hookUncaughtExceptionHandler() { 49 | final Thread.UncaughtExceptionHandler handler = new GlobalExceptionHandler( 50 | Thread.getDefaultUncaughtExceptionHandler()); 51 | 52 | Thread.setDefaultUncaughtExceptionHandler(handler); 53 | } 54 | 55 | private void initTrackEngines() { 56 | Tracker.getInstance() 57 | .initialize(getApplicationContext(), 58 | mGlobal.getBus()); 59 | } 60 | 61 | private void createNotificationChannel() { 62 | if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { 63 | ChannelCreator.createChannel(this.getApplicationContext()); 64 | } 65 | } 66 | 67 | private void configGreenDao() { 68 | final DaoMaster.DevOpenHelper helper = new DaoMaster.DevOpenHelper( 69 | getApplicationContext(), DB_NAME); 70 | 71 | final DaoSession session = new DaoMaster(helper.getWritableDb()).newSession(); 72 | mGlobal.setSession(session); 73 | } 74 | 75 | private void configPreferences() { 76 | final SharedPreferences preferences = getApplicationContext().getSharedPreferences(PREFERENCE_NAME, MODE_PRIVATE); 77 | mGlobal.setPreference(preferences); 78 | } 79 | 80 | private void configEvenBus() { 81 | final EventBus bus = EventBus.builder() 82 | .addIndex(new EventBusIndex()) 83 | .installDefaultEventBus(); 84 | 85 | mGlobal.setBus(bus); 86 | } 87 | 88 | private static class GlobalExceptionHandler implements Thread.UncaughtExceptionHandler { 89 | @Nullable 90 | private final Thread.UncaughtExceptionHandler mDefault; 91 | 92 | GlobalExceptionHandler(@Nullable Thread.UncaughtExceptionHandler handler) { 93 | mDefault = handler; 94 | } 95 | 96 | @Override 97 | public void uncaughtException(Thread thread, Throwable throwable) { 98 | ExceptionHelper.getInstance() 99 | .caught(throwable); 100 | 101 | if (mDefault != null) { 102 | mDefault.uncaughtException(thread, throwable); 103 | } 104 | } 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/component/BaseActivity.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.component; 2 | 3 | import android.support.v7.app.AppCompatActivity; 4 | 5 | import io.whz.synapse.util.Versatile; 6 | 7 | public class BaseActivity extends AppCompatActivity { 8 | @Override 9 | protected void onDestroy() { 10 | super.onDestroy(); 11 | 12 | Versatile.removeActivityFromTransitionManager(this); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/AutoFitWidthCardView.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.content.Context; 4 | import android.support.v7.widget.CardView; 5 | import android.util.AttributeSet; 6 | 7 | public class AutoFitWidthCardView extends CardView { 8 | private static final double SCALE = 16F / 9F; 9 | 10 | public AutoFitWidthCardView(Context context) { 11 | super(context); 12 | } 13 | 14 | public AutoFitWidthCardView(Context context, AttributeSet attrs) { 15 | super(context, attrs); 16 | } 17 | 18 | public AutoFitWidthCardView(Context context, AttributeSet attrs, int defStyleAttr) { 19 | super(context, attrs, defStyleAttr); 20 | } 21 | 22 | @Override 23 | protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) { 24 | final int width = MeasureSpec.getSize(widthMeasureSpec); 25 | final int height = (int) (width / SCALE); 26 | 27 | super.onMeasure( 28 | MeasureSpec.makeMeasureSpec(width, MeasureSpec.EXACTLY), 29 | MeasureSpec.makeMeasureSpec(height, MeasureSpec.EXACTLY)); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/AutoFitWidthLineChart.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.content.Context; 4 | import android.util.AttributeSet; 5 | 6 | import com.github.mikephil.charting.charts.LineChart; 7 | 8 | public class AutoFitWidthLineChart extends LineChart { 9 | private static final double SCALE = 16F / 9F; 10 | 11 | public AutoFitWidthLineChart(Context context) { 12 | super(context); 13 | } 14 | 15 | public AutoFitWidthLineChart(Context context, AttributeSet attrs) { 16 | super(context, attrs); 17 | } 18 | 19 | public AutoFitWidthLineChart(Context context, AttributeSet attrs, int defStyle) { 20 | super(context, attrs, defStyle); 21 | } 22 | 23 | @Override 24 | protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) { 25 | final int width = MeasureSpec.getSize(widthMeasureSpec); 26 | final int height = (int) (width / SCALE); 27 | 28 | super.onMeasure( 29 | MeasureSpec.makeMeasureSpec(width, MeasureSpec.EXACTLY), 30 | MeasureSpec.makeMeasureSpec(height, MeasureSpec.EXACTLY)); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/ChannelCreator.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.app.NotificationChannel; 4 | import android.app.NotificationManager; 5 | import android.content.Context; 6 | import android.os.Build; 7 | import android.support.annotation.NonNull; 8 | import android.support.annotation.RequiresApi; 9 | import android.support.v4.app.NotificationCompat; 10 | 11 | import io.whz.synapse.util.Precondition; 12 | 13 | public class ChannelCreator { 14 | public static final String CHANNEL_ID = "io.whz.androidneuralnetwork.notification"; 15 | private static final String CHANNEL_NAME = "Synapse Channel"; 16 | 17 | @RequiresApi(api = Build.VERSION_CODES.O) 18 | public static void createChannel(@NonNull Context context) { 19 | Precondition.checkNotNull(context); 20 | 21 | final NotificationChannel channel = new NotificationChannel(CHANNEL_ID, CHANNEL_NAME, 22 | NotificationManager.IMPORTANCE_DEFAULT); 23 | channel.setLockscreenVisibility(NotificationCompat.VISIBILITY_PUBLIC); 24 | 25 | final NotificationManager manager = getManager(context); 26 | manager.deleteNotificationChannel(CHANNEL_ID); 27 | manager.createNotificationChannel(channel); 28 | } 29 | 30 | private static NotificationManager getManager(@NonNull Context context) { 31 | return (NotificationManager) context.getSystemService(Context.NOTIFICATION_SERVICE); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/DigitView.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.content.Context; 4 | import android.graphics.Bitmap; 5 | import android.graphics.Canvas; 6 | import android.graphics.Color; 7 | import android.graphics.DashPathEffect; 8 | import android.graphics.Matrix; 9 | import android.graphics.Paint; 10 | import android.graphics.Path; 11 | import android.graphics.Rect; 12 | import android.support.annotation.NonNull; 13 | import android.support.annotation.Nullable; 14 | import android.util.AttributeSet; 15 | import android.view.MotionEvent; 16 | import android.view.View; 17 | 18 | import io.whz.synapse.R; 19 | import io.whz.synapse.neural.MNISTUtil; 20 | 21 | public class DigitView extends View { 22 | private static final int OVERLAY_TIME = 1; 23 | private static final int DIGIT_SIDE = 28; 24 | 25 | private final Bitmap mDigitBitmap; 26 | private final Canvas mDigitCanvas; 27 | private final Matrix mNarrowMatrix; 28 | private final Matrix mEnlargeMatrix; 29 | private final Paint mBgPaint; 30 | private final Paint mFgPaint; 31 | private final Paint mRicePaint; 32 | 33 | private final Pair mOldPair; 34 | private final Pair mNewPair; 35 | private final Path mRicePath; 36 | 37 | private boolean mIsTrash; 38 | 39 | { 40 | mDigitBitmap = Bitmap.createBitmap(DIGIT_SIDE, DIGIT_SIDE, Bitmap.Config.ARGB_8888); 41 | mDigitCanvas = new Canvas(mDigitBitmap); 42 | mNarrowMatrix = new Matrix(); 43 | mEnlargeMatrix = new Matrix(); 44 | 45 | mBgPaint = new Paint(); 46 | mBgPaint.setColor(Color.WHITE); 47 | 48 | mFgPaint = new Paint(); 49 | mFgPaint.setColor(Color.BLACK); 50 | mFgPaint.setAntiAlias(true); 51 | mFgPaint.setStrokeWidth(2.5F); 52 | 53 | mRicePaint = new Paint(); 54 | mRicePaint.setColor(Color.BLACK); 55 | mRicePaint.setAlpha(85); 56 | mRicePaint.setPathEffect(new DashPathEffect(new float[]{15, 15, 15, 15}, 1)); 57 | mRicePaint.setAntiAlias(true); 58 | mRicePaint.setStyle(Paint.Style.STROKE); 59 | mRicePaint.setStrokeWidth(3); 60 | 61 | mOldPair = new Pair(); 62 | mNewPair = new Pair(); 63 | 64 | mRicePath = new Path(); 65 | 66 | mIsTrash = false; 67 | } 68 | 69 | public DigitView(Context context) { 70 | super(context); 71 | } 72 | 73 | public DigitView(Context context, @Nullable AttributeSet attrs) { 74 | super(context, attrs); 75 | } 76 | 77 | public DigitView(Context context, @Nullable AttributeSet attrs, int defStyleAttr) { 78 | super(context, attrs, defStyleAttr); 79 | } 80 | 81 | public DigitView(Context context, @Nullable AttributeSet attrs, int defStyleAttr, int defStyleRes) { 82 | super(context, attrs, defStyleAttr, defStyleRes); 83 | } 84 | 85 | @Override 86 | protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) { 87 | final int defaultSide = getResources().getDimensionPixelOffset(R.dimen.digit_view_default_side); 88 | 89 | final int tmpW = measureHelper(widthMeasureSpec, defaultSide); 90 | final int tmpH = measureHelper(heightMeasureSpec, defaultSide); 91 | final int finalSide = Math.min(tmpW, tmpH); 92 | 93 | prepare(finalSide); 94 | setMeasuredDimension(finalSide, finalSide); 95 | } 96 | 97 | private void prepare(float finalSize) { 98 | final float scale = finalSize / DIGIT_SIDE; 99 | 100 | mEnlargeMatrix.reset(); 101 | mNarrowMatrix.reset(); 102 | 103 | mEnlargeMatrix.setScale(scale, scale); 104 | mEnlargeMatrix.invert(mNarrowMatrix); 105 | 106 | mRicePath.reset(); 107 | mRicePath.moveTo(0F, 0F); 108 | mRicePath.lineTo(finalSize, finalSize); 109 | 110 | mRicePath.moveTo(0F, finalSize); 111 | mRicePath.lineTo(finalSize, 0F); 112 | 113 | final float halfSize = finalSize / 2; 114 | 115 | mRicePath.moveTo(0F, halfSize); 116 | mRicePath.lineTo(finalSize, halfSize); 117 | 118 | mRicePath.moveTo(halfSize, 0F); 119 | mRicePath.lineTo(halfSize, finalSize); 120 | 121 | mRicePath.close(); 122 | } 123 | 124 | private int measureHelper(int spec, int defaultSide) { 125 | final int size = MeasureSpec.getSize(spec); 126 | return Math.max(defaultSide, size); 127 | } 128 | 129 | @Override 130 | protected void onFinishInflate() { 131 | super.onFinishInflate(); 132 | reset(); 133 | } 134 | 135 | public void reset() { 136 | final Rect rect = new Rect(0, 0, mDigitBitmap.getWidth(), mDigitBitmap.getHeight()); 137 | mDigitCanvas.drawRect(rect, mBgPaint); 138 | 139 | postInvalidate(); 140 | } 141 | 142 | public void reset(int[] pixels) { 143 | mDigitBitmap.setPixels(pixels, 0, mDigitBitmap.getWidth(), 0, 0, mDigitBitmap.getWidth(), mDigitBitmap.getHeight()); 144 | 145 | postInvalidate(); 146 | } 147 | 148 | @Override 149 | protected void onDraw(Canvas canvas) { 150 | canvas.drawBitmap(mDigitBitmap, mEnlargeMatrix, mFgPaint); 151 | canvas.drawPath(mRicePath, mRicePaint); 152 | } 153 | 154 | @Override 155 | public boolean onTouchEvent(MotionEvent event) { 156 | final int action = event.getAction() & MotionEvent.ACTION_MASK; 157 | 158 | switch (action) { 159 | case MotionEvent.ACTION_DOWN: 160 | handleDownAction(event); 161 | return true; 162 | 163 | case MotionEvent.ACTION_UP: 164 | case MotionEvent.ACTION_MOVE: 165 | handleMoveAction(event); 166 | return true; 167 | 168 | default: 169 | return false; 170 | } 171 | } 172 | 173 | private void handleMoveAction(@NonNull MotionEvent motionEvent) { 174 | mOldPair.copy(mNewPair); 175 | mNewPair.pos(motionEvent.getX(), motionEvent.getY()); 176 | convert(mNewPair); 177 | 178 | drawLine(mOldPair, mNewPair); 179 | postInvalidate(); 180 | } 181 | 182 | /** 183 | * In order to deeper the color of the digit 184 | */ 185 | private void drawLine(@NonNull Pair oldPair, @NonNull Pair newPair) { 186 | for (int i = 0; i < OVERLAY_TIME; ++i) { 187 | mDigitCanvas.drawLine(oldPair.getX(), oldPair.getY(), 188 | newPair.getX(), newPair.getY(), mFgPaint); 189 | } 190 | } 191 | 192 | /** 193 | * get darkness, 0.0 for white and 1.0 for black pixel 194 | */ 195 | public io.whz.synapse.matrix.Matrix getDarkness() { 196 | final int side = DIGIT_SIDE; 197 | final int[] pixels = new int[side * side]; 198 | 199 | mDigitBitmap.getPixels(pixels, 0, side, 0, 0, side, side); 200 | 201 | final double[] doubles = MNISTUtil.convertBitmap2Darkness(pixels); 202 | 203 | return io.whz.synapse.matrix.Matrix.array(doubles, DIGIT_SIDE * DIGIT_SIDE); 204 | } 205 | 206 | public void markTrash() { 207 | mIsTrash = true; 208 | } 209 | 210 | private void handleDownAction(@NonNull MotionEvent motionEvent) { 211 | if (mIsTrash) { 212 | reset(); 213 | mIsTrash = false; 214 | } 215 | 216 | mOldPair.reset(); 217 | 218 | mNewPair.pos(motionEvent.getX(), motionEvent.getY()); 219 | convert(mNewPair); 220 | } 221 | 222 | private void convert(@NonNull Pair pair) { 223 | mNarrowMatrix.mapPoints(pair.getItems()); 224 | } 225 | 226 | private static final class Pair { 227 | private final float[] items = new float[2]; 228 | 229 | void pos(float x, float y) { 230 | items[0] = x; 231 | items[1] = y; 232 | } 233 | 234 | void copy(@NonNull Pair pair) { 235 | this.items[0] = pair.getX(); 236 | this.items[1] = pair.getY(); 237 | } 238 | 239 | void reset() { 240 | pos(0F, 0F); 241 | } 242 | 243 | float[] getItems() { 244 | return items; 245 | } 246 | 247 | float getX() { 248 | return this.items[0]; 249 | } 250 | 251 | float getY() { 252 | return this.items[1]; 253 | } 254 | } 255 | } 256 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/Dir.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import java.io.File; 6 | 7 | public class Dir { 8 | private static final String DOWNLOAD = "download"; 9 | private static final String DECOMPRESS = "decompress"; 10 | private static final String MNIST = "mnist"; 11 | private static final String TRAIN = "train"; 12 | private static final String TEST = "test"; 13 | 14 | public final File root; 15 | public final File download; 16 | public final File decompress; 17 | public final File mnist; 18 | public final File train; 19 | public final File test; 20 | 21 | public Dir(@NonNull File root) { 22 | this.root = root; 23 | 24 | this.download = new File(root, DOWNLOAD); 25 | this.decompress = new File(root, DECOMPRESS); 26 | this.mnist = new File(root, MNIST); 27 | this.train = new File(this.mnist, TRAIN); 28 | this.test = new File(this.mnist, TEST); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/FigureProvider.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.support.annotation.NonNull; 4 | import android.support.annotation.Nullable; 5 | 6 | import java.io.File; 7 | import java.util.concurrent.ThreadLocalRandom; 8 | 9 | import io.whz.synapse.neural.MNISTUtil; 10 | import io.whz.synapse.pojo.neural.Figure; 11 | import io.whz.synapse.util.Precondition; 12 | 13 | public class FigureProvider { 14 | private final File mFigureFile; 15 | private final ThreadLocalRandom mRandom; 16 | @Nullable private Figure[] mFigures; 17 | 18 | public FigureProvider(@NonNull File file) { 19 | Precondition.checkNotNull(file); 20 | 21 | mFigureFile = file; 22 | mRandom = ThreadLocalRandom.current(); 23 | } 24 | 25 | public void load() { 26 | mFigures = MNISTUtil.readFigures(mFigureFile); 27 | } 28 | 29 | @Nullable 30 | public Figure next() { 31 | if (mFigures == null 32 | || mFigures.length == 0) { 33 | return null; 34 | } 35 | 36 | return mFigures[mRandom.nextInt(mFigures.length)]; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/Global.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.content.SharedPreferences; 4 | import android.net.Uri; 5 | import android.support.annotation.NonNull; 6 | 7 | import org.greenrobot.eventbus.EventBus; 8 | 9 | import java.io.File; 10 | 11 | import io.whz.synapse.pojo.dao.DaoSession; 12 | import io.whz.synapse.util.Precondition; 13 | 14 | 15 | public class Global { 16 | private final Uri mBaseDownloadUri; 17 | private final String[] mDataSet; 18 | 19 | private final Singleton mBus; 20 | private final Singleton mSession; 21 | private final Singleton mDirs; 22 | private final Singleton mPreference; 23 | 24 | private Global() { 25 | mBaseDownloadUri = Uri.parse("http://yann.lecun.com/exdb/mnist"); 26 | mDataSet = new String[]{ 27 | "train-images-idx3-ubyte.gz", 28 | "train-labels-idx1-ubyte.gz", 29 | "t10k-images-idx3-ubyte.gz", 30 | "t10k-labels-idx1-ubyte.gz", 31 | }; 32 | 33 | mBus = new Singleton<>(); 34 | mSession = new Singleton<>(); 35 | mDirs = new Singleton<>(); 36 | mPreference = new Singleton<>(); 37 | } 38 | 39 | public String[] getDataSet() { 40 | return mDataSet; 41 | } 42 | 43 | public Uri getBaseDownloadUri() { 44 | return mBaseDownloadUri; 45 | } 46 | 47 | public void setPreference(@NonNull SharedPreferences preference) { 48 | mPreference.setAndLock(Precondition.checkNotNull(preference)); 49 | } 50 | 51 | public SharedPreferences getPreference() { 52 | return mPreference.get(); 53 | } 54 | 55 | public void setBus(@NonNull EventBus bus) { 56 | mBus.setAndLock(Precondition.checkNotNull(bus)); 57 | } 58 | 59 | public EventBus getBus() { 60 | return mBus.get(); 61 | } 62 | 63 | public void setSession(@NonNull DaoSession session) { 64 | mSession.setAndLock(session); 65 | } 66 | 67 | public DaoSession getSession() { 68 | return mSession.get(); 69 | } 70 | 71 | public void setRootDir(@NonNull File root) { 72 | final Dir dirs = new Dir(root); 73 | mDirs.setAndLock(dirs); 74 | } 75 | 76 | public Dir getDirs() { 77 | return mDirs.get(); 78 | } 79 | 80 | public boolean isDirSet() { 81 | return mDirs.isSet(); 82 | } 83 | 84 | public static Global getInstance() { 85 | return Holder.sInstance; 86 | } 87 | 88 | private interface Holder { 89 | Global sInstance = new Global(); 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/IThreadExecutor.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | interface IThreadExecutor { 4 | void execute(Runnable runnable); 5 | } 6 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/Scheduler.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.os.Handler; 4 | import android.os.Looper; 5 | 6 | import java.util.concurrent.ExecutorService; 7 | import java.util.concurrent.Executors; 8 | 9 | public enum Scheduler implements IThreadExecutor { 10 | Main(new MainThread()), Secondary(new SecondaryThread()); 11 | 12 | private final IThreadExecutor mExecutor; 13 | 14 | Scheduler(IThreadExecutor executor) { 15 | mExecutor = executor; 16 | } 17 | 18 | @Override 19 | public void execute(Runnable runnable) { 20 | mExecutor.execute(runnable); 21 | } 22 | 23 | private static class SecondaryThread implements IThreadExecutor { 24 | private static int sCoreThreadNum = Runtime.getRuntime().availableProcessors() + 1; 25 | private final ExecutorService mExecutorService = Executors.newFixedThreadPool(sCoreThreadNum); 26 | 27 | @Override 28 | public void execute(Runnable runnable) { 29 | mExecutorService.execute(runnable); 30 | } 31 | } 32 | 33 | private static class MainThread implements IThreadExecutor { 34 | private final Handler handler = new Handler(Looper.getMainLooper()); 35 | 36 | @Override 37 | public void execute(Runnable runnable) { 38 | handler.post(runnable); 39 | } 40 | } 41 | 42 | } 43 | 44 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/Singleton.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import java.util.concurrent.atomic.AtomicReference; 6 | 7 | import io.whz.synapse.component.App; 8 | import io.whz.synapse.util.Precondition; 9 | 10 | public class Singleton { 11 | private static final String TAG = App.TAG + "-Singleton"; 12 | private final AtomicReference mReference; 13 | 14 | public Singleton() { 15 | mReference = new AtomicReference<>(); 16 | } 17 | 18 | public Singleton setAndLock(@NonNull T object) { 19 | if (!mReference.compareAndSet(null, object)) { 20 | new UnsupportedOperationException("Already locked, can't set new instance again") 21 | .printStackTrace(); 22 | } 23 | 24 | return this; 25 | } 26 | 27 | public boolean isSet() { 28 | return mReference.get() != null; 29 | } 30 | 31 | public T get() { 32 | return Precondition.checkNotNull(mReference.get(), "You should bind before get instance"); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/element/VerticalGap.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.element; 2 | 3 | import android.graphics.Rect; 4 | import android.support.v7.widget.RecyclerView; 5 | import android.view.View; 6 | 7 | public class VerticalGap extends RecyclerView.ItemDecoration { 8 | private final int mSpace; 9 | 10 | public VerticalGap(int space) { 11 | mSpace = space; 12 | } 13 | 14 | @Override 15 | public void getItemOffsets(Rect outRect, View view, RecyclerView parent, RecyclerView.State state) { 16 | if (parent.getChildAdapterPosition(view) != 0) { 17 | outRect.top = mSpace; 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/matrix/Matrix.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.matrix; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import java.io.Serializable; 6 | import java.util.Random; 7 | 8 | import static io.whz.synapse.matrix.MatrixChecker.checkDimensions; 9 | import static io.whz.synapse.matrix.MatrixChecker.checkExpression; 10 | import static io.whz.synapse.matrix.MatrixChecker.checkInnerDimensions; 11 | import static io.whz.synapse.matrix.MatrixChecker.checkNotNull; 12 | import static io.whz.synapse.matrix.MatrixChecker.verifyArrays; 13 | 14 | public class Matrix implements Serializable { 15 | private final int mRow; 16 | private final int mCol; 17 | private final double[] mArray; 18 | 19 | private Matrix(double[] array, int row, int col) { 20 | mRow = row; 21 | mCol = col; 22 | mArray = array; 23 | } 24 | 25 | public int[] shape() { 26 | return new int[]{mRow, mCol}; 27 | } 28 | 29 | public double[] getArray() { 30 | return mArray; 31 | } 32 | 33 | public double[][] getArrays() { 34 | final double[][] arrays = new double[mRow][mCol]; 35 | 36 | for (int i = 0; i < mRow; ++i) { 37 | for (int j = 0; j < mCol; ++j) { 38 | arrays[i][j] = mArray[i * mCol + j]; 39 | } 40 | } 41 | 42 | return arrays; 43 | } 44 | 45 | public Matrix copy() { 46 | final double[] array = mArray.clone(); 47 | 48 | return Matrix.array(array, mRow); 49 | } 50 | 51 | public int getRow() { 52 | return mRow; 53 | } 54 | 55 | public int getCol() { 56 | return mCol; 57 | } 58 | 59 | public double get(int i, int j) { 60 | checkExpression(i >= 0 && i < mRow && j >= 0 && j < mCol, 61 | "Row or col is out of bounds"); 62 | 63 | return mArray[i * mCol + j]; 64 | } 65 | 66 | public void set(int i, int j, double value) { 67 | checkExpression(i >= 0 && i < mRow && j >= 0 && j < mCol, 68 | "Row or col is out of bounds"); 69 | 70 | mArray[i * mCol + j] = value; 71 | } 72 | 73 | public Matrix times(@NonNull Matrix matrix) { 74 | checkNotNull(matrix); 75 | 76 | return Matrix.times(this, matrix); 77 | } 78 | 79 | public Matrix times(double num) { 80 | return Matrix.times(this, num); 81 | } 82 | 83 | public Matrix timesTo(double num) { 84 | return Matrix.timesTo(this, num); 85 | } 86 | 87 | public Matrix timesTo(@NonNull Matrix matrix) { 88 | checkNotNull(matrix); 89 | 90 | return Matrix.timesTo(this, matrix); 91 | } 92 | 93 | public Matrix plus(@NonNull Matrix matrix) { 94 | checkNotNull(matrix); 95 | 96 | return Matrix.plus(this, matrix); 97 | } 98 | 99 | public Matrix plusTo(@NonNull Matrix matrix) { 100 | checkNotNull(matrix); 101 | 102 | return Matrix.plusTo(this, matrix); 103 | } 104 | 105 | public Matrix minus(@NonNull Matrix matrix) { 106 | checkNotNull(matrix); 107 | 108 | return Matrix.minus(this, matrix); 109 | } 110 | 111 | public Matrix minusTo(@NonNull Matrix matrix) { 112 | checkNotNull(matrix); 113 | 114 | return Matrix.minusTo(this, matrix); 115 | } 116 | 117 | public Matrix dot(@NonNull Matrix matrix) { 118 | checkNotNull(matrix); 119 | 120 | return Matrix.dot(this, matrix); 121 | } 122 | 123 | public Matrix transpose() { 124 | return Matrix.transpose(this); 125 | } 126 | 127 | public static Matrix array(@NonNull double[][] arrays) { 128 | verifyArrays(arrays); 129 | 130 | final int row = arrays.length; 131 | final int col = arrays[0].length; 132 | final double[] array = new double[row * col]; 133 | 134 | for (int i = 0; i < row; ++i) { 135 | System.arraycopy(arrays[i], 0, array, i * col, col); 136 | } 137 | 138 | return Matrix.array(array, row); 139 | } 140 | 141 | public static Matrix array(@NonNull double[] array, int row) { 142 | checkNotNull(array); 143 | checkExpression(row > 0, 144 | "Row should be positive"); 145 | 146 | final int col = array.length / row; 147 | checkExpression(row * col == array.length, 148 | "Array length must be a multiple of row"); 149 | 150 | return new Matrix(array.clone(), row, col); 151 | } 152 | 153 | public static Matrix randn(int row, int col) { 154 | checkDimensions(row, col); 155 | 156 | final Matrix matrix = Matrix.zeros(row, col); 157 | final double[] array = matrix.getArray(); 158 | final Random random = new Random(System.currentTimeMillis()); 159 | 160 | for (int i = 0, len = array.length; i < len; ++i) { 161 | array[i] = random.nextGaussian(); 162 | } 163 | 164 | return matrix; 165 | } 166 | 167 | public static Matrix plus(@NonNull Matrix a, @NonNull Matrix b) { 168 | checkNotNull(a); 169 | checkNotNull(b); 170 | checkDimensions(a, b); 171 | 172 | return plusTo(a.copy(), b); 173 | } 174 | 175 | public static Matrix plusTo(@NonNull Matrix a,@NonNull Matrix b) { 176 | checkNotNull(a); 177 | checkNotNull(b); 178 | checkDimensions(a, b); 179 | 180 | final double[] aArr = a.getArray(); 181 | final double[] bArr = b.getArray(); 182 | 183 | for (int i = 0, len = aArr.length; i < len; ++i) { 184 | aArr[i] += bArr[i]; 185 | } 186 | 187 | return a; 188 | } 189 | 190 | public static Matrix minus(@NonNull Matrix a,@NonNull Matrix b) { 191 | checkNotNull(a); 192 | checkNotNull(b); 193 | checkDimensions(a, b); 194 | 195 | return minusTo(a.copy(), b); 196 | } 197 | 198 | public static Matrix minusTo(@NonNull Matrix a, @NonNull Matrix b) { 199 | checkNotNull(a); 200 | checkNotNull(b); 201 | checkDimensions(a, b); 202 | 203 | final double[] aArr = a.getArray(); 204 | final double[] bArr = b.getArray(); 205 | 206 | for (int i = 0, len = aArr.length; i < len; ++i) { 207 | aArr[i] -= bArr[i]; 208 | } 209 | 210 | return a; 211 | } 212 | 213 | public static Matrix times(@NonNull Matrix a, @NonNull Matrix b) { 214 | checkNotNull(a); 215 | checkNotNull(b); 216 | checkDimensions(a, b); 217 | 218 | return timesTo(a.copy(), b); 219 | } 220 | 221 | public static Matrix times(@NonNull Matrix matrix, double num) { 222 | checkNotNull(matrix); 223 | 224 | return timesTo(matrix.copy(), num); 225 | } 226 | 227 | public static Matrix timesTo(@NonNull Matrix matrix, double num) { 228 | checkNotNull(matrix); 229 | 230 | final double[] array = matrix.getArray(); 231 | 232 | for (int i = 0, len = array.length; i < len; ++i) { 233 | array[i] *= num; 234 | } 235 | 236 | return matrix; 237 | } 238 | 239 | public static Matrix timesTo(@NonNull Matrix a, @NonNull Matrix b) { 240 | checkNotNull(a); 241 | checkNotNull(b); 242 | checkDimensions(a, b); 243 | 244 | final double[] aArr = a.getArray(); 245 | final double[] bArr = b.getArray(); 246 | 247 | for (int i = 0, len = aArr.length; i < len; ++i) { 248 | aArr[i] *= bArr[i]; 249 | } 250 | 251 | return a; 252 | } 253 | 254 | public static Matrix dot(@NonNull Matrix a, @NonNull Matrix b) { 255 | checkNotNull(a); 256 | checkNotNull(b); 257 | checkInnerDimensions(a, b); 258 | 259 | final Matrix c = zeros(a.getRow(), b.getCol()); 260 | final double[] cArr = c.getArray(); 261 | 262 | final double[] aArr = a.getArray(); 263 | final double[] bArr = b.getArray(); 264 | 265 | final int aRow = a.getRow(); 266 | final int bCol = b.getCol(); 267 | final int aCol = a.getCol(); 268 | double tmpNum; 269 | 270 | for (int i = 0; i < aRow; ++i) { 271 | for (int j = 0; j < bCol; ++j) { 272 | tmpNum = 0; 273 | 274 | for (int k = 0; k < aCol; ++k) { 275 | tmpNum += aArr[aCol * i + k] * bArr[bCol * k + j]; 276 | } 277 | 278 | cArr[i * bCol + j] = tmpNum; 279 | } 280 | } 281 | 282 | return c; 283 | } 284 | 285 | public static Matrix transpose(@NonNull Matrix a) { 286 | checkNotNull(a); 287 | 288 | final int aRow = a.getRow(); 289 | final int aCol = a.getCol(); 290 | final Matrix b = zeros(aCol, aRow); 291 | 292 | final double[] aArr = a.getArray(); 293 | final double[] bArr = b.getArray(); 294 | 295 | for (int i = 0; i < aRow; ++i) { 296 | for (int j = 0; j < aCol; ++j) { 297 | bArr[j * aRow + i] = aArr[i * aCol + j]; 298 | } 299 | } 300 | 301 | return b; 302 | } 303 | 304 | public static Matrix zeroLike(Matrix matrix) { 305 | checkNotNull(matrix); 306 | 307 | return zeros(matrix.shape()); 308 | } 309 | 310 | public static Matrix zeros(@NonNull int... shape) { 311 | checkNotNull(shape); 312 | checkExpression( shape.length == 2, "Shape is incorrect"); 313 | 314 | return new Matrix(new double[shape[0] * shape[1]], 315 | shape[0], shape[1]); 316 | } 317 | } 318 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/matrix/MatrixChecker.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.matrix; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import io.whz.synapse.util.Precondition; 6 | 7 | class MatrixChecker { 8 | static void verifyArrays(@NonNull double[][] arrays) { 9 | checkNotNull(arrays); 10 | 11 | final int m = arrays.length; 12 | checkExpression(m > 0, 13 | "Row should be positive"); 14 | 15 | final int n = arrays[0].length; 16 | for (int i = 1; i < m; ++i) { 17 | checkExpression(arrays[i].length == n, 18 | "All rows must have the same length"); 19 | } 20 | } 21 | 22 | static void checkNotNull(Object o) { 23 | Precondition.checkNotNull(o); 24 | } 25 | 26 | static void checkExpression(boolean expression, String message) { 27 | Precondition.checkArgument(expression, message); 28 | } 29 | 30 | static void checkExpression(boolean expression) { 31 | checkExpression(expression, 32 | "Expression Fail"); 33 | } 34 | 35 | static void checkDimensions(int row, int col) { 36 | checkExpression(row > 0 && col > 0, 37 | "Row and column should be positive"); 38 | } 39 | 40 | static void checkDimensions(@NonNull Matrix a, @NonNull Matrix b) { 41 | final int[] aDim = a.shape(); 42 | final int[] bDim = b.shape(); 43 | 44 | checkExpression(aDim[0] == bDim[0] && aDim[1] == bDim[1], 45 | "Matrix dimensions must agree"); 46 | } 47 | 48 | static void checkInnerDimensions(@NonNull Matrix a, @NonNull Matrix b) { 49 | checkExpression(a.getCol() == b.getRow(), 50 | "Matrix inner dimensions must agree"); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/neural/ActivateFunction.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.neural; 2 | 3 | import android.support.annotation.CheckResult; 4 | import android.support.annotation.NonNull; 5 | 6 | import io.whz.synapse.matrix.Matrix; 7 | import io.whz.synapse.util.Precondition; 8 | 9 | class ActivateFunction { 10 | @CheckResult 11 | static Matrix sigmoid(@NonNull Matrix matrix) { 12 | Precondition.checkNotNull(matrix); 13 | 14 | final Matrix copy = matrix.copy(); 15 | final double[] doubles = copy.getArray(); 16 | 17 | for (int i = 0, len = doubles.length; i < len; ++i) { 18 | final double cur = -doubles[i]; 19 | doubles[i] = 1D / (1D + Math.exp(cur)); 20 | } 21 | 22 | return copy; 23 | } 24 | 25 | @CheckResult 26 | static Matrix sigmoidPrime(@NonNull Matrix activation) { 27 | Precondition.checkNotNull(activation); 28 | 29 | final Matrix copy = activation.copy(); 30 | final double[] doubles = copy.getArray(); 31 | 32 | for (int i = 0, iLen = doubles.length; i < iLen; ++i) { 33 | final double cur = doubles[i]; 34 | doubles[i] = cur * (1 - cur); 35 | } 36 | 37 | return copy; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/neural/DataSet.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.neural; 2 | 3 | import android.support.annotation.NonNull; 4 | import android.support.annotation.Nullable; 5 | 6 | import java.io.File; 7 | import java.util.ArrayList; 8 | import java.util.Arrays; 9 | import java.util.Collections; 10 | import java.util.List; 11 | 12 | import io.whz.synapse.component.App; 13 | import io.whz.synapse.matrix.Matrix; 14 | import io.whz.synapse.pojo.neural.Batch; 15 | import io.whz.synapse.pojo.neural.Digit; 16 | import io.whz.synapse.util.Precondition; 17 | 18 | public class DataSet { 19 | private static final String TAG = App.TAG + "-DataSet"; 20 | private static final int DEFAULT_MIN_BATCH = 20; 21 | private static final int DIGIT_COUNT = 10; 22 | private static final int PIXEL_COUNT = 784; 23 | 24 | private final int mMiniBatch; 25 | private final List mBatchFiles = new ArrayList<>(); 26 | private final List mCurFiles = new ArrayList<>(); 27 | 28 | private Digit[] mCurDigits; 29 | private int mRemain; 30 | private int mIndex; 31 | 32 | public DataSet(@NonNull File... batchFiles) { 33 | mMiniBatch = DEFAULT_MIN_BATCH; 34 | mBatchFiles.addAll(Arrays.asList(batchFiles)); 35 | 36 | reset(); 37 | } 38 | 39 | public DataSet(int miniBatch, @NonNull File... batchFiles) { 40 | Precondition.checkArgument(miniBatch <= 0, "MiniBatch must be positive"); 41 | 42 | mMiniBatch = miniBatch; 43 | mBatchFiles.addAll(Arrays.asList(batchFiles)); 44 | 45 | reset(); 46 | } 47 | 48 | private void reset() { 49 | mCurFiles.clear(); 50 | mCurFiles.addAll(mBatchFiles); 51 | mCurDigits = null; 52 | mRemain = 0; 53 | mIndex = 0; 54 | } 55 | 56 | @Nullable 57 | public Batch nextBatch() { 58 | final List batch = new ArrayList<>(); 59 | int need = mMiniBatch; 60 | 61 | while (need > 0) { 62 | while (mRemain <= 0 && !mCurFiles.isEmpty()) { 63 | final File file = mCurFiles.remove(0); 64 | mCurDigits = nextDigits(file); 65 | 66 | if (mCurDigits != null && mCurDigits.length != 0) { 67 | mRemain = mCurDigits.length; 68 | mIndex = 0; 69 | } 70 | } 71 | 72 | if (mRemain <= 0) { 73 | break; 74 | } 75 | 76 | final int size = Math.min(need, mRemain); 77 | need -= size; 78 | mRemain -= size; 79 | 80 | final Digit[] window = new Digit[size]; 81 | System.arraycopy(mCurDigits, mIndex, window, 0, size); 82 | mIndex += size; 83 | batch.addAll(Arrays.asList(window)); 84 | } 85 | 86 | return convert2Matrix(batch); 87 | } 88 | 89 | private Batch convert2Matrix(@NonNull List digits) { 90 | if (digits.isEmpty()) { 91 | return null; 92 | } 93 | 94 | final int len = digits.size(); 95 | 96 | final Matrix[] inputs = new Matrix[len]; 97 | final Matrix[] targets = new Matrix[len]; 98 | Digit digit; 99 | 100 | for (int i = 0; i < len; ++i) { 101 | digit = digits.get(i); 102 | 103 | inputs[i] = Matrix.array(digit.colors, PIXEL_COUNT); 104 | targets[i] = oneHot(digit.label); 105 | } 106 | 107 | return new Batch(inputs, targets); 108 | } 109 | 110 | private Matrix oneHot(int actual) { 111 | final double[] doubles = new double[DIGIT_COUNT]; 112 | doubles[actual] = 1D; 113 | 114 | return Matrix.array(doubles, DIGIT_COUNT); 115 | } 116 | 117 | @Nullable 118 | private Digit[] nextDigits(@NonNull File file) { 119 | final Digit[] res = MNISTUtil.readBatches(file); 120 | 121 | return res == null ? null : normalize(res); 122 | } 123 | 124 | private Digit[] normalize(@NonNull Digit[] digits) { 125 | final List tmp = new ArrayList<>(Arrays.asList(digits)); 126 | Collections.shuffle(tmp); 127 | tmp.toArray(digits); 128 | 129 | return digits; 130 | } 131 | 132 | public void shuffle() { 133 | Collections.shuffle(mBatchFiles); 134 | reset(); 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/neural/NeuralNetwork.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.neural; 2 | 3 | import android.support.annotation.NonNull; 4 | import android.support.v4.util.Pair; 5 | 6 | import org.greenrobot.greendao.annotation.NotNull; 7 | 8 | import java.util.Arrays; 9 | 10 | import io.whz.synapse.component.App; 11 | import io.whz.synapse.matrix.Matrix; 12 | import io.whz.synapse.pojo.neural.Batch; 13 | import io.whz.synapse.util.MatrixUtil; 14 | import io.whz.synapse.util.Precondition; 15 | 16 | public class NeuralNetwork { 17 | private static final String TAG = App.TAG + "-NeuralNetwork"; 18 | 19 | public static final int INPUT_LAYER_NUMBER = 784; 20 | public static final int OUTPUT_LAYER_NUMBER = 10; 21 | 22 | private final Matrix[] mWeights; 23 | private final Matrix[] mBiases; 24 | 25 | public NeuralNetwork(@NonNull int... hiddenLayerSizes) { 26 | Precondition.checkNotNull(hiddenLayerSizes, "Hidden layers must not be null!"); 27 | 28 | final int hiddenLen; 29 | Precondition.checkArgument((hiddenLen = hiddenLayerSizes.length) != 0, "Size of Hidden layers must not be zero!"); 30 | 31 | final int[] totalSizes = new int[hiddenLen + 2]; 32 | System.arraycopy(hiddenLayerSizes, 0, totalSizes, 1, hiddenLen); 33 | totalSizes[0] = INPUT_LAYER_NUMBER; 34 | totalSizes[totalSizes.length - 1] = OUTPUT_LAYER_NUMBER; 35 | 36 | mBiases = newBiasesMatrix(totalSizes); 37 | mWeights = newWeightsMatrices(totalSizes); 38 | } 39 | 40 | public NeuralNetwork(@NonNull Matrix[] weights, @NonNull Matrix[] biases) { 41 | Precondition.checkNotNull(weights); 42 | Precondition.checkNotNull(biases); 43 | 44 | mWeights = weights; 45 | mBiases = biases; 46 | } 47 | 48 | private Matrix[] newWeightsMatrices(int[] totalSize) { 49 | final int len = totalSize.length - 1; 50 | 51 | final int[] rows = new int[len]; 52 | System.arraycopy(totalSize, 1, rows, 0, len); 53 | 54 | final int[] cols = new int[len]; 55 | System.arraycopy(totalSize, 0, cols, 0, len); 56 | 57 | return MatrixUtil.randns(rows, cols); 58 | } 59 | 60 | private Matrix[] newBiasesMatrix(int[] totalSize) { 61 | final int len = totalSize.length - 1; 62 | 63 | final int[] rows = new int[len]; 64 | System.arraycopy(totalSize, 1, rows, 0, rows.length); 65 | 66 | final int[] cols = new int[len]; 67 | Arrays.fill(cols, 1); 68 | 69 | return MatrixUtil.randns(rows, cols); 70 | } 71 | 72 | public Matrix[] getBiases() { 73 | return mBiases; 74 | } 75 | 76 | public Matrix[] getWeights() { 77 | return mWeights; 78 | } 79 | 80 | public void train(int epochs, double learningRate, 81 | @NonNull DataSet trainDataSet, @NotNull DataSet validateDataSet, 82 | @NonNull DataSet testDataSet, @NonNull TrainCallback callback) { 83 | Precondition.checkArgument(epochs > 0, "Epochs must greater than 0"); 84 | Precondition.checkArgument(learningRate > 0D, "Learning rate must greater than 0"); 85 | Precondition.checkNotNull(trainDataSet); 86 | 87 | new TrainRunnable(epochs, learningRate, trainDataSet, 88 | validateDataSet, testDataSet, mBiases, mWeights, callback) 89 | .run(); 90 | } 91 | 92 | private static class TrainRunnable implements Runnable { 93 | private final int mEpochs; 94 | private final double mLearningRate; 95 | private final DataSet mTraining; 96 | private final DataSet mValidation; 97 | private final DataSet mTest; 98 | private final Matrix[] mBiases; 99 | private final Matrix[] mWeights; 100 | private TrainCallback mCallback; 101 | 102 | private TrainRunnable(int epochs, double learningRate, DataSet training, 103 | DataSet validation, DataSet test, Matrix[] biases, Matrix[] weights, 104 | TrainCallback callback) { 105 | mEpochs = epochs; 106 | mLearningRate = learningRate; 107 | mTraining = training; 108 | mValidation = validation; 109 | mTest = test; 110 | mBiases = biases; 111 | mWeights = weights; 112 | mCallback = callback; 113 | } 114 | 115 | @Override 116 | public void run() { 117 | mCallback.onStart(); 118 | 119 | for (int i = 1; i <= mEpochs; ++i) { 120 | mTraining.shuffle(); 121 | sgd(); 122 | 123 | final double rate = evaluate(mValidation); 124 | 125 | if (!mCallback.onUpdate(i, rate)) { 126 | mCallback.onEvaluate(); 127 | return; 128 | } 129 | } 130 | 131 | mCallback.onEvaluate(); 132 | 133 | final double rate = evaluate(mTest); 134 | mCallback.onComplete(rate); 135 | } 136 | 137 | private void sgd() { 138 | Batch batch; 139 | 140 | while ((batch = mTraining.nextBatch()) != null) { 141 | updateMiniBatch(batch.inputs, batch.targets); 142 | } 143 | } 144 | 145 | private void updateMiniBatch(Matrix[] inputs, Matrix[] targets) { 146 | final Matrix[] batchWeights = MatrixUtil.zerosLike(mWeights); 147 | final Matrix[] batchBiases = MatrixUtil.zerosLike(mBiases); 148 | final int len = inputs.length; 149 | 150 | for (int i = 0; i < len; ++i) { 151 | feed(inputs[i], targets[i], batchWeights, batchBiases); 152 | } 153 | 154 | update(batchWeights, batchBiases, len); 155 | } 156 | 157 | private void update(Matrix[] batchWeights, Matrix[] batchBiases, int count) { 158 | final int len = mWeights.length; 159 | final double tmp = mLearningRate / count; 160 | 161 | for (int i = 0; i < len; ++i) { 162 | mWeights[i].minusTo(batchWeights[i].times(tmp)); 163 | mBiases[i].minusTo(batchBiases[i].times(tmp)); 164 | } 165 | } 166 | 167 | private void feed(Matrix input, Matrix target, Matrix[] batchWeights, Matrix[] batchBiases) { 168 | final Matrix[] activations = forwardPropagation(input); 169 | final int aLen = activations.length; 170 | final int bLen = batchWeights.length; 171 | 172 | final Matrix error = activations[aLen - 1].minus(target); 173 | Matrix delta = error.times(ActivateFunction.sigmoidPrime(activations[aLen - 1])); 174 | 175 | batchBiases[bLen - 1].plusTo(delta); 176 | batchWeights[bLen - 1].plusTo(delta.dot(activations[aLen - 2].transpose())); 177 | 178 | for (int i = 2; i < aLen; ++i) { 179 | final Matrix prime = ActivateFunction.sigmoidPrime(activations[aLen - i]); 180 | delta = mWeights[bLen - i + 1].transpose() 181 | .dot(delta) 182 | .times(prime); 183 | 184 | batchBiases[bLen - i].plusTo(delta); 185 | batchWeights[bLen - i].plusTo(delta.dot(activations[aLen - i - 1].transpose())); 186 | } 187 | } 188 | 189 | private Matrix[] forwardPropagation(@NonNull Matrix input) { 190 | final int len = mWeights.length; 191 | final Matrix[] activations = new Matrix[len + 1]; 192 | activations[0] = input; 193 | 194 | for (int i = 0; i < len; ++i) { 195 | final Matrix matrix = mWeights[i] 196 | .dot(activations[i]) 197 | .plus(mBiases[i]); 198 | 199 | activations[i + 1] = ActivateFunction.sigmoid(matrix); 200 | } 201 | 202 | return activations; 203 | } 204 | 205 | private double evaluate(@NonNull DataSet dataSet) { 206 | Batch batch; 207 | int correct = 0; 208 | int total = 0; 209 | 210 | dataSet.shuffle(); 211 | 212 | while ((batch = dataSet.nextBatch()) != null) { 213 | final Matrix[] inputs = batch.inputs; 214 | final Matrix[] targets = batch.targets; 215 | final int len = inputs.length; 216 | total += len; 217 | 218 | for (int i = 0; i < len; ++i) { 219 | final Matrix output = feedForward(mWeights, mBiases, inputs[i]); 220 | 221 | if (MatrixUtil.argmax(output) == MatrixUtil.index(targets[i])) { 222 | ++correct; 223 | } 224 | } 225 | } 226 | 227 | return total != 0 ? (double) correct / total : 0D; 228 | } 229 | } 230 | 231 | public Pair predict(@NonNull Matrix input) { 232 | return MatrixUtil.findMax(feedForward(mWeights, mBiases, input)); 233 | } 234 | 235 | private static Matrix feedForward(@NonNull Matrix[] weights, @NonNull Matrix[] biases, 236 | @NonNull Matrix input) { 237 | final int len = weights.length; 238 | Matrix res = input; 239 | 240 | for (int i = 0; i < len; ++i) { 241 | res = ActivateFunction.sigmoid(weights[i].dot(res).plus(biases[i])); 242 | } 243 | 244 | return res; 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/neural/TrainCallback.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.neural; 2 | 3 | public interface TrainCallback { 4 | void onStart(); 5 | 6 | boolean onUpdate(int progress, double accurate); 7 | 8 | void onEvaluate(); 9 | 10 | void onComplete(double evaluate); 11 | } 12 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/constant/PreferenceCons.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.constant; 2 | 3 | public interface PreferenceCons { 4 | String IS_DATA_SET_READY = "is_downloaded_data_set"; 5 | String IS_FIST_PLAY = "is_fist_play"; 6 | String SHOULD_ADD_DEMO_MODEL = "should_add_demo_model"; 7 | } 8 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/constant/TrackCons.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.constant; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import io.whz.synapse.BuildConfig; 6 | import io.whz.synapse.component.AboutDialog; 7 | import io.whz.synapse.component.App; 8 | import io.whz.synapse.component.MainActivity; 9 | import io.whz.synapse.component.MainService; 10 | import io.whz.synapse.component.ModelDetailActivity; 11 | import io.whz.synapse.component.NeuralModelActivity; 12 | import io.whz.synapse.component.PlayActivity; 13 | 14 | public class TrackCons { 15 | private static final String SPLIT = "_"; 16 | private static final String INDEX = BuildConfig.DEBUG ? "debug" : ""; 17 | 18 | public interface Main { 19 | String INDEX = concat(TrackCons.INDEX, MainActivity.class.getSimpleName().toLowerCase()); 20 | 21 | String CLICK_DOWNLOAD = concat(INDEX, "click_download"); 22 | String CLICK_FAB = concat(INDEX, "click_fab"); 23 | String CLICK_PLAY = concat(INDEX, "click_play"); 24 | String CLICK_TRAINED = concat(INDEX, "click_trained"); 25 | String CLICK_TRAINING = concat(INDEX, "click_training"); 26 | String CLICK_ABOUT = concat(INDEX, "click_about"); 27 | 28 | String SCROLL_BOTTOM = concat(INDEX, "scroll_bottom"); 29 | } 30 | 31 | public interface Play { 32 | String INDEX = concat(TrackCons.INDEX, PlayActivity.class.getSimpleName().toLowerCase()); 33 | 34 | String CLICK_MODEL_SELECTION = concat(INDEX, "click_model_selection"); 35 | String CLICK_VIEW_DETAIL = concat(INDEX, "click_view_detail"); 36 | 37 | String PLAY_MNIST = concat(INDEX, "play_mnist"); 38 | String PLAY_HAND_WRITE = concat(INDEX, "play_hand_write"); 39 | } 40 | 41 | public interface Model { 42 | String INDEX = concat(TrackCons.INDEX, NeuralModelActivity.class.getSimpleName().toLowerCase()); 43 | 44 | String CLICK_ADD_NEW_LAYER = concat(INDEX, "click_add_new_layer"); 45 | String CLICK_LAYER_DELETE = concat(INDEX, "click_layer_delete"); 46 | String CLICK_TRAIN = concat(INDEX, "click_train"); 47 | } 48 | 49 | public interface Detail { 50 | String INDEX = concat(TrackCons.INDEX, ModelDetailActivity.class.getSimpleName().toLowerCase()); 51 | 52 | String CLICK_INTERRUPT = concat(INDEX, "click_interrupt"); 53 | String CLICK_DELETE = concat(INDEX, "click_delete"); 54 | String CLICK_PLAY = concat(INDEX, "click_play"); 55 | } 56 | 57 | public interface About { 58 | String INDEX = concat(TrackCons.INDEX, AboutDialog.class.getSimpleName().toLowerCase()); 59 | 60 | String CLICK_GITHUB = concat(INDEX, "click_github"); 61 | String CLICK_RATE = concat(INDEX, "click_rate"); 62 | String CLICK_SHARE = concat(INDEX, "click_share"); 63 | } 64 | 65 | public interface Service { 66 | String INDEX = concat(TrackCons.INDEX, MainService.class.getSimpleName().toLowerCase()); 67 | 68 | String DOWNLOAD = concat(INDEX, "download"); 69 | String DECOMPRESS = concat(INDEX, "decompress"); 70 | String TRAIN = concat(INDEX, "train"); 71 | String INTERRUPT_REQ = concat(INDEX, "interrupt_req"); 72 | String INTERRUPT = concat(INDEX, "interrupt"); 73 | } 74 | 75 | public interface APP { 76 | String INDEX = concat(TrackCons.INDEX, App.class.getSimpleName().toLowerCase()); 77 | 78 | String INITIALIZE = concat(INDEX, "initialize"); 79 | String CAUGHT = concat(INDEX, "catch_exception"); 80 | } 81 | 82 | public interface Key { 83 | String SUCCESS = "success"; 84 | String MSG = "msg"; 85 | String TIME_USED = "time_used"; 86 | } 87 | 88 | public interface Lifecycle { 89 | String INDEX = concat(TrackCons.INDEX, "lifecycle"); 90 | 91 | String ENTER = concat(INDEX, "enter"); 92 | String LEAVE = concat(INDEX, "leave"); 93 | } 94 | 95 | public static String concat(@NonNull String... array) { 96 | final StringBuilder builder = new StringBuilder(); 97 | 98 | for (int i = 0, len = array.length; i < len; ++i) { 99 | if (builder.length() != 0) { 100 | builder.append(SPLIT); 101 | } 102 | 103 | builder.append(array[i]); 104 | } 105 | 106 | return builder.toString(); 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/dao/DBModel.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.dao; 2 | 3 | import org.greenrobot.greendao.annotation.Entity; 4 | import org.greenrobot.greendao.annotation.Id; 5 | import org.greenrobot.greendao.annotation.Unique; 6 | import org.greenrobot.greendao.annotation.Generated; 7 | 8 | @Entity 9 | public class DBModel { 10 | @Id(autoincrement = true) 11 | private Long id; 12 | 13 | @Unique 14 | private String name; 15 | private long createdTime; 16 | private double learningRate; 17 | private int epochs; 18 | private int dataSize; 19 | private long timeUsed; 20 | private double evaluate; 21 | private byte[] hiddenSizeBytes; 22 | private byte[] accuracyBytes; 23 | private byte[] biasBytes; 24 | private byte[] weightBytes; 25 | @Generated(hash = 156277296) 26 | public DBModel(Long id, String name, long createdTime, double learningRate, 27 | int epochs, int dataSize, long timeUsed, double evaluate, 28 | byte[] hiddenSizeBytes, byte[] accuracyBytes, byte[] biasBytes, 29 | byte[] weightBytes) { 30 | this.id = id; 31 | this.name = name; 32 | this.createdTime = createdTime; 33 | this.learningRate = learningRate; 34 | this.epochs = epochs; 35 | this.dataSize = dataSize; 36 | this.timeUsed = timeUsed; 37 | this.evaluate = evaluate; 38 | this.hiddenSizeBytes = hiddenSizeBytes; 39 | this.accuracyBytes = accuracyBytes; 40 | this.biasBytes = biasBytes; 41 | this.weightBytes = weightBytes; 42 | } 43 | @Generated(hash = 1265045159) 44 | public DBModel() { 45 | } 46 | public Long getId() { 47 | return this.id; 48 | } 49 | public void setId(Long id) { 50 | this.id = id; 51 | } 52 | public String getName() { 53 | return this.name; 54 | } 55 | public void setName(String name) { 56 | this.name = name; 57 | } 58 | public long getCreatedTime() { 59 | return this.createdTime; 60 | } 61 | public void setCreatedTime(long createdTime) { 62 | this.createdTime = createdTime; 63 | } 64 | public double getLearningRate() { 65 | return this.learningRate; 66 | } 67 | public void setLearningRate(double learningRate) { 68 | this.learningRate = learningRate; 69 | } 70 | public int getEpochs() { 71 | return this.epochs; 72 | } 73 | public void setEpochs(int epochs) { 74 | this.epochs = epochs; 75 | } 76 | public int getDataSize() { 77 | return this.dataSize; 78 | } 79 | public void setDataSize(int dataSize) { 80 | this.dataSize = dataSize; 81 | } 82 | public long getTimeUsed() { 83 | return this.timeUsed; 84 | } 85 | public void setTimeUsed(long timeUsed) { 86 | this.timeUsed = timeUsed; 87 | } 88 | public double getEvaluate() { 89 | return this.evaluate; 90 | } 91 | public void setEvaluate(double evaluate) { 92 | this.evaluate = evaluate; 93 | } 94 | public byte[] getHiddenSizeBytes() { 95 | return this.hiddenSizeBytes; 96 | } 97 | public void setHiddenSizeBytes(byte[] hiddenSizeBytes) { 98 | this.hiddenSizeBytes = hiddenSizeBytes; 99 | } 100 | public byte[] getAccuracyBytes() { 101 | return this.accuracyBytes; 102 | } 103 | public void setAccuracyBytes(byte[] accuracyBytes) { 104 | this.accuracyBytes = accuracyBytes; 105 | } 106 | public byte[] getBiasBytes() { 107 | return this.biasBytes; 108 | } 109 | public void setBiasBytes(byte[] biasBytes) { 110 | this.biasBytes = biasBytes; 111 | } 112 | public byte[] getWeightBytes() { 113 | return this.weightBytes; 114 | } 115 | public void setWeightBytes(byte[] weightBytes) { 116 | this.weightBytes = weightBytes; 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/event/MANEvent.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.event; 2 | 3 | import android.support.annotation.IntDef; 4 | 5 | import java.lang.annotation.Retention; 6 | import java.lang.annotation.RetentionPolicy; 7 | 8 | /** 9 | * Main Activity Normal Event 10 | */ 11 | public class MANEvent extends TypeEvent { 12 | public static final int CLICK_DOWNLOAD = 0x01; 13 | public static final int DOWNLOAD_COMPLETE = 0x01 << 1; 14 | public static final int DECOMPRESS_COMPLETE = 0x01 << 2; 15 | public static final int REJECT_MSG = 0x01 << 3; 16 | public static final int JUMP_TO_PLAY = 0x01 << 4; 17 | public static final int JUMP_TO_TRAINED = 0x01 << 5; 18 | public static final int JUMP_TO_TRAINING = 0x01 << 6; 19 | 20 | @IntDef({ 21 | CLICK_DOWNLOAD, DOWNLOAD_COMPLETE, DECOMPRESS_COMPLETE, 22 | REJECT_MSG, JUMP_TO_PLAY, JUMP_TO_TRAINED, JUMP_TO_TRAINING 23 | }) 24 | @Retention(RetentionPolicy.SOURCE) 25 | public @interface Event { 26 | } 27 | 28 | public MANEvent(int what, T obj) { 29 | super(what, obj); 30 | } 31 | 32 | public MANEvent(int what) { 33 | super(what); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/event/MSNEvent.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.event; 2 | 3 | public class MSNEvent extends TypeEvent { 4 | public static final int DOWNLOAD_COMPLETE = 0x01; 5 | public static final int DECOMPRESS_COMPLETE = 0x01 << 1; 6 | public static final int TRAIN_INTERRUPT = 0x01 << 2; 7 | 8 | public MSNEvent(int what, T obj) { 9 | super(what, obj); 10 | } 11 | 12 | public MSNEvent(int what) { 13 | super(what); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/event/ModelDeletedEvent.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.event; 2 | 3 | public class ModelDeletedEvent extends NormalEvent { 4 | public ModelDeletedEvent() { 5 | super(null); 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/event/NormalEvent.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.event; 2 | 3 | abstract class NormalEvent extends TypeEvent{ 4 | private static final int WHAT = 0xFF; 5 | 6 | NormalEvent(T obj) { 7 | super(WHAT, obj); 8 | } 9 | 10 | private NormalEvent() { 11 | super(WHAT); 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/event/TrackEvent.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.event; 2 | 3 | import android.os.Bundle; 4 | import android.support.annotation.NonNull; 5 | import android.support.annotation.Nullable; 6 | 7 | import io.whz.synapse.util.Precondition; 8 | 9 | public class TrackEvent { 10 | @NonNull 11 | public final String id; 12 | @Nullable 13 | public final Bundle bundle; 14 | 15 | public TrackEvent(@NonNull String id, @Nullable Bundle bundle) { 16 | Precondition.checkNotNull(id); 17 | 18 | this.id = id; 19 | this.bundle = bundle; 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/event/TrainEvent.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.event; 2 | 3 | import android.support.annotation.IntDef; 4 | 5 | import java.lang.annotation.Retention; 6 | import java.lang.annotation.RetentionPolicy; 7 | 8 | public class TrainEvent extends TypeEvent { 9 | public static final int START = 0x01; 10 | public static final int UPDATE = 0x01 << 1; 11 | public static final int EVALUATE = 0x01 << 2; 12 | public static final int COMPLETE = 0x01 << 3; 13 | public static final int ERROR = 0x01 << 4; 14 | public static final int INTERRUPTED = 0x01 << 5; 15 | 16 | @Retention(RetentionPolicy.SOURCE) 17 | @IntDef({START, UPDATE, EVALUATE, COMPLETE, ERROR, INTERRUPTED}) 18 | public @interface Type {} 19 | 20 | public TrainEvent(@Type int what, T obj) { 21 | super(what, obj); 22 | } 23 | 24 | public TrainEvent(@Type int what) { 25 | super(what); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/event/TypeEvent.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.event; 2 | 3 | public abstract class TypeEvent { 4 | public final int what; 5 | public final T obj; 6 | 7 | public TypeEvent(int what, T obj) { 8 | this.what = what; 9 | this.obj = obj; 10 | } 11 | 12 | public TypeEvent(int what) { 13 | this.what = what; 14 | this.obj = null; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/multiple/binder/PlayViewBinder.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.multiple.binder; 2 | 3 | import android.support.annotation.NonNull; 4 | import android.support.v7.widget.RecyclerView; 5 | import android.view.LayoutInflater; 6 | import android.view.View; 7 | import android.view.ViewGroup; 8 | 9 | import io.whz.synapse.R; 10 | import io.whz.synapse.element.Global; 11 | import io.whz.synapse.pojo.event.MANEvent; 12 | import io.whz.synapse.pojo.multiple.item.PlayItem; 13 | import me.drakeet.multitype.ItemViewBinder; 14 | 15 | public class PlayViewBinder extends ItemViewBinder 16 | implements View.OnClickListener { 17 | 18 | @NonNull 19 | @Override 20 | protected PlayViewHolder onCreateViewHolder(@NonNull LayoutInflater layoutInflater, @NonNull ViewGroup viewGroup) { 21 | final PlayViewHolder holder = PlayViewHolder.newInstance(layoutInflater, viewGroup); 22 | 23 | holder.itemView.setOnClickListener(this); 24 | 25 | return holder; 26 | } 27 | 28 | @Override 29 | protected void onBindViewHolder(@NonNull PlayViewHolder playViewHolder, @NonNull PlayItem playItem) {} 30 | 31 | @Override 32 | public void onClick(View view) { 33 | Global.getInstance() 34 | .getBus() 35 | .post(new MANEvent<>(MANEvent.JUMP_TO_PLAY)); 36 | } 37 | 38 | static class PlayViewHolder extends RecyclerView.ViewHolder { 39 | PlayViewHolder(View itemView) { 40 | super(itemView); 41 | } 42 | 43 | static PlayViewHolder newInstance(@NonNull LayoutInflater layoutInflater, @NonNull ViewGroup viewGroup) { 44 | final View view = layoutInflater.inflate(R.layout.item_paly, viewGroup, false); 45 | return new PlayViewHolder(view); 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/multiple/binder/TrainedModelViewBinder.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.multiple.binder; 2 | 3 | import android.content.Context; 4 | import android.support.annotation.NonNull; 5 | import android.support.annotation.Nullable; 6 | import android.support.v4.content.ContextCompat; 7 | import android.support.v7.widget.RecyclerView; 8 | import android.view.LayoutInflater; 9 | import android.view.View; 10 | import android.view.ViewGroup; 11 | import android.widget.TextView; 12 | 13 | import com.github.mikephil.charting.charts.LineChart; 14 | import com.github.mikephil.charting.components.YAxis; 15 | import com.github.mikephil.charting.data.Entry; 16 | import com.github.mikephil.charting.data.LineData; 17 | import com.github.mikephil.charting.data.LineDataSet; 18 | 19 | import java.util.List; 20 | 21 | import io.whz.synapse.R; 22 | import io.whz.synapse.element.Global; 23 | import io.whz.synapse.pojo.neural.Model; 24 | import io.whz.synapse.pojo.event.MANEvent; 25 | import io.whz.synapse.pojo.multiple.item.TrainedModelItem; 26 | import io.whz.synapse.util.StringFormatUtil; 27 | import me.drakeet.multitype.ItemViewBinder; 28 | 29 | public class TrainedModelViewBinder extends ItemViewBinder 30 | implements View.OnClickListener { 31 | public static final int[] FG = new int[]{ 32 | R.color.item_chart_fg$1, 33 | R.color.item_chart_fg$2, 34 | R.color.item_chart_fg$3, 35 | R.color.item_chart_fg$4, 36 | R.color.item_chart_fg$5, 37 | R.color.item_chart_fg$6, 38 | R.color.item_chart_fg$7, 39 | R.color.item_chart_fg$8, 40 | R.color.item_chart_fg$9, 41 | R.color.item_chart_fg$10 42 | }; 43 | 44 | public static final int[] BG = new int[]{ 45 | R.color.item_chart_bg$1, 46 | R.color.item_chart_bg$2, 47 | R.color.item_chart_bg$3, 48 | R.color.item_chart_bg$4, 49 | R.color.item_chart_bg$5, 50 | R.color.item_chart_bg$6, 51 | R.color.item_chart_bg$7, 52 | R.color.item_chart_bg$8, 53 | R.color.item_chart_bg$9, 54 | R.color.item_chart_bg$10 55 | }; 56 | 57 | @NonNull 58 | @Override 59 | protected TrainedModelViewHolder onCreateViewHolder(@NonNull LayoutInflater layoutInflater, @NonNull ViewGroup viewGroup) { 60 | final TrainedModelViewHolder holder = TrainedModelViewHolder.newInstance(layoutInflater, viewGroup); 61 | 62 | holder.itemView.setOnClickListener(this); 63 | prepareChart(holder.chart); 64 | 65 | return holder; 66 | } 67 | 68 | @Override 69 | public void onClick(View view) { 70 | final Long id = (Long) view.getTag(); 71 | 72 | if (id == null) { 73 | return; 74 | } 75 | 76 | Global.getInstance() 77 | .getBus() 78 | .post(new MANEvent<>(MANEvent.JUMP_TO_TRAINED, id)); 79 | } 80 | 81 | private void prepareChart(@NonNull LineChart chart) { 82 | chart.getDescription().setEnabled(false); 83 | chart.setTouchEnabled(false); 84 | chart.setDragEnabled(false); 85 | chart.setScaleEnabled(false); 86 | chart.setHighlightPerDragEnabled(false); 87 | chart.setPinchZoom(false); 88 | chart.setDrawGridBackground(true); 89 | chart.getLegend().setEnabled(false); 90 | chart.getAxisRight().setEnabled(false); 91 | chart.getAxisLeft().setEnabled(false); 92 | chart.getXAxis().setEnabled(false); 93 | chart.setViewPortOffsets(0, 0, 0, 0); 94 | } 95 | 96 | @Override 97 | protected void onBindViewHolder(@NonNull TrainedModelViewHolder holder, @NonNull TrainedModelItem item) { 98 | if (holder.data == null) { 99 | holder.data = prepareInitData(holder.chart, item.getEntries()); 100 | } else { 101 | holder.data.setValues(item.getEntries()); 102 | holder.chart.getData().notifyDataChanged(); 103 | holder.chart.notifyDataSetChanged(); 104 | } 105 | 106 | holder.itemView.setTag(item.getModel().getId()); 107 | renderModel(holder, item.getModel()); 108 | changeStyle(item.getModel().getId(), holder.chart, holder.data); 109 | 110 | holder.chart.invalidate(); 111 | } 112 | 113 | private void changeStyle(long id, LineChart chart, LineDataSet set) { 114 | final int index = (int) (id % FG.length); 115 | final Context context = chart.getContext(); 116 | 117 | final int fg = ContextCompat.getColor(context, FG[index]); 118 | set.setColor(fg); 119 | set.setFillColor(fg); 120 | 121 | chart.setGridBackgroundColor(ContextCompat.getColor(context, BG[index])); 122 | } 123 | 124 | private LineDataSet prepareInitData(@NonNull LineChart chart, @NonNull List entries) { 125 | final LineDataSet set = new LineDataSet(entries, "Accuracy"); 126 | 127 | set.setMode(LineDataSet.Mode.HORIZONTAL_BEZIER); 128 | set.setAxisDependency(YAxis.AxisDependency.LEFT); 129 | set.setLineWidth(2F); 130 | set.setDrawCircleHole(false); 131 | set.setDrawCircles(false); 132 | set.setHighlightEnabled(false); 133 | set.setDrawFilled(true); 134 | 135 | final LineData group = new LineData(set); 136 | group.setDrawValues(false); 137 | 138 | chart.setData(group); 139 | 140 | return set; 141 | } 142 | 143 | private void renderModel(TrainedModelViewHolder holder, Model model) { 144 | holder.name.setText(model.getName()); 145 | holder.layers.setText(StringFormatUtil.formatLayerSizes(model.getHiddenSizes())); 146 | holder.epochs.setText(String.format("E: %s", model.getEpochs())); 147 | holder.learningRate.setText(String.format("R: %s", model.getLearningRate())); 148 | holder.dataSize.setText(String.format("D: %s", model.getDataSize())); 149 | holder.timeUsed.setText(String.format("T: %s", StringFormatUtil.formatTimeUsed(model.getDataSize()))); 150 | holder.evaluate.setText(String.format("%s%%", (int)(model.getEvaluate() * 100))); 151 | } 152 | 153 | static class TrainedModelViewHolder extends RecyclerView.ViewHolder { 154 | final LineChart chart; 155 | final TextView name; 156 | final TextView layers; 157 | final TextView epochs; 158 | final TextView learningRate; 159 | final TextView dataSize; 160 | final TextView timeUsed; 161 | final TextView evaluate; 162 | 163 | @Nullable LineDataSet data; 164 | 165 | TrainedModelViewHolder(View itemView) { 166 | super(itemView); 167 | 168 | chart = itemView.findViewById(R.id.line_chart); 169 | name = itemView.findViewById(R.id.name); 170 | layers = itemView.findViewById(R.id.layers); 171 | epochs = itemView.findViewById(R.id.epochs); 172 | learningRate = itemView.findViewById(R.id.learning_rate); 173 | dataSize = itemView.findViewById(R.id.data_size); 174 | timeUsed = itemView.findViewById(R.id.time_used); 175 | evaluate = itemView.findViewById(R.id.accuracy); 176 | } 177 | 178 | private static TrainedModelViewHolder newInstance(@NonNull LayoutInflater layoutInflater, 179 | @NonNull ViewGroup viewGroup) { 180 | final View view = layoutInflater.inflate(R.layout.item_trained, viewGroup, false); 181 | 182 | return new TrainedModelViewHolder(view); 183 | } 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/multiple/binder/TrainingModelViewBinder.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.multiple.binder; 2 | 3 | import android.support.annotation.NonNull; 4 | import android.support.v7.widget.RecyclerView; 5 | import android.view.LayoutInflater; 6 | import android.view.View; 7 | import android.view.ViewGroup; 8 | import android.widget.ProgressBar; 9 | import android.widget.TextView; 10 | 11 | import io.whz.synapse.R; 12 | import io.whz.synapse.element.Global; 13 | import io.whz.synapse.pojo.event.MANEvent; 14 | import io.whz.synapse.pojo.multiple.item.TrainingModelItem; 15 | import io.whz.synapse.pojo.neural.Model; 16 | import io.whz.synapse.util.StringFormatUtil; 17 | import me.drakeet.multitype.ItemViewBinder; 18 | 19 | public class TrainingModelViewBinder extends ItemViewBinder 20 | implements View.OnClickListener { 21 | 22 | @NonNull 23 | @Override 24 | protected TrainingModelViewHolder onCreateViewHolder(@NonNull LayoutInflater layoutInflater, @NonNull ViewGroup viewGroup) { 25 | final TrainingModelViewHolder holder = TrainingModelViewHolder.newInstance(layoutInflater, viewGroup); 26 | 27 | holder.itemView.setOnClickListener(this); 28 | 29 | return holder; 30 | } 31 | 32 | @Override 33 | protected void onBindViewHolder(@NonNull TrainingModelViewHolder trainingModelViewHolder, @NonNull TrainingModelItem trainingModelItem) { 34 | renderModel(trainingModelViewHolder, trainingModelItem.getModel()); 35 | } 36 | 37 | private void renderModel(TrainingModelViewHolder holder, Model model) { 38 | final int step = model.getStepEpoch(); 39 | 40 | holder.name.setText(model.getName()); 41 | holder.step.setText(String.format("S: %s", step)); 42 | holder.layers.setText(StringFormatUtil.formatLayerSizes(model.getHiddenSizes())); 43 | holder.epochs.setText(String.format("E: %s", model.getEpochs())); 44 | holder.learningRate.setText(String.format("R: %s", model.getLearningRate())); 45 | holder.dataSize.setText(String.format("D: %s", model.getDataSize())); 46 | 47 | if (step == 0 || step == model.getEpochs()) { 48 | holder.progress.setIndeterminate(true); 49 | } else { 50 | holder.progress.setIndeterminate(false); 51 | holder.progress.setProgress(step * 100 / model.getEpochs()); 52 | } 53 | 54 | final double[] accuracies = model.getAccuracies(); 55 | 56 | if (accuracies == null || accuracies.length == 0) { 57 | holder.accuracy.setText("--%"); 58 | } else { 59 | holder.accuracy.setText(String.format("%s%%", (int)(accuracies[step - 1] * 100))); 60 | } 61 | } 62 | 63 | @Override 64 | public void onClick(View view) { 65 | Global.getInstance() 66 | .getBus() 67 | .post(new MANEvent<>(MANEvent.JUMP_TO_TRAINING)); 68 | } 69 | 70 | static class TrainingModelViewHolder extends RecyclerView.ViewHolder { 71 | final TextView name; 72 | final TextView step; 73 | final TextView layers; 74 | final TextView epochs; 75 | final TextView learningRate; 76 | final TextView dataSize; 77 | final TextView accuracy; 78 | final ProgressBar progress; 79 | 80 | TrainingModelViewHolder(View itemView) { 81 | super(itemView); 82 | name = itemView.findViewById(R.id.name); 83 | step = itemView.findViewById(R.id.step); 84 | layers = itemView.findViewById(R.id.layers); 85 | epochs = itemView.findViewById(R.id.epochs); 86 | learningRate = itemView.findViewById(R.id.learning_rate); 87 | dataSize = itemView.findViewById(R.id.data_size); 88 | accuracy = itemView.findViewById(R.id.accuracy); 89 | progress = itemView.findViewById(R.id.progress); 90 | } 91 | 92 | static TrainingModelViewHolder newInstance(@NonNull LayoutInflater layoutInflater, @NonNull ViewGroup viewGroup) { 93 | final View view = layoutInflater.inflate(R.layout.item_training, viewGroup, false); 94 | 95 | return new TrainingModelViewHolder(view); 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/multiple/binder/WelcomeViewBinder.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.multiple.binder; 2 | 3 | import android.content.res.Resources; 4 | import android.support.annotation.NonNull; 5 | import android.support.transition.TransitionManager; 6 | import android.support.v4.content.res.ResourcesCompat; 7 | import android.support.v7.widget.RecyclerView; 8 | import android.view.LayoutInflater; 9 | import android.view.View; 10 | import android.view.ViewGroup; 11 | import android.widget.TextView; 12 | 13 | import org.greenrobot.eventbus.EventBus; 14 | 15 | import io.whz.synapse.R; 16 | import io.whz.synapse.pojo.event.MANEvent; 17 | import io.whz.synapse.pojo.multiple.item.WelcomeItem; 18 | import me.drakeet.multitype.ItemViewBinder; 19 | 20 | public class WelcomeViewBinder extends ItemViewBinder 21 | implements View.OnClickListener { 22 | 23 | @NonNull 24 | @Override 25 | protected WelcomeHolder onCreateViewHolder(@NonNull LayoutInflater layoutInflater, @NonNull ViewGroup viewGroup) { 26 | final WelcomeHolder holder = WelcomeHolder.newInstance(layoutInflater, viewGroup); 27 | 28 | holder.download.setOnClickListener(this); 29 | 30 | return holder; 31 | } 32 | 33 | @Override 34 | protected void onBindViewHolder(@NonNull WelcomeHolder holder, @NonNull WelcomeItem dataSet) { 35 | final Resources resources = holder.download.getResources(); 36 | TransitionManager.beginDelayedTransition((ViewGroup) holder.itemView); 37 | 38 | switch (dataSet.state()) { 39 | case WelcomeItem.READY: 40 | holder.download.setText(R.string.data_ready); 41 | holder.download.setTextColor(ResourcesCompat.getColor(resources, R.color.data_ready_text, null)); 42 | holder.download.setClickable(false); 43 | holder.progress.setVisibility(View.GONE); 44 | break; 45 | 46 | case WelcomeItem.WAITING: 47 | holder.download.setTextColor(ResourcesCompat.getColor(resources, R.color.data_waiting_text, null)); 48 | holder.download.setText(R.string.data_waiting); 49 | holder.download.setClickable(true); 50 | holder.progress.setVisibility(View.VISIBLE); 51 | break; 52 | 53 | case WelcomeItem.UNREADY: 54 | holder.download.setTextColor(ResourcesCompat.getColor(resources, R.color.data_unready_text, null)); 55 | holder.download.setText(R.string.data_unready); 56 | holder.download.setClickable(true); 57 | holder.progress.setVisibility(View.GONE); 58 | break; 59 | } 60 | } 61 | 62 | @Override 63 | public void onClick(View view) { 64 | EventBus.getDefault() 65 | .post(new MANEvent(MANEvent.CLICK_DOWNLOAD)); 66 | } 67 | 68 | static class WelcomeHolder extends RecyclerView.ViewHolder { 69 | final TextView download; 70 | final View progress; 71 | 72 | WelcomeHolder(View itemView) { 73 | super(itemView); 74 | 75 | download = itemView.findViewById(R.id.download); 76 | progress = itemView.findViewById(R.id.progress); 77 | } 78 | 79 | static WelcomeHolder newInstance(@NonNull LayoutInflater layoutInflater, @NonNull ViewGroup viewGroup) { 80 | final View v = layoutInflater.inflate(R.layout.item_welcome, viewGroup, false); 81 | return new WelcomeHolder(v); 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/multiple/item/PlayItem.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.multiple.item; 2 | 3 | public class PlayItem { 4 | } 5 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/multiple/item/TrainedModelItem.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.multiple.item; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import com.github.mikephil.charting.data.Entry; 6 | 7 | import java.util.ArrayList; 8 | import java.util.List; 9 | import java.util.Objects; 10 | 11 | import io.whz.synapse.pojo.neural.Model; 12 | import io.whz.synapse.util.Precondition; 13 | 14 | public class TrainedModelItem { 15 | private final Model mModel; 16 | private List mEntries; 17 | 18 | public TrainedModelItem(@NonNull Model model) { 19 | mModel = Precondition.checkNotNull(model); 20 | } 21 | 22 | private List format() { 23 | final double[] doubles = getModel().getAccuracies(); 24 | final List list = new ArrayList<>(); 25 | 26 | for (int i = 0, len = doubles.length; i < len; ++i) { 27 | list.add(new Entry(i, (float) doubles[i])); 28 | } 29 | 30 | return list; 31 | } 32 | 33 | public Model getModel() { 34 | return mModel; 35 | } 36 | 37 | public List getEntries() { 38 | if (mEntries == null) { 39 | mEntries = format(); 40 | } 41 | 42 | return mEntries; 43 | } 44 | 45 | @Override 46 | public boolean equals(Object obj) { 47 | return (obj instanceof TrainedModelItem) 48 | && Objects.equals(((TrainedModelItem) obj).getModel().getId(), this.getModel().getId()); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/multiple/item/TrainingModelItem.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.multiple.item; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import io.whz.synapse.pojo.neural.Model; 6 | 7 | public class TrainingModelItem { 8 | private final Model mModel; 9 | private final int mStepSnapShot; 10 | 11 | public TrainingModelItem(@NonNull Model model) { 12 | mModel = model; 13 | mStepSnapShot = model.getStepEpoch(); 14 | } 15 | 16 | public Model getModel() { 17 | return mModel; 18 | } 19 | 20 | @Override 21 | public boolean equals(Object obj) { 22 | return (obj instanceof TrainingModelItem) 23 | && ((TrainingModelItem) obj).mStepSnapShot == this.mStepSnapShot; 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/multiple/item/WelcomeItem.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.multiple.item; 2 | 3 | import android.support.annotation.IntDef; 4 | 5 | import java.lang.annotation.Retention; 6 | import java.lang.annotation.RetentionPolicy; 7 | 8 | public class WelcomeItem { 9 | @Retention(RetentionPolicy.SOURCE) 10 | @IntDef({UNREADY, WAITING, READY}) 11 | public @interface State {} 12 | 13 | public static final int UNREADY = 0x01; 14 | public static final int WAITING = 0x01 << 1; 15 | public static final int READY = 0x01 << 2; 16 | 17 | @State 18 | private final int mState; 19 | 20 | public WelcomeItem(@State int state) { 21 | mState = state; 22 | } 23 | 24 | @State 25 | public int state() { 26 | return mState; 27 | } 28 | 29 | @Override 30 | public boolean equals(Object obj) { 31 | return (obj instanceof WelcomeItem) 32 | && ((WelcomeItem) obj).state() == this.state(); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/neural/Batch.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.neural; 2 | 3 | import io.whz.synapse.matrix.Matrix; 4 | 5 | public class Batch { 6 | public final Matrix[] inputs; 7 | public final Matrix[] targets; 8 | 9 | public Batch(Matrix[] inputs, Matrix[] targets) { 10 | this.inputs = inputs; 11 | this.targets = targets; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/neural/Digit.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.neural; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | public class Digit { 6 | public final double[] colors; 7 | public final int label; 8 | 9 | public Digit(int label, @NonNull double[] colors) { 10 | this.colors = colors; 11 | this.label = label; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/neural/Figure.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.neural; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | public class Figure { 6 | public final byte[] bytes; 7 | public final int label; 8 | 9 | public Figure(int label, @NonNull byte[] bytes) { 10 | this.bytes = bytes; 11 | this.label = label; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/pojo/neural/Model.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.pojo.neural; 2 | 3 | import java.io.Serializable; 4 | 5 | import io.whz.synapse.matrix.Matrix; 6 | 7 | public class Model implements Serializable { 8 | public static final long serialVersionUID = 0xAAFF; 9 | 10 | private Long id; 11 | private String name; 12 | private long createdTime; 13 | private double learningRate; 14 | private int epochs; 15 | private int dataSize; 16 | private long timeUsed; 17 | private double evaluate; 18 | private int stepEpoch; 19 | private double[] accuracies; 20 | private int[] hiddenSizes; 21 | private Matrix[] biases; 22 | private Matrix[] weights; 23 | 24 | public Long getId() { 25 | return id; 26 | } 27 | 28 | public void setId(Long id) { 29 | this.id = id; 30 | } 31 | 32 | public String getName() { 33 | return name; 34 | } 35 | 36 | public void setName(String name) { 37 | this.name = name; 38 | } 39 | 40 | public long getCreatedTime() { 41 | return createdTime; 42 | } 43 | 44 | public void setCreatedTime(long createdTime) { 45 | this.createdTime = createdTime; 46 | } 47 | 48 | public double getLearningRate() { 49 | return learningRate; 50 | } 51 | 52 | public void setLearningRate(double learningRate) { 53 | this.learningRate = learningRate; 54 | } 55 | 56 | public int getEpochs() { 57 | return epochs; 58 | } 59 | 60 | public void setEpochs(int epochs) { 61 | this.epochs = epochs; 62 | } 63 | 64 | public int getDataSize() { 65 | return dataSize; 66 | } 67 | 68 | public void setDataSize(int dataSize) { 69 | this.dataSize = dataSize; 70 | } 71 | 72 | public long getTimeUsed() { 73 | return timeUsed; 74 | } 75 | 76 | public void setTimeUsed(long timeUsed) { 77 | this.timeUsed = timeUsed; 78 | } 79 | 80 | public double getEvaluate() { 81 | return evaluate; 82 | } 83 | 84 | public void setEvaluate(double evaluate) { 85 | this.evaluate = evaluate; 86 | } 87 | 88 | public int getStepEpoch() { 89 | return stepEpoch; 90 | } 91 | 92 | public void setStepEpoch(int stepEpoch) { 93 | this.stepEpoch = stepEpoch; 94 | } 95 | 96 | public int[] getHiddenSizes() { 97 | return hiddenSizes; 98 | } 99 | 100 | public void setHiddenSizes(int[] hiddenSizes) { 101 | this.hiddenSizes = hiddenSizes; 102 | } 103 | 104 | public Matrix[] getBiases() { 105 | return biases; 106 | } 107 | 108 | public void setBiases(Matrix[] biases) { 109 | this.biases = biases; 110 | } 111 | 112 | public Matrix[] getWeights() { 113 | return weights; 114 | } 115 | 116 | public void setWeights(Matrix[] weights) { 117 | this.weights = weights; 118 | } 119 | 120 | public double[] getAccuracies() { 121 | return accuracies; 122 | } 123 | 124 | public void setAccuracies(double[] accuracies) { 125 | this.accuracies = accuracies; 126 | } 127 | } 128 | 129 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/AbsTrackHandler.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import android.content.Context; 4 | import android.support.annotation.NonNull; 5 | 6 | import org.greenrobot.eventbus.EventBus; 7 | import org.greenrobot.eventbus.Subscribe; 8 | import org.greenrobot.eventbus.ThreadMode; 9 | 10 | import io.whz.synapse.pojo.event.TrackEvent; 11 | 12 | public abstract class AbsTrackHandler { 13 | void register(@NonNull EventBus bus) { 14 | bus.register(this); 15 | } 16 | 17 | AbsTrackHandler(@NonNull Context context) {} 18 | 19 | @Subscribe(threadMode = ThreadMode.BACKGROUND) 20 | public abstract void onTrackEvent(TrackEvent event); 21 | } 22 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/ActivityLifecycleTracker.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import android.app.Activity; 4 | import android.app.Application; 5 | import android.os.Bundle; 6 | import android.support.annotation.NonNull; 7 | 8 | import static io.whz.synapse.pojo.constant.TrackCons.Lifecycle.ENTER; 9 | import static io.whz.synapse.pojo.constant.TrackCons.Lifecycle.LEAVE; 10 | import static io.whz.synapse.pojo.constant.TrackCons.concat; 11 | 12 | class ActivityLifecycleTracker implements Application.ActivityLifecycleCallbacks { 13 | private final Tracker mTrack; 14 | 15 | ActivityLifecycleTracker(@NonNull Tracker track) { 16 | mTrack = track; 17 | } 18 | 19 | @Override 20 | public void onActivityCreated(Activity activity, Bundle bundle) {} 21 | 22 | @Override 23 | public void onActivityStarted(Activity activity) { 24 | mTrack.logEvent(concat(ENTER, activity.getClass().getSimpleName().toLowerCase())); 25 | } 26 | 27 | @Override 28 | public void onActivityResumed(Activity activity) { 29 | 30 | } 31 | 32 | @Override 33 | public void onActivityPaused(Activity activity) { 34 | } 35 | 36 | @Override 37 | public void onActivityStopped(Activity activity) { 38 | mTrack.logEvent(concat(LEAVE, activity.getClass().getSimpleName().toLowerCase())); 39 | } 40 | 41 | @Override 42 | public void onActivitySaveInstanceState(Activity activity, Bundle bundle) { 43 | 44 | } 45 | 46 | @Override 47 | public void onActivityDestroyed(Activity activity) { 48 | 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/AmplitudeTrackHandler.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import android.content.Context; 4 | import android.os.Bundle; 5 | import android.support.annotation.NonNull; 6 | import android.support.annotation.Nullable; 7 | 8 | import com.amplitude.api.Amplitude; 9 | import com.amplitude.api.AmplitudeClient; 10 | 11 | import org.json.JSONException; 12 | import org.json.JSONObject; 13 | 14 | import io.whz.synapse.BuildConfig; 15 | import io.whz.synapse.pojo.event.TrackEvent; 16 | 17 | class AmplitudeTrackHandler extends AbsTrackHandler { 18 | private static final String ORIGINAL_OBJECT = "AmplitudeTrackHandler:OriginalObject"; 19 | 20 | private final AmplitudeClient mClient; 21 | 22 | AmplitudeTrackHandler(@NonNull Context context) { 23 | super(context); 24 | 25 | mClient = Amplitude.getInstance(); 26 | mClient.initialize(context, BuildConfig.AMPLITUDE_ID); 27 | } 28 | 29 | @Override 30 | public void onTrackEvent(TrackEvent event) { 31 | mClient.logEvent(event.id, toJsonObject(event.bundle)); 32 | } 33 | 34 | @Nullable 35 | private JSONObject toJsonObject(@Nullable Bundle bundle) { 36 | if (bundle == null) { 37 | return null; 38 | } 39 | 40 | final JSONObject object = new JSONObject(); 41 | 42 | try { 43 | for (String key : bundle.keySet()) { 44 | object.put(key, bundle.get(key)); 45 | } 46 | } catch (JSONException e) { 47 | e.printStackTrace(); 48 | 49 | try { 50 | object.put(ORIGINAL_OBJECT, bundle); 51 | } catch (JSONException e1) { 52 | e1.printStackTrace(); 53 | } 54 | } 55 | 56 | return object; 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/DebugTrackHandler.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import android.content.Context; 4 | import android.support.annotation.NonNull; 5 | import android.util.Log; 6 | 7 | import io.whz.synapse.pojo.event.TrackEvent; 8 | 9 | class DebugTrackHandler extends AbsTrackHandler { 10 | private static final String TAG = "DebugTrackHandler"; 11 | 12 | DebugTrackHandler(@NonNull Context context) { 13 | super(context); 14 | } 15 | 16 | @Override 17 | public void onTrackEvent(TrackEvent event) { 18 | Log.i(TAG, "ID: " + event.id + 19 | (event.bundle == null ? "" : "Event: " + event.bundle)); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/ExceptionHelper.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import org.greenrobot.greendao.annotation.NotNull; 4 | 5 | import io.whz.synapse.BuildConfig; 6 | 7 | import static io.whz.synapse.pojo.constant.TrackCons.APP.CAUGHT; 8 | import static io.whz.synapse.pojo.constant.TrackCons.Key.MSG; 9 | 10 | public class ExceptionHelper { 11 | private static final boolean sEnable = BuildConfig.TRACK_ENABLE; 12 | private final Tracker mTracker; 13 | 14 | private ExceptionHelper() { 15 | mTracker = Tracker.getInstance(); 16 | } 17 | 18 | public void caught(@NotNull Throwable e) { 19 | if (sEnable) { 20 | mTracker.event(CAUGHT) 21 | .put(MSG, toString(e)) 22 | .log(); 23 | } else { 24 | e.printStackTrace(); 25 | } 26 | } 27 | 28 | private String toString(@NotNull Throwable e) { 29 | final StackTraceElement[] elements = e.getStackTrace(); 30 | 31 | if (elements == null 32 | || elements.length == 0) { 33 | return ""; 34 | } 35 | 36 | final StringBuilder builder = new StringBuilder(); 37 | 38 | for (StackTraceElement element : elements) { 39 | builder.append(element.toString()) 40 | .append("\n"); 41 | } 42 | 43 | return builder.toString(); 44 | } 45 | 46 | public static ExceptionHelper getInstance() { 47 | return Holder.sInstance; 48 | } 49 | 50 | private interface Holder { 51 | ExceptionHelper sInstance = new ExceptionHelper(); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/FirebaseTrackHandler.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import android.content.Context; 4 | import android.support.annotation.NonNull; 5 | 6 | import com.google.firebase.analytics.FirebaseAnalytics; 7 | 8 | import io.whz.synapse.pojo.event.TrackEvent; 9 | 10 | class FirebaseTrackHandler extends AbsTrackHandler { 11 | 12 | private final FirebaseAnalytics mAnalyties; 13 | 14 | FirebaseTrackHandler(@NonNull Context context) { 15 | super(context); 16 | 17 | mAnalyties = FirebaseAnalytics.getInstance(context); 18 | } 19 | 20 | @Override 21 | public void onTrackEvent(TrackEvent event) { 22 | mAnalyties.logEvent(event.id, event.bundle); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/ITracker.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import io.whz.synapse.pojo.event.TrackEvent; 6 | 7 | interface ITracker { 8 | void logEvent(@NonNull TrackEvent event); 9 | } 10 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/TimeHelper.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import android.support.annotation.NonNull; 4 | import android.support.v4.util.ArrayMap; 5 | import android.text.TextUtils; 6 | 7 | import java.util.Map; 8 | 9 | public class TimeHelper { 10 | private final Map mMap; 11 | 12 | private TimeHelper() { 13 | mMap = new ArrayMap<>(); 14 | } 15 | 16 | public void start(@NonNull String id) { 17 | if (TextUtils.isEmpty(id) || mMap.containsKey(id)) { 18 | new IllegalArgumentException("Illegal id") 19 | .printStackTrace(); 20 | 21 | return; 22 | } 23 | 24 | mMap.put(id, System.currentTimeMillis()); 25 | } 26 | 27 | public long stop(@NonNull String id) { 28 | if (TextUtils.isEmpty(id) || !mMap.containsKey(id)) { 29 | new IllegalArgumentException("Illegal id") 30 | .printStackTrace(); 31 | 32 | return 0L; 33 | } 34 | 35 | final long startTIme = mMap.remove(id); 36 | 37 | return System.currentTimeMillis() - startTIme; 38 | } 39 | 40 | public static TimeHelper getInstance() { 41 | return Holder.sInstance; 42 | } 43 | 44 | private interface Holder { 45 | TimeHelper sInstance = new TimeHelper(); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/track/Tracker.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.track; 2 | 3 | import android.app.Application; 4 | import android.content.Context; 5 | import android.os.Bundle; 6 | import android.support.annotation.NonNull; 7 | 8 | import org.greenrobot.eventbus.EventBus; 9 | 10 | import java.io.Serializable; 11 | import java.util.HashSet; 12 | import java.util.Set; 13 | 14 | import io.whz.synapse.BuildConfig; 15 | import io.whz.synapse.pojo.event.TrackEvent; 16 | import io.whz.synapse.util.Precondition; 17 | 18 | public class Tracker implements ITracker { 19 | private final Set mTracks = new HashSet<>(); 20 | 21 | private EventBus mBus; 22 | 23 | public void initialize(@NonNull Context context, @NonNull EventBus bus) { 24 | mBus = Precondition.checkNotNull(bus); 25 | 26 | if (BuildConfig.TRACK_ENABLE) { 27 | mTracks.add(new FirebaseTrackHandler(context)); 28 | mTracks.add(new AmplitudeTrackHandler(context)); 29 | } else { 30 | mTracks.add(new DebugTrackHandler(context)); 31 | } 32 | 33 | for (AbsTrackHandler track : mTracks) { 34 | track.register(bus); 35 | } 36 | 37 | if (context instanceof Application) { 38 | ((Application) context).registerActivityLifecycleCallbacks(new ActivityLifecycleTracker(this)); 39 | } 40 | } 41 | 42 | @Override 43 | public void logEvent(@NonNull TrackEvent event) { 44 | if (mBus == null) { 45 | new NullPointerException("EvenBus is null, please initialize first") 46 | .printStackTrace(); 47 | return; 48 | } 49 | 50 | mBus.post(event); 51 | } 52 | 53 | public void logEvent(@NonNull String id) { 54 | Precondition.checkNotNull(id); 55 | 56 | logEvent(new TrackEvent(id, null)); 57 | } 58 | 59 | public EventBuilder event(@NonNull String id) { 60 | Precondition.checkNotNull(id); 61 | 62 | return new EventBuilder(id); 63 | } 64 | 65 | public static Tracker getInstance() { 66 | return Holder.sInstance; 67 | } 68 | 69 | private interface Holder { 70 | Tracker sInstance = new Tracker(); 71 | } 72 | 73 | public static class EventBuilder{ 74 | private final Bundle bundle = new Bundle(); 75 | private final String id; 76 | 77 | EventBuilder(@NonNull String id) { 78 | this.id = id; 79 | } 80 | 81 | public EventBuilder put(@NonNull String key, boolean value) { 82 | bundle.putBoolean(key, value); 83 | 84 | return this; 85 | } 86 | 87 | public EventBuilder put(@NonNull String key, int value) { 88 | bundle.putInt(key, value); 89 | 90 | return this; 91 | } 92 | 93 | public EventBuilder put(@NonNull String key, double value) { 94 | bundle.putDouble(key, value); 95 | return this; 96 | } 97 | 98 | public EventBuilder put(@NonNull String key, String value) { 99 | bundle.putString(key, value); 100 | 101 | return this; 102 | } 103 | 104 | public EventBuilder put(@NonNull String key, Serializable value) { 105 | bundle.putSerializable(key, value); 106 | 107 | return this; 108 | } 109 | 110 | public void log() { 111 | Holder.sInstance.logEvent(new TrackEvent(id, bundle)); 112 | } 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/transition/GravityArcMotion.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2016 Google Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.whz.synapse.transition; 18 | 19 | import android.graphics.Path; 20 | import android.transition.ArcMotion; 21 | 22 | /** 23 | * Thanks https://github.com/nickbutcher/plaid 24 | * A tweak to {@link ArcMotion} which slightly alters the path calculation. In the real world 25 | * gravity slows upward motion and accelerates downward motion. This class emulates this behavior 26 | * to make motion paths appear more natural. 27 | *

28 | * See https://www.google.com/design/spec/motion/movement.html#movement-movement-within-screen-bounds 29 | */ 30 | class GravityArcMotion extends ArcMotion { 31 | private static final float DEFAULT_MIN_ANGLE_DEGREES = 0; 32 | private static final float DEFAULT_MAX_ANGLE_DEGREES = 70; 33 | private static final float DEFAULT_MAX_TANGENT = (float) 34 | Math.tan(Math.toRadians(DEFAULT_MAX_ANGLE_DEGREES/2)); 35 | 36 | private float mMinimumHorizontalAngle = 0; 37 | private float mMinimumVerticalAngle = 0; 38 | private float mMaximumAngle = DEFAULT_MAX_ANGLE_DEGREES; 39 | private float mMinimumHorizontalTangent = 0; 40 | private float mMinimumVerticalTangent = 0; 41 | private float mMaximumTangent = DEFAULT_MAX_TANGENT; 42 | 43 | GravityArcMotion() {} 44 | 45 | /** 46 | * @inheritDoc 47 | */ 48 | @Override 49 | public void setMinimumHorizontalAngle(float angleInDegrees) { 50 | mMinimumHorizontalAngle = angleInDegrees; 51 | mMinimumHorizontalTangent = toTangent(angleInDegrees); 52 | } 53 | 54 | /** 55 | * @inheritDoc 56 | */ 57 | @Override 58 | public float getMinimumHorizontalAngle() { 59 | return mMinimumHorizontalAngle; 60 | } 61 | 62 | /** 63 | * @inheritDoc 64 | */ 65 | @Override 66 | public void setMinimumVerticalAngle(float angleInDegrees) { 67 | mMinimumVerticalAngle = angleInDegrees; 68 | mMinimumVerticalTangent = toTangent(angleInDegrees); 69 | } 70 | 71 | /** 72 | * @inheritDoc 73 | */ 74 | @Override 75 | public float getMinimumVerticalAngle() { 76 | return mMinimumVerticalAngle; 77 | } 78 | 79 | /** 80 | * @inheritDoc 81 | */ 82 | @Override 83 | public void setMaximumAngle(float angleInDegrees) { 84 | mMaximumAngle = angleInDegrees; 85 | mMaximumTangent = toTangent(angleInDegrees); 86 | } 87 | 88 | /** 89 | * @inheritDoc 90 | */ 91 | @Override 92 | public float getMaximumAngle() { 93 | return mMaximumAngle; 94 | } 95 | 96 | private static float toTangent(float arcInDegrees) { 97 | if (arcInDegrees < 0 || arcInDegrees > 90) { 98 | throw new IllegalArgumentException("Arc must be between 0 and 90 degrees"); 99 | } 100 | return (float) Math.tan(Math.toRadians(arcInDegrees / 2)); 101 | } 102 | 103 | @Override 104 | public Path getPath(float startX, float startY, float endX, float endY) { 105 | // Here's a little ascii art to show how this is calculated: 106 | // c---------- b 107 | // \ / | 108 | // \ d | 109 | // \ / e 110 | // a----f 111 | // This diagram assumes that the horizontal distance is less than the vertical 112 | // distance between The start point (a) and end point (b). 113 | // d is the midpoint between a and b. c is the center point of the circle with 114 | // This path is formed by assuming that start and end points are in 115 | // an arc on a circle. The end point is centered in the circle vertically 116 | // and start is a point on the circle. 117 | 118 | // Triangles bfa and bde form similar right triangles. The control points 119 | // for the cubic Bezier arc path are the midpoints between a and e and e and b. 120 | 121 | Path path = new Path(); 122 | path.moveTo(startX, startY); 123 | 124 | float ex; 125 | float ey; 126 | if (startY == endY) { 127 | ex = (startX + endX) / 2; 128 | ey = startY + mMinimumHorizontalTangent * Math.abs(endX - startX) / 2; 129 | } else if (startX == endX) { 130 | ex = startX + mMinimumVerticalTangent * Math.abs(endY - startY) / 2; 131 | ey = (startY + endY) / 2; 132 | } else { 133 | float deltaX = endX - startX; 134 | 135 | /** 136 | * This is the only change to ArcMotion 137 | */ 138 | float deltaY; 139 | if (endY < startY) { 140 | deltaY = startY - endY; // Y is inverted compared to diagram above. 141 | } else { 142 | deltaY = endY - startY; 143 | } 144 | /** 145 | * End changes 146 | */ 147 | 148 | // hypotenuse squared. 149 | float h2 = deltaX * deltaX + deltaY * deltaY; 150 | 151 | // Midpoint between start and end 152 | float dx = (startX + endX) / 2; 153 | float dy = (startY + endY) / 2; 154 | 155 | // Distance squared between end point and mid point is (1/2 hypotenuse)^2 156 | float midDist2 = h2 * 0.25f; 157 | 158 | float minimumArcDist2 = 0; 159 | 160 | if (Math.abs(deltaX) < Math.abs(deltaY)) { 161 | // Similar triangles bfa and bde mean that (ab/fb = eb/bd) 162 | // Therefore, eb = ab * bd / fb 163 | // ab = hypotenuse 164 | // bd = hypotenuse/2 165 | // fb = deltaY 166 | float eDistY = h2 / (2 * deltaY); 167 | ey = endY + eDistY; 168 | ex = endX; 169 | 170 | minimumArcDist2 = midDist2 * mMinimumVerticalTangent 171 | * mMinimumVerticalTangent; 172 | } else { 173 | // Same as above, but flip X & Y 174 | float eDistX = h2 / (2 * deltaX); 175 | ex = endX + eDistX; 176 | ey = endY; 177 | 178 | minimumArcDist2 = midDist2 * mMinimumHorizontalTangent 179 | * mMinimumHorizontalTangent; 180 | } 181 | float arcDistX = dx - ex; 182 | float arcDistY = dy - ey; 183 | float arcDist2 = arcDistX * arcDistX + arcDistY * arcDistY; 184 | 185 | float maximumArcDist2 = midDist2 * mMaximumTangent * mMaximumTangent; 186 | 187 | float newArcDistance2 = 0; 188 | if (arcDist2 < minimumArcDist2) { 189 | newArcDistance2 = minimumArcDist2; 190 | } else if (arcDist2 > maximumArcDist2) { 191 | newArcDistance2 = maximumArcDist2; 192 | } 193 | if (newArcDistance2 != 0) { 194 | float ratio2 = newArcDistance2 / arcDist2; 195 | float ratio = (float) Math.sqrt(ratio2); 196 | ex = dx + (ratio * (ex - dx)); 197 | ey = dy + (ratio * (ey - dy)); 198 | } 199 | } 200 | float controlX1 = (startX + ex) / 2; 201 | float controlY1 = (startY + ey) / 2; 202 | float controlX2 = (ex + endX) / 2; 203 | float controlY2 = (ey + endY) / 2; 204 | path.cubicTo(controlX1, controlY1, controlX2, controlY2, endX, endY); 205 | return path; 206 | } 207 | 208 | } 209 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/util/DbHelper.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.util; 2 | 3 | import android.support.annotation.NonNull; 4 | import android.support.annotation.Nullable; 5 | 6 | import java.nio.ByteBuffer; 7 | import java.nio.DoubleBuffer; 8 | import java.nio.IntBuffer; 9 | 10 | import io.whz.synapse.matrix.Matrix; 11 | import io.whz.synapse.pojo.dao.DBModel; 12 | import io.whz.synapse.pojo.neural.Model; 13 | 14 | public class DbHelper { 15 | public static Model dbModel2Model(@NonNull DBModel dbModel) { 16 | Precondition.checkNotNull(dbModel); 17 | 18 | final Model model = new Model(); 19 | 20 | model.setId(dbModel.getId()); 21 | model.setName(dbModel.getName()); 22 | model.setCreatedTime(dbModel.getCreatedTime()); 23 | model.setLearningRate(dbModel.getLearningRate()); 24 | model.setEpochs(dbModel.getEpochs()); 25 | model.setStepEpoch(dbModel.getEpochs()); 26 | model.setDataSize(dbModel.getDataSize()); 27 | model.setTimeUsed(dbModel.getTimeUsed()); 28 | model.setEvaluate(dbModel.getEvaluate()); 29 | model.setHiddenSizes(byteArray2IntArray(dbModel.getHiddenSizeBytes())); 30 | model.setAccuracies(byteArray2DoubleArray(dbModel.getAccuracyBytes())); 31 | model.setBiases(byteArray2MatrixArray(dbModel.getBiasBytes())); 32 | model.setWeights(byteArray2MatrixArray(dbModel.getWeightBytes())); 33 | 34 | return model; 35 | } 36 | 37 | public static DBModel model2DBModel(@NonNull Model model) { 38 | Precondition.checkNotNull(model); 39 | 40 | final DBModel dbModel = new DBModel(); 41 | 42 | dbModel.setId(model.getId()); 43 | dbModel.setName(model.getName()); 44 | dbModel.setCreatedTime(model.getCreatedTime()); 45 | dbModel.setLearningRate(model.getLearningRate()); 46 | dbModel.setEpochs(model.getEpochs()); 47 | dbModel.setDataSize(model.getDataSize()); 48 | dbModel.setTimeUsed(model.getTimeUsed()); 49 | dbModel.setEvaluate(model.getEvaluate()); 50 | dbModel.setHiddenSizeBytes(convert2ByteArray(model.getHiddenSizes())); 51 | dbModel.setAccuracyBytes(convert2ByteArray(model.getAccuracies())); 52 | dbModel.setBiasBytes(convert2ByteArray(model.getBiases())); 53 | dbModel.setWeightBytes(convert2ByteArray(model.getWeights())); 54 | 55 | return dbModel; 56 | } 57 | 58 | @Nullable 59 | private static byte[] convert2ByteArray(int... array) { 60 | if (array == null) { 61 | return null; 62 | } 63 | 64 | ByteBuffer buffer = null; 65 | 66 | try { 67 | buffer = ByteBuffer.allocate(array.length << 2); 68 | 69 | for (int i : array) { 70 | buffer.putInt(i); 71 | } 72 | } catch (Exception e) { 73 | e.printStackTrace(); 74 | buffer = null; 75 | } 76 | 77 | return buffer == null ? null : buffer.array(); 78 | } 79 | 80 | @Nullable 81 | private static byte[] convert2ByteArray(double... array) { 82 | if (array == null) { 83 | return null; 84 | } 85 | 86 | ByteBuffer buffer = null; 87 | 88 | try { 89 | buffer = ByteBuffer.allocate(array.length << 3); 90 | 91 | for (double i : array) { 92 | buffer.putDouble(i); 93 | } 94 | } catch (Exception e) { 95 | e.printStackTrace(); 96 | 97 | buffer = null; 98 | } 99 | 100 | return buffer == null ? null : buffer.array(); 101 | } 102 | 103 | @Nullable 104 | private static byte[] convert2ByteArray(Matrix... matrices) { 105 | if (matrices == null) { 106 | return null; 107 | } 108 | 109 | ByteBuffer buffer = null; 110 | int sum = 0; 111 | 112 | sum += 4; 113 | 114 | for (Matrix matrix : matrices) { 115 | sum += calMatrixLen(matrix); 116 | } 117 | 118 | try { 119 | buffer = ByteBuffer.allocate(sum); 120 | 121 | buffer.putInt(matrices.length); 122 | 123 | for (Matrix matrix : matrices) { 124 | buffer.putInt(matrix.getRow()); 125 | buffer.putInt(matrix.getCol()); 126 | 127 | final double[] doubles = matrix.getArray(); 128 | 129 | for (double d : doubles) { 130 | buffer.putDouble(d); 131 | } 132 | } 133 | } catch (Exception e) { 134 | e.printStackTrace(); 135 | } 136 | 137 | return buffer == null ? null : buffer.array(); 138 | } 139 | 140 | private static int calMatrixLen(@NonNull Matrix matrix) { 141 | int sum = 0; 142 | 143 | sum += 8; 144 | sum += (matrix.getArray().length << 3); 145 | 146 | return sum; 147 | } 148 | 149 | @Nullable 150 | private static Matrix[] byteArray2MatrixArray(byte... array) { 151 | if (array == null) { 152 | return null; 153 | } 154 | 155 | Matrix[] res = null; 156 | 157 | try { 158 | final ByteBuffer buffer = ByteBuffer.wrap(array); 159 | 160 | final int len = buffer.getInt(); 161 | res = new Matrix[len]; 162 | 163 | for (int i = 0; i < len; ++i) { 164 | final int row = buffer.getInt(); 165 | final int col = buffer.getInt(); 166 | final double[] doubles = new double[row * col]; 167 | 168 | for (int j = 0, jLen = doubles.length; j < jLen; ++j) { 169 | doubles[j] = buffer.getDouble(); 170 | } 171 | 172 | res[i] = Matrix.array(doubles, row); 173 | } 174 | } catch (Exception e) { 175 | e.printStackTrace(); 176 | 177 | res = null; 178 | } 179 | 180 | return res; 181 | } 182 | 183 | @Nullable 184 | private static int[] byteArray2IntArray(byte... array) { 185 | if (array == null) { 186 | return null; 187 | } 188 | 189 | final int[] res = new int[array.length >> 2]; 190 | 191 | try { 192 | final IntBuffer buffer = ByteBuffer.wrap(array).asIntBuffer(); 193 | 194 | for (int i = 0, iLen = res.length; i < iLen; ++i) { 195 | res[i] = buffer.get(); 196 | } 197 | } catch (Exception e) { 198 | e.printStackTrace(); 199 | } 200 | 201 | return res; 202 | } 203 | 204 | @Nullable 205 | private static double[] byteArray2DoubleArray(byte... array) { 206 | if (array == null) { 207 | return null; 208 | } 209 | 210 | final int len = array.length; 211 | final double[] res = new double[len >> 3]; 212 | 213 | try { 214 | final DoubleBuffer buffer = ByteBuffer.wrap(array).asDoubleBuffer(); 215 | 216 | for (int i = 0, iLen = len >> 3; i < iLen; ++i) { 217 | res[i] = buffer.get(); 218 | } 219 | } catch (Exception e) { 220 | e.printStackTrace(); 221 | } 222 | 223 | return res; 224 | } 225 | } 226 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/util/FileUtil.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.util; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import java.io.File; 6 | 7 | public class FileUtil { 8 | 9 | public static void clear(@NonNull File... files) { 10 | final StringBuilder builder = new StringBuilder(); 11 | 12 | for (File file : files) { 13 | if (file.exists()) { 14 | builder.append(file.getAbsolutePath()) 15 | .append(' '); 16 | } 17 | } 18 | 19 | String res; 20 | 21 | if (!(res = builder.toString()).isEmpty()) { 22 | ProcessUtil.execCommand(String.format("rm -rf %s", res)); 23 | } 24 | } 25 | 26 | public static void makeDirs(@NonNull File... dirs) { 27 | for (File dir : dirs) { 28 | if (!dir.exists()) { 29 | dir.mkdirs(); 30 | } 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/util/MatrixUtil.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.util; 2 | 3 | import android.support.annotation.NonNull; 4 | import android.support.v4.util.Pair; 5 | 6 | import io.whz.synapse.matrix.Matrix; 7 | 8 | public class MatrixUtil { 9 | public static Matrix[] zerosLike(@NonNull Matrix[] matrices) { 10 | Precondition.checkNotNull(matrices); 11 | 12 | final int len = matrices.length; 13 | final Matrix[] res = new Matrix[len]; 14 | 15 | for (int i = 0; i < len; ++i) { 16 | final Matrix matrix = matrices[i]; 17 | res[i] = Matrix.zeroLike(matrix); 18 | } 19 | 20 | return res; 21 | } 22 | 23 | public static Matrix[] randns(int[] rows, int[] cols) { 24 | final int len; 25 | Precondition.checkArgument((len = rows.length) == cols.length); 26 | 27 | final Matrix[] matrices = new Matrix[len]; 28 | 29 | for (int i = 0; i < len; ++i) { 30 | matrices[i] = Matrix.randn(rows[i], cols[i]); 31 | } 32 | 33 | return matrices; 34 | } 35 | 36 | public static int index(Matrix matrix) { 37 | final double[] doubles = matrix.getArray(); 38 | 39 | for (int i = 0, len = doubles.length; i < len; ++i) { 40 | if (doubles[i] == 1) { 41 | return i; 42 | } 43 | } 44 | 45 | throw new IllegalStateException("Can not find 1 in matrix"); 46 | } 47 | 48 | public static int argmax(Matrix matrix) { 49 | final double[] doubles = matrix.getArray(); 50 | 51 | int index = 0; 52 | 53 | for (int i = 1, len = doubles.length; i < len; ++i) { 54 | if (doubles[i] > doubles[index]) { 55 | index = i; 56 | } 57 | } 58 | 59 | return index; 60 | } 61 | 62 | public static Pair findMax(Matrix matrix) { 63 | final int index = argmax(matrix); 64 | 65 | return Pair.create(index, matrix.getArray()[index]); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/util/ProcessUtil.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.util; 2 | 3 | import java.io.IOException; 4 | 5 | public class ProcessUtil { 6 | public static boolean execCommand(String command) { 7 | final Runtime runtime = Runtime.getRuntime(); 8 | boolean result; 9 | 10 | try { 11 | Process process = runtime.exec(command); 12 | process.waitFor(); 13 | result = true; 14 | } catch (IOException | InterruptedException e) { 15 | e.printStackTrace(); 16 | result = false; 17 | } 18 | 19 | return result; 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/util/StringFormatUtil.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.util; 2 | 3 | import android.support.annotation.NonNull; 4 | 5 | import java.util.Locale; 6 | 7 | import io.whz.synapse.neural.NeuralNetwork; 8 | 9 | public class StringFormatUtil { 10 | private static final String SPLIT_ITEM = ":"; 11 | 12 | public static String formatTimeUsed(long timeUsed) { 13 | return String.format(Locale.getDefault(), "%02d:%02d:%02d", 14 | timeUsed / (3600000), timeUsed / (60000) % 60, timeUsed / 1000 % 60); 15 | } 16 | 17 | public static String formatPercent(double value) { 18 | return String.format(Locale.getDefault(), 19 | "%.2f%%", value * 100); 20 | } 21 | 22 | public static String formatLayerSizes(@NonNull int[] hiddenSizes) { 23 | final StringBuilder builder = new StringBuilder(); 24 | final String spilt = " × "; 25 | 26 | builder.append(NeuralNetwork.INPUT_LAYER_NUMBER) 27 | .append(spilt); 28 | 29 | for (int size : hiddenSizes) { 30 | builder.append(size) 31 | .append(spilt); 32 | } 33 | 34 | builder.append(NeuralNetwork.OUTPUT_LAYER_NUMBER); 35 | 36 | return builder.toString(); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /app/src/main/java/io/whz/synapse/util/Versatile.java: -------------------------------------------------------------------------------- 1 | package io.whz.synapse.util; 2 | 3 | import android.app.Activity; 4 | import android.content.Context; 5 | import android.content.res.AssetManager; 6 | import android.support.annotation.NonNull; 7 | import android.support.annotation.Nullable; 8 | import android.transition.Transition; 9 | import android.transition.TransitionManager; 10 | import android.util.ArrayMap; 11 | import android.view.View; 12 | import android.view.ViewGroup; 13 | 14 | import java.io.File; 15 | import java.io.FileOutputStream; 16 | import java.io.IOException; 17 | import java.io.InputStream; 18 | import java.io.ObjectInputStream; 19 | import java.io.ObjectOutputStream; 20 | import java.lang.ref.WeakReference; 21 | import java.lang.reflect.Field; 22 | import java.util.ArrayList; 23 | 24 | import io.whz.synapse.element.Global; 25 | import io.whz.synapse.pojo.neural.Model; 26 | 27 | public class Versatile { 28 | private static final String DEMO_MODEL_FILE = "demo.model"; 29 | 30 | /** 31 | * Solve TransitionManager leak problem 32 | */ 33 | public static void removeActivityFromTransitionManager(Activity activity) { 34 | final Class transitionManagerClass = TransitionManager.class; 35 | 36 | try { 37 | final Field runningTransitionsField = transitionManagerClass.getDeclaredField("sRunningTransitions"); 38 | 39 | if (runningTransitionsField == null) { 40 | return; 41 | } 42 | 43 | runningTransitionsField.setAccessible(true); 44 | 45 | //noinspection unchecked 46 | final ThreadLocal>>> runningTransitions 47 | = (ThreadLocal>>>) 48 | runningTransitionsField.get(transitionManagerClass); 49 | 50 | if (runningTransitions == null 51 | || runningTransitions.get() == null 52 | || runningTransitions.get().get() == null) { 53 | return; 54 | } 55 | 56 | final ArrayMap map = runningTransitions.get().get(); 57 | final View decorView = activity.getWindow().getDecorView(); 58 | 59 | if (map.containsKey(decorView)) { 60 | map.remove(decorView); 61 | } 62 | } catch (Exception e) { 63 | e.printStackTrace(); 64 | } 65 | } 66 | 67 | @SuppressWarnings("unused") 68 | public static void writeModel2File(@NonNull Model model) { 69 | ObjectOutputStream objectOutputStream = null; 70 | 71 | try { 72 | final File file = new File(Global.getInstance().getDirs().root, DEMO_MODEL_FILE); 73 | 74 | if (file.exists()) { 75 | file.delete(); 76 | } 77 | 78 | final FileOutputStream outputStream = new FileOutputStream(file); 79 | objectOutputStream = new ObjectOutputStream(outputStream); 80 | 81 | objectOutputStream.writeObject(model); 82 | } catch (Exception e) { 83 | e.printStackTrace(); 84 | } finally { 85 | if (objectOutputStream != null) { 86 | try { 87 | objectOutputStream.close(); 88 | } catch (IOException e) { 89 | e.printStackTrace(); 90 | } 91 | } 92 | } 93 | } 94 | 95 | @Nullable 96 | public static Model readModelFromAssert(@NonNull Context context) throws IOException, ClassNotFoundException { 97 | Precondition.checkNotNull(context); 98 | 99 | final AssetManager manager = context.getAssets(); 100 | 101 | if (manager == null) { 102 | return null; 103 | } 104 | 105 | ObjectInputStream objectInputStream = null; 106 | 107 | try { 108 | final InputStream inputStream = manager.open(DEMO_MODEL_FILE); 109 | objectInputStream = new ObjectInputStream(inputStream); 110 | 111 | return (Model) objectInputStream.readObject(); 112 | } catch (IOException | ClassNotFoundException e) { 113 | throw e; 114 | } finally { 115 | if (objectInputStream != null) { 116 | try { 117 | objectInputStream.close(); 118 | } catch (Exception e) { 119 | e.printStackTrace(); 120 | } 121 | } 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /app/src/main/res/anim/item_animation_from_bottom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 10 | 11 | 16 | 17 | -------------------------------------------------------------------------------- /app/src/main/res/anim/layout_animation_from_bottom.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /app/src/main/res/drawable-nodpi/marker.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huazhouwang/Synapse/cc536d98a284a7c5113e90abecdd1ea67a7531a2/app/src/main/res/drawable-nodpi/marker.webp -------------------------------------------------------------------------------- /app/src/main/res/drawable-xxxhdpi/blue_ripple.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huazhouwang/Synapse/cc536d98a284a7c5113e90abecdd1ea67a7531a2/app/src/main/res/drawable-xxxhdpi/blue_ripple.webp -------------------------------------------------------------------------------- /app/src/main/res/drawable-xxxhdpi/notify_icon.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huazhouwang/Synapse/cc536d98a284a7c5113e90abecdd1ea67a7531a2/app/src/main/res/drawable-xxxhdpi/notify_icon.webp -------------------------------------------------------------------------------- /app/src/main/res/drawable-xxxhdpi/red_paper.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huazhouwang/Synapse/cc536d98a284a7c5113e90abecdd1ea67a7531a2/app/src/main/res/drawable-xxxhdpi/red_paper.webp -------------------------------------------------------------------------------- /app/src/main/res/drawable-xxxhdpi/red_sun.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huazhouwang/Synapse/cc536d98a284a7c5113e90abecdd1ea67a7531a2/app/src/main/res/drawable-xxxhdpi/red_sun.webp -------------------------------------------------------------------------------- /app/src/main/res/drawable-xxxhdpi/stack_rectangle.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huazhouwang/Synapse/cc536d98a284a7c5113e90abecdd1ea67a7531a2/app/src/main/res/drawable-xxxhdpi/stack_rectangle.webp -------------------------------------------------------------------------------- /app/src/main/res/drawable/bg_splash.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/bg_white_fillet.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/dialog_background.xml: -------------------------------------------------------------------------------- 1 | 2 | 17 | 18 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_add_white_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_arrow_forward_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_block_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_change_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_close_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_favorite_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_github_code_24.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_play_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_refresh_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/drawable/ic_share_24dp.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /app/src/main/res/layout/ac_detail_mark_view.xml: -------------------------------------------------------------------------------- 1 | 2 | 11 | 12 | 23 | 24 | -------------------------------------------------------------------------------- /app/src/main/res/layout/activity_main.xml: -------------------------------------------------------------------------------- 1 | 2 | 10 | 11 | 23 | 24 | 35 | 36 | -------------------------------------------------------------------------------- /app/src/main/res/layout/activity_play.xml: -------------------------------------------------------------------------------- 1 | 2 | 11 | 12 | 16 | 21 | 27 | 28 | 29 | 34 | 35 | 36 | 40 | 49 | 50 | 51 | 58 | 59 | 67 | 72 | 73 | 74 | 87 | 98 | 99 | 106 | 107 | 117 | 118 | 125 | 126 | 134 | 135 | 136 | 137 | 143 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /app/src/main/res/layout/dialog_about.xml: -------------------------------------------------------------------------------- 1 | 2 | 12 | 13 | 20 | 21 | 28 | 36 | 37 | 46 | 47 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /app/src/main/res/layout/dialog_model_list.xml: -------------------------------------------------------------------------------- 1 | 2 | 13 | 14 | 23 | 24 | 35 | 36 | -------------------------------------------------------------------------------- /app/src/main/res/layout/hidden_layer_input.xml: -------------------------------------------------------------------------------- 1 | 2 | 10 | 14 | 15 | 27 | 28 | 29 | 40 | -------------------------------------------------------------------------------- /app/src/main/res/layout/item_paly.xml: -------------------------------------------------------------------------------- 1 | 2 | 12 | 13 | 18 | 19 | 24 | 25 | 30 | 31 | 37 | 38 | 39 | 40 | 53 | 54 | -------------------------------------------------------------------------------- /app/src/main/res/layout/item_trained.xml: -------------------------------------------------------------------------------- 1 | 2 | 10 | 11 | 17 | 18 | 26 | 27 | 35 | 36 | 45 | 46 | 54 | 55 | 63 | 64 | 65 | 73 | 74 | 82 | 83 | 93 | 94 | -------------------------------------------------------------------------------- /app/src/main/res/layout/item_training.xml: -------------------------------------------------------------------------------- 1 | 2 | 10 | 11 | 16 | 17 | 26 | 27 | 35 | 36 | 45 | 46 | 54 | 55 | 56 | 64 | 65 | 73 | 74 | 82 | 83 | 94 | 95 | 96 | 109 | -------------------------------------------------------------------------------- /app/src/main/res/layout/item_welcome.xml: -------------------------------------------------------------------------------- 1 | 2 | 11 | 12 | 17 | 18 | 23 | 24 | 29 | 30 | 36 | 37 | 38 | 39 |