├── .gitignore ├── .jvmopts ├── .scalafmt.conf ├── LICENSE ├── NOTICE ├── README.md ├── build.sbt ├── data └── iris.csv ├── dl4j └── src │ └── main │ ├── resources │ └── logback.xml │ └── scala │ └── io │ └── brunk │ └── examples │ ├── ImageReader.scala │ ├── IrisReader.scala │ ├── dl4j │ ├── IrisMLP.scala │ ├── MnistMLP.scala │ └── SimpleCNN.scala │ └── scalnet │ ├── IrisMLP.scala │ ├── MnistMLP.scala │ └── SimpleCNN.scala ├── mxnet ├── build.sbt ├── project │ ├── build.properties │ └── plugins.sbt └── src │ └── main │ ├── resources │ └── logback.xml │ └── scala │ └── io │ └── brunk │ └── examples │ ├── IrisMLP.scala │ └── MnistMLP.scala ├── project ├── build.properties └── plugins.sbt └── tensorflow ├── example_image.jpg └── src └── main ├── protobuf └── string_int_label_map.proto ├── resources ├── logback.xml └── mscoco_label_map.pbtxt └── scala └── io └── brunk ├── DatasetSplitter.scala └── examples ├── FashionMnistCNN.scala ├── FashionMnistMLP.scala ├── IrisMLP.scala ├── MnistMLP.scala ├── ObjectDetector.scala ├── SimpleCNN.scala └── SimpleCNNModels.scala /.gitignore: -------------------------------------------------------------------------------- 1 | /temp 2 | 3 | # sbt 4 | lib_managed 5 | project/project 6 | target 7 | 8 | # Worksheets (Eclipse or IntelliJ) 9 | *.sc 10 | 11 | # Eclipse 12 | .cache* 13 | .classpath 14 | .project 15 | .scala_dependencies 16 | .settings 17 | .target 18 | .worksheet 19 | 20 | # IntelliJ 21 | .idea 22 | 23 | # ENSIME 24 | .ensime 25 | .ensime_lucene 26 | .ensime_cache 27 | 28 | # Mac 29 | .DS_Store 30 | 31 | # Akka 32 | ddata* 33 | journal 34 | snapshots 35 | 36 | # Log files 37 | *.log 38 | -------------------------------------------------------------------------------- /.jvmopts: -------------------------------------------------------------------------------- 1 | -Dfile.encoding=UTF8 2 | -Xms1G 3 | -Xmx6G 4 | -Xms6G 5 | -Xss2M 6 | -XX:ReservedCodeCacheSize=256m 7 | -XX:MaxMetaspaceSize=512m 8 | -XX:+TieredCompilation 9 | -XX:-UseGCOverheadLimit 10 | -XX:+CMSClassUnloadingEnabled 11 | -XX:+UseConcMarkSweepGC 12 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | style = defaultWithAlign 2 | 3 | danglingParentheses = true 4 | indentOperator = spray 5 | maxColumn = 100 6 | project.excludeFilters = [".*\\.sbt"] 7 | rewrite.rules = [AsciiSortImports, RedundantBraces, RedundantParens] 8 | spaces.inImportCurlyBraces = true 9 | unindentTopLevelOperators = true 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Sören Brunk 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scala-deeplearn-examples # 2 | 3 | Welcome to scala-deeplearn-examples! 4 | 5 | This repository contains a list of examples used in my blog series on deep learning with Scala https://brunk.io 6 | 7 | You can clone the repository and run the examples using SBT. 8 | 9 | ## Contribution policy ## 10 | 11 | Contributions via GitHub pull requests are gladly accepted from their original author. Along with 12 | any pull requests, please state that the contribution is your original work and that you license 13 | the work to the project under the project's open source license. Whether or not you state this 14 | explicitly, by submitting any copyrighted material via pull request, email, or other means you 15 | agree to license the material under the project's open source license and warrant that you have the 16 | legal authority to do so. 17 | 18 | ## License ## 19 | 20 | This code is open source software licensed under the 21 | [Apache-2.0](http://www.apache.org/licenses/LICENSE-2.0) license. 22 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | // ***************************************************************************** 2 | // Projects 3 | // ***************************************************************************** 4 | 5 | // The MXNet example has been moved into its own sbt project for now because we have to build mxnet manually, 6 | // and we don't want to break dependency resolution for the other projects. 7 | // lazy val mxnet = project 8 | 9 | lazy val dl4j = 10 | project 11 | .in(file("dl4j")) 12 | .enablePlugins(AutomateHeaderPlugin) 13 | .settings(settings) 14 | .settings( 15 | scalaVersion := "2.11.12", // ScalNet and ND4S are only available for Scala 2.11 16 | libraryDependencies ++= Seq( 17 | library.dl4j, 18 | library.dl4jCuda, 19 | library.dl4jUi, 20 | library.logbackClassic, 21 | library.nd4jNativePlatform, 22 | library.scalNet 23 | ) 24 | ) 25 | 26 | lazy val tensorFlow = 27 | project 28 | .in(file("tensorflow")) 29 | .enablePlugins(AutomateHeaderPlugin) 30 | .settings(settings) 31 | .settings( 32 | PB.targets in Compile := Seq( 33 | scalapb.gen() -> (sourceManaged in Compile).value 34 | ), 35 | javaCppPresetLibs ++= Seq( 36 | "ffmpeg" -> "3.4.1" 37 | ), 38 | libraryDependencies ++= Seq( 39 | library.betterFiles, 40 | library.janino, 41 | library.logbackClassic, 42 | library.tensorFlow, 43 | library.tensorFlowData 44 | ), 45 | fork := true // prevent classloader issues caused by sbt and opencv 46 | ) 47 | 48 | // ***************************************************************************** 49 | // Library dependencies 50 | // ***************************************************************************** 51 | 52 | lazy val library = 53 | new { 54 | object Version { 55 | val betterFiles = "3.4.0" 56 | val dl4j = "1.0.0-alpha" 57 | val janino = "2.6.1" 58 | val logbackClassic = "1.2.3" 59 | val scalaCheck = "1.13.5" 60 | val scalaTest = "3.0.4" 61 | val tensorFlow = "0.2.4" 62 | 63 | } 64 | val betterFiles = "com.github.pathikrit" %% "better-files" % Version.betterFiles 65 | val dl4j = "org.deeplearning4j" % "deeplearning4j-core" % Version.dl4j 66 | val dl4jUi = "org.deeplearning4j" %% "deeplearning4j-ui" % Version.dl4j 67 | val janino = "org.codehaus.janino" % "janino" % Version.janino 68 | val logbackClassic = "ch.qos.logback" % "logback-classic" % Version.logbackClassic 69 | val nd4jNativePlatform = "org.nd4j" % "nd4j-cuda-9.0-platform" % Version.dl4j 70 | val dl4jCuda = "org.deeplearning4j" % "deeplearning4j-cuda-9.0" % Version.dl4j 71 | val scalaCheck = "org.scalacheck" %% "scalacheck" % Version.scalaCheck 72 | val scalaTest = "org.scalatest" %% "scalatest" % Version.scalaTest 73 | val scalNet = "org.deeplearning4j" %% "scalnet" % Version.dl4j 74 | // change the classifier to "linux-cpu-x86_64" or "linux-gpu-x86_64" if you're on a linux/linux with nvidia system 75 | val tensorFlow = "org.platanios" %% "tensorflow" % Version.tensorFlow classifier "darwin-cpu-x86_64" 76 | val tensorFlowData = "org.platanios" %% "tensorflow-data" % Version.tensorFlow 77 | } 78 | 79 | // ***************************************************************************** 80 | // Settings 81 | // ***************************************************************************** 82 | 83 | lazy val settings = 84 | Seq( 85 | scalaVersion := "2.12.6", 86 | organization := "io.brunk", 87 | organizationName := "Sören Brunk", 88 | startYear := Some(2017), 89 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")), 90 | scalacOptions ++= Seq( 91 | "-unchecked", 92 | "-deprecation", 93 | "-language:_", 94 | "-target:jvm-1.8", 95 | "-encoding", "UTF-8" 96 | ), 97 | unmanagedSourceDirectories.in(Compile) := Seq(scalaSource.in(Compile).value), 98 | unmanagedSourceDirectories.in(Test) := Seq(scalaSource.in(Test).value), 99 | resolvers ++= Seq( 100 | Resolver.sonatypeRepo("snapshots") 101 | ) 102 | ) -------------------------------------------------------------------------------- /data/iris.csv: -------------------------------------------------------------------------------- 1 | 150,4,setosa,versicolor,virginica 2 | 5.1,3.5,1.4,0.2,0 3 | 4.9,3.0,1.4,0.2,0 4 | 4.7,3.2,1.3,0.2,0 5 | 4.6,3.1,1.5,0.2,0 6 | 5.0,3.6,1.4,0.2,0 7 | 5.4,3.9,1.7,0.4,0 8 | 4.6,3.4,1.4,0.3,0 9 | 5.0,3.4,1.5,0.2,0 10 | 4.4,2.9,1.4,0.2,0 11 | 4.9,3.1,1.5,0.1,0 12 | 5.4,3.7,1.5,0.2,0 13 | 4.8,3.4,1.6,0.2,0 14 | 4.8,3.0,1.4,0.1,0 15 | 4.3,3.0,1.1,0.1,0 16 | 5.8,4.0,1.2,0.2,0 17 | 5.7,4.4,1.5,0.4,0 18 | 5.4,3.9,1.3,0.4,0 19 | 5.1,3.5,1.4,0.3,0 20 | 5.7,3.8,1.7,0.3,0 21 | 5.1,3.8,1.5,0.3,0 22 | 5.4,3.4,1.7,0.2,0 23 | 5.1,3.7,1.5,0.4,0 24 | 4.6,3.6,1.0,0.2,0 25 | 5.1,3.3,1.7,0.5,0 26 | 4.8,3.4,1.9,0.2,0 27 | 5.0,3.0,1.6,0.2,0 28 | 5.0,3.4,1.6,0.4,0 29 | 5.2,3.5,1.5,0.2,0 30 | 5.2,3.4,1.4,0.2,0 31 | 4.7,3.2,1.6,0.2,0 32 | 4.8,3.1,1.6,0.2,0 33 | 5.4,3.4,1.5,0.4,0 34 | 5.2,4.1,1.5,0.1,0 35 | 5.5,4.2,1.4,0.2,0 36 | 4.9,3.1,1.5,0.1,0 37 | 5.0,3.2,1.2,0.2,0 38 | 5.5,3.5,1.3,0.2,0 39 | 4.9,3.1,1.5,0.1,0 40 | 4.4,3.0,1.3,0.2,0 41 | 5.1,3.4,1.5,0.2,0 42 | 5.0,3.5,1.3,0.3,0 43 | 4.5,2.3,1.3,0.3,0 44 | 4.4,3.2,1.3,0.2,0 45 | 5.0,3.5,1.6,0.6,0 46 | 5.1,3.8,1.9,0.4,0 47 | 4.8,3.0,1.4,0.3,0 48 | 5.1,3.8,1.6,0.2,0 49 | 4.6,3.2,1.4,0.2,0 50 | 5.3,3.7,1.5,0.2,0 51 | 5.0,3.3,1.4,0.2,0 52 | 7.0,3.2,4.7,1.4,1 53 | 6.4,3.2,4.5,1.5,1 54 | 6.9,3.1,4.9,1.5,1 55 | 5.5,2.3,4.0,1.3,1 56 | 6.5,2.8,4.6,1.5,1 57 | 5.7,2.8,4.5,1.3,1 58 | 6.3,3.3,4.7,1.6,1 59 | 4.9,2.4,3.3,1.0,1 60 | 6.6,2.9,4.6,1.3,1 61 | 5.2,2.7,3.9,1.4,1 62 | 5.0,2.0,3.5,1.0,1 63 | 5.9,3.0,4.2,1.5,1 64 | 6.0,2.2,4.0,1.0,1 65 | 6.1,2.9,4.7,1.4,1 66 | 5.6,2.9,3.6,1.3,1 67 | 6.7,3.1,4.4,1.4,1 68 | 5.6,3.0,4.5,1.5,1 69 | 5.8,2.7,4.1,1.0,1 70 | 6.2,2.2,4.5,1.5,1 71 | 5.6,2.5,3.9,1.1,1 72 | 5.9,3.2,4.8,1.8,1 73 | 6.1,2.8,4.0,1.3,1 74 | 6.3,2.5,4.9,1.5,1 75 | 6.1,2.8,4.7,1.2,1 76 | 6.4,2.9,4.3,1.3,1 77 | 6.6,3.0,4.4,1.4,1 78 | 6.8,2.8,4.8,1.4,1 79 | 6.7,3.0,5.0,1.7,1 80 | 6.0,2.9,4.5,1.5,1 81 | 5.7,2.6,3.5,1.0,1 82 | 5.5,2.4,3.8,1.1,1 83 | 5.5,2.4,3.7,1.0,1 84 | 5.8,2.7,3.9,1.2,1 85 | 6.0,2.7,5.1,1.6,1 86 | 5.4,3.0,4.5,1.5,1 87 | 6.0,3.4,4.5,1.6,1 88 | 6.7,3.1,4.7,1.5,1 89 | 6.3,2.3,4.4,1.3,1 90 | 5.6,3.0,4.1,1.3,1 91 | 5.5,2.5,4.0,1.3,1 92 | 5.5,2.6,4.4,1.2,1 93 | 6.1,3.0,4.6,1.4,1 94 | 5.8,2.6,4.0,1.2,1 95 | 5.0,2.3,3.3,1.0,1 96 | 5.6,2.7,4.2,1.3,1 97 | 5.7,3.0,4.2,1.2,1 98 | 5.7,2.9,4.2,1.3,1 99 | 6.2,2.9,4.3,1.3,1 100 | 5.1,2.5,3.0,1.1,1 101 | 5.7,2.8,4.1,1.3,1 102 | 6.3,3.3,6.0,2.5,2 103 | 5.8,2.7,5.1,1.9,2 104 | 7.1,3.0,5.9,2.1,2 105 | 6.3,2.9,5.6,1.8,2 106 | 6.5,3.0,5.8,2.2,2 107 | 7.6,3.0,6.6,2.1,2 108 | 4.9,2.5,4.5,1.7,2 109 | 7.3,2.9,6.3,1.8,2 110 | 6.7,2.5,5.8,1.8,2 111 | 7.2,3.6,6.1,2.5,2 112 | 6.5,3.2,5.1,2.0,2 113 | 6.4,2.7,5.3,1.9,2 114 | 6.8,3.0,5.5,2.1,2 115 | 5.7,2.5,5.0,2.0,2 116 | 5.8,2.8,5.1,2.4,2 117 | 6.4,3.2,5.3,2.3,2 118 | 6.5,3.0,5.5,1.8,2 119 | 7.7,3.8,6.7,2.2,2 120 | 7.7,2.6,6.9,2.3,2 121 | 6.0,2.2,5.0,1.5,2 122 | 6.9,3.2,5.7,2.3,2 123 | 5.6,2.8,4.9,2.0,2 124 | 7.7,2.8,6.7,2.0,2 125 | 6.3,2.7,4.9,1.8,2 126 | 6.7,3.3,5.7,2.1,2 127 | 7.2,3.2,6.0,1.8,2 128 | 6.2,2.8,4.8,1.8,2 129 | 6.1,3.0,4.9,1.8,2 130 | 6.4,2.8,5.6,2.1,2 131 | 7.2,3.0,5.8,1.6,2 132 | 7.4,2.8,6.1,1.9,2 133 | 7.9,3.8,6.4,2.0,2 134 | 6.4,2.8,5.6,2.2,2 135 | 6.3,2.8,5.1,1.5,2 136 | 6.1,2.6,5.6,1.4,2 137 | 7.7,3.0,6.1,2.3,2 138 | 6.3,3.4,5.6,2.4,2 139 | 6.4,3.1,5.5,1.8,2 140 | 6.0,3.0,4.8,1.8,2 141 | 6.9,3.1,5.4,2.1,2 142 | 6.7,3.1,5.6,2.4,2 143 | 6.9,3.1,5.1,2.3,2 144 | 5.8,2.7,5.1,1.9,2 145 | 6.8,3.2,5.9,2.3,2 146 | 6.7,3.3,5.7,2.5,2 147 | 6.7,3.0,5.2,2.3,2 148 | 6.3,2.5,5.0,1.9,2 149 | 6.5,3.0,5.2,2.0,2 150 | 6.2,3.4,5.4,2.3,2 151 | 5.9,3.0,5.1,1.8,2 152 | -------------------------------------------------------------------------------- /dl4j/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /dl4j/src/main/scala/io/brunk/examples/ImageReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import java.io.{File, FileFilter} 20 | import java.lang.Math.toIntExact 21 | 22 | import org.datavec.api.io.filters.BalancedPathFilter 23 | import org.datavec.api.io.labels.ParentPathLabelGenerator 24 | import org.datavec.api.split.{FileSplit, InputSplit} 25 | import org.datavec.image.loader.BaseImageLoader 26 | import org.datavec.image.recordreader.ImageRecordReader 27 | import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator 28 | import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator 29 | import org.deeplearning4j.eval.Evaluation 30 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator 31 | import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler 32 | 33 | import scala.collection.JavaConverters._ 34 | 35 | 36 | object ImageReader { 37 | 38 | val channels = 3 39 | val height = 150 40 | val width = 150 41 | 42 | val batchSize = 50 43 | val numClasses = 2 44 | val epochs = 100 45 | val splitTrainTest = 0.8 46 | 47 | val random = new java.util.Random() 48 | 49 | def createImageIterator(path: String): (MultipleEpochsIterator, DataSetIterator) = { 50 | val baseDir = new File(path) 51 | val labelGenerator = new ParentPathLabelGenerator 52 | val fileSplit = new FileSplit(baseDir, BaseImageLoader.ALLOWED_FORMATS, random) 53 | 54 | val numExamples = toIntExact(fileSplit.length) 55 | val numLabels = fileSplit.getRootDir.listFiles(new FileFilter { 56 | override def accept(pathname: File): Boolean = pathname.isDirectory 57 | }).length 58 | 59 | val pathFilter = new BalancedPathFilter(random, labelGenerator, numExamples, numLabels, batchSize) 60 | 61 | //val inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest) 62 | val inputSplit = fileSplit.sample(pathFilter, 70, 30) 63 | 64 | val trainData = inputSplit(0) 65 | val validationData = inputSplit(1) 66 | 67 | val recordReader = new ImageRecordReader(height, width, channels, labelGenerator) 68 | val scaler = new ImagePreProcessingScaler(0, 1) 69 | 70 | recordReader.initialize(trainData, null) 71 | val dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numClasses) 72 | scaler.fit(dataIter) 73 | dataIter.setPreProcessor(scaler) 74 | val trainIter = new MultipleEpochsIterator(epochs, dataIter) 75 | 76 | val valRecordReader = new ImageRecordReader(height, width, channels, labelGenerator) 77 | valRecordReader.initialize(validationData, null) 78 | val validationIter = new RecordReaderDataSetIterator(valRecordReader, batchSize, 1, numClasses) 79 | scaler.fit(validationIter) 80 | validationIter.setPreProcessor(scaler) 81 | 82 | (trainIter, validationIter) 83 | } 84 | 85 | } 86 | -------------------------------------------------------------------------------- /dl4j/src/main/scala/io/brunk/examples/IrisReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import java.io.File 20 | 21 | import org.datavec.api.records.reader.impl.csv.CSVRecordReader 22 | import org.datavec.api.split.FileSplit 23 | import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator 24 | import org.nd4j.linalg.dataset.SplitTestAndTrain 25 | import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize 26 | 27 | object IrisReader { 28 | val numLinesToSkip = 1 29 | 30 | val batchSize = 150 31 | val labelIndex = 4 32 | val numLabels = 3 33 | 34 | val seed = 1 35 | 36 | def readData(): SplitTestAndTrain = { 37 | val recordReader = new CSVRecordReader(numLinesToSkip, ',') 38 | recordReader.initialize(new FileSplit(new File("data/iris.csv"))) 39 | val iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numLabels) 40 | val dataSet = iterator.next() // read all data in a single batch 41 | dataSet.shuffle(seed) 42 | val testAndTrain = dataSet.splitTestAndTrain(0.67) 43 | val train = testAndTrain.getTrain 44 | val test = testAndTrain.getTest 45 | 46 | // val normalizer = new NormalizerStandardize 47 | // normalizer.fit(train) 48 | // normalizer.transform(train) // normalize training data 49 | // normalizer.transform(test) // normalize test data 50 | testAndTrain 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /dl4j/src/main/scala/io/brunk/examples/dl4j/IrisMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples.dl4j 18 | 19 | import io.brunk.examples.IrisReader 20 | import org.deeplearning4j.eval.Evaluation 21 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration 22 | import org.deeplearning4j.nn.conf.layers.{ DenseLayer, OutputLayer } 23 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork 24 | import org.deeplearning4j.nn.weights.WeightInit 25 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener 26 | import org.nd4j.linalg.activations.Activation 27 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction 28 | import org.slf4j.{ Logger, LoggerFactory } 29 | 30 | /** 31 | * A simple feed forward network for classifying the IRIS dataset in dl4j with a single hidden layer 32 | * 33 | * Based on 34 | * https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/CSVExample.java 35 | * 36 | * @author Sören Brunk 37 | */ 38 | object IrisMLP { 39 | private val log: Logger = LoggerFactory.getLogger(IrisMLP.getClass) 40 | 41 | def main(args: Array[String]): Unit = { 42 | 43 | val seed = 1 // for reproducibility 44 | val numInputs = 4 45 | val numHidden = 10 46 | val numOutputs = 3 47 | val learningRate = 0.1 48 | val numEpoch = 30 49 | 50 | val testAndTrain = IrisReader.readData() 51 | 52 | val conf = new NeuralNetConfiguration.Builder() 53 | .seed(seed) 54 | .activation(Activation.RELU) 55 | .weightInit(WeightInit.XAVIER) 56 | .list() 57 | .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHidden).build()) 58 | .layer(1, 59 | new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) 60 | .activation(Activation.SOFTMAX) 61 | .nIn(numHidden) 62 | .nOut(numOutputs) 63 | .build()) 64 | .backprop(true) 65 | .pretrain(false) 66 | .build() 67 | 68 | val model = new MultiLayerNetwork(conf) 69 | model.init() 70 | model.setListeners(new ScoreIterationListener(100)) // print out scores every 100 iterations 71 | 72 | log.info("Running training") 73 | for(_ <- 0 until numEpoch) 74 | model.fit(testAndTrain.getTrain) 75 | 76 | log.info("Training finished") 77 | 78 | log.info(s"Evaluating model on ${testAndTrain.getTest.getLabels.rows()} examples") 79 | val evaluator = new Evaluation(numOutputs) 80 | val output = model.output(testAndTrain.getTest.getFeatureMatrix) 81 | evaluator.eval(testAndTrain.getTest.getLabels, output) 82 | println(evaluator.stats) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /dl4j/src/main/scala/io/brunk/examples/dl4j/MnistMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples.dl4j 18 | 19 | import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator 20 | import org.deeplearning4j.eval.Evaluation 21 | import org.deeplearning4j.nn.api.OptimizationAlgorithm 22 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration 23 | import org.deeplearning4j.nn.conf.layers.{DenseLayer, OutputLayer} 24 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork 25 | import org.deeplearning4j.nn.weights.WeightInit 26 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener 27 | import org.nd4j.linalg.activations.Activation 28 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator 29 | import org.nd4j.linalg.learning.config.Sgd 30 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction 31 | import org.slf4j.LoggerFactory 32 | 33 | import scala.collection.JavaConverters.asScalaIteratorConverter 34 | 35 | /** Simple multilayer perceptron for classifying handwritten digits from the MNIST dataset. 36 | * 37 | * Implemented using DL4J based on the Java example from 38 | * https://github.com/deeplearning4j/dl4j-examples/blob/dfcf71d75fff956db53a93b09b560d53e3da4638/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/mnist/MLPMnistSingleLayerExample.java 39 | * 40 | * @author Sören Brunk 41 | */ 42 | object MnistMLP { 43 | private val log = LoggerFactory.getLogger(MnistMLP.getClass) 44 | 45 | def main(args: Array[String]): Unit = { 46 | 47 | val seed = 1 // for reproducibility 48 | val numInputs = 28 * 28 49 | val numHidden = 512 // size (number of neurons) of our hidden layer 50 | val numOutputs = 10 // digits from 0 to 9 51 | val learningRate = 0.01 52 | val batchSize = 128 53 | val numEpochs = 10 54 | 55 | // download and load the MNIST images as tensors 56 | val mnistTrain = new MnistDataSetIterator(batchSize, true, seed) 57 | val mnistTest = new MnistDataSetIterator(batchSize, false, seed) 58 | 59 | // define the neural network architecture 60 | val conf = new NeuralNetConfiguration.Builder() 61 | .seed(seed) 62 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 63 | .updater(new Sgd(learningRate)) 64 | .weightInit(WeightInit.XAVIER) // random initialization of our weights 65 | .list // builder for creating stacked layers 66 | .layer(0, new DenseLayer.Builder() // define the hidden layer 67 | .nIn(numInputs) 68 | .nOut(numHidden) 69 | .activation(Activation.RELU) 70 | .build()) 71 | .layer(1, new OutputLayer.Builder(LossFunction.MCXENT) // define loss and output layer 72 | .nIn(numHidden) 73 | .nOut(numOutputs) 74 | .activation(Activation.SOFTMAX) 75 | .build()) 76 | .build() 77 | 78 | val model = new MultiLayerNetwork(conf) 79 | model.init() 80 | model.setListeners(new ScoreIterationListener(100)) // print the score every 100th iteration 81 | 82 | // train the model 83 | for (_ <- 0 until numEpochs) 84 | model.fit(mnistTrain) 85 | 86 | // evaluate model performance 87 | def accuracy(dataSet: DataSetIterator): Double = { 88 | val evaluator = new Evaluation(numOutputs) 89 | dataSet.reset() 90 | for (dataSet <- dataSet.asScala) { 91 | val output = model.output(dataSet.getFeatureMatrix) 92 | evaluator.eval(dataSet.getLabels, output) 93 | } 94 | evaluator.accuracy() 95 | } 96 | 97 | log.info(s"Train accuracy = ${accuracy(mnistTrain)}") 98 | log.info(s"Test accuracy = ${accuracy(mnistTest)}") 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /dl4j/src/main/scala/io/brunk/examples/dl4j/SimpleCNN.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples.dl4j 18 | 19 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration 20 | import org.deeplearning4j.nn.conf.layers.{ConvolutionLayer, DenseLayer, OutputLayer, SubsamplingLayer} 21 | import org.nd4j.linalg.learning.config.Adam 22 | import io.brunk.examples.ImageReader._ 23 | import org.deeplearning4j.nn.conf.dropout.Dropout 24 | import org.deeplearning4j.nn.conf.inputs.InputType 25 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork 26 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener 27 | import org.deeplearning4j.ui.api.UIServer 28 | import org.deeplearning4j.ui.stats.StatsListener 29 | import org.deeplearning4j.ui.storage.InMemoryStatsStorage 30 | import org.nd4j.linalg.activations.Activation.{RELU, SOFTMAX} 31 | import org.nd4j.linalg.lossfunctions.LossFunctions 32 | import org.slf4j.LoggerFactory 33 | 34 | 35 | object SimpleCNN { 36 | 37 | private val log = LoggerFactory.getLogger(getClass) 38 | val seed = 1 39 | 40 | def main(args: Array[String]): Unit = { 41 | 42 | val dataDir = args.head 43 | 44 | val conf = new NeuralNetConfiguration.Builder() 45 | .seed(seed) 46 | .updater(new Adam) 47 | .list() 48 | .layer(0, new ConvolutionLayer.Builder(3, 3) 49 | .nIn(channels) 50 | .nOut(32) 51 | .activation(RELU) 52 | .build()) 53 | .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) 54 | .kernelSize(2, 2) 55 | .build()) 56 | .layer(2, new ConvolutionLayer.Builder(3, 3) 57 | .nOut(64) 58 | .activation(RELU) 59 | .build()) 60 | .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) 61 | .kernelSize(2, 2) 62 | .build()) 63 | .layer(4, new ConvolutionLayer.Builder(3, 3) 64 | .nOut(128) 65 | .activation(RELU) 66 | .build()) 67 | .layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) 68 | .kernelSize(2, 2) 69 | .build()) 70 | .layer(6, new ConvolutionLayer.Builder(3, 3) 71 | .nOut(128) 72 | .activation(RELU) 73 | .build()) 74 | .layer(7, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) 75 | .kernelSize(2, 2) 76 | .build()) 77 | .layer(8, new DenseLayer.Builder() 78 | .nOut(512) 79 | .activation(RELU) 80 | .dropOut(new Dropout(0.5)) 81 | .build()) 82 | .layer(9, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) 83 | .nOut(2) 84 | .activation(SOFTMAX) 85 | .build()) 86 | .setInputType(InputType.convolutional(150, 150, 3)) 87 | .backprop(true).pretrain(false).build() 88 | 89 | val model = new MultiLayerNetwork(conf) 90 | model.init() 91 | model.setListeners(new ScoreIterationListener(10)) 92 | log.debug("Total num of params: {}", model.numParams) 93 | 94 | val uiServer = UIServer.getInstance 95 | val statsStorage = new InMemoryStatsStorage 96 | uiServer.attach(statsStorage) 97 | model.setListeners(new StatsListener(statsStorage)) 98 | 99 | val (trainIter, testIter) = createImageIterator(dataDir) 100 | 101 | model.fit(trainIter) 102 | val eval = model.evaluate(testIter) 103 | log.info(eval.stats) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /dl4j/src/main/scala/io/brunk/examples/scalnet/IrisMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples.scalnet 18 | 19 | import io.brunk.examples.IrisReader 20 | import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator 21 | import org.deeplearning4j.eval.Evaluation 22 | import org.deeplearning4j.nn.conf.Updater 23 | import org.deeplearning4j.nn.weights.WeightInit 24 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener 25 | import org.deeplearning4j.scalnet.layers.core.Dense 26 | import org.deeplearning4j.scalnet.models.Sequential 27 | import org.deeplearning4j.scalnet.regularizers.L2 28 | import org.nd4j.linalg.activations.Activation 29 | import org.nd4j.linalg.api.ndarray.INDArray 30 | import org.nd4j.linalg.learning.config.Sgd 31 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction 32 | import org.slf4j.{Logger, LoggerFactory} 33 | 34 | /** 35 | * A simple feed forward network (one hidden layer) for classifying the IRIS dataset 36 | * implemented using ScalNet. 37 | * 38 | * @author Sören Brunk 39 | */ 40 | object IrisMLP { 41 | 42 | private val log: Logger = LoggerFactory.getLogger(IrisMLP.getClass) 43 | 44 | def main(args: Array[String]): Unit = { 45 | 46 | val seed = 1 47 | val numInputs = 4 48 | val numHidden = 10 49 | val numOutputs = 3 50 | val learningRate = 0.1 51 | val iterations = 1000 52 | 53 | val testAndTrain = IrisReader.readData() 54 | val trainList = testAndTrain.getTrain.asList() 55 | val trainIterator = new ListDataSetIterator(trainList, trainList.size) 56 | 57 | val model = Sequential(rngSeed = seed) 58 | model.add(Dense(numHidden, nIn = numInputs, weightInit = WeightInit.XAVIER, activation = Activation.RELU)) 59 | model.add(Dense(numOutputs, weightInit = WeightInit.XAVIER, activation = Activation.SOFTMAX)) 60 | 61 | model.compile(lossFunction = LossFunction.NEGATIVELOGLIKELIHOOD, updater = Updater.SGD) 62 | 63 | log.info("Running training") 64 | model.fit(iter = trainIterator, 65 | nbEpoch = iterations, 66 | listeners = List(new ScoreIterationListener(100))) 67 | log.info("Training finished") 68 | 69 | log.info(s"Evaluating model on ${testAndTrain.getTest.getLabels.rows()} examples") 70 | val evaluator = new Evaluation(numOutputs) 71 | val output: INDArray = model.predict(testAndTrain.getTest.getFeatureMatrix) 72 | evaluator.eval(testAndTrain.getTest.getLabels, output) 73 | log.info(evaluator.stats()) 74 | 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /dl4j/src/main/scala/io/brunk/examples/scalnet/MnistMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples.scalnet 18 | 19 | import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator 20 | import org.deeplearning4j.eval.Evaluation 21 | import org.deeplearning4j.nn.conf.Updater 22 | import org.deeplearning4j.nn.weights.WeightInit 23 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener 24 | import org.deeplearning4j.scalnet.layers.core.Dense 25 | import org.deeplearning4j.scalnet.models.Sequential 26 | import org.nd4j.linalg.activations.Activation 27 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator 28 | import org.nd4j.linalg.learning.config.Sgd 29 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction 30 | import org.slf4j.{Logger, LoggerFactory} 31 | 32 | import scala.collection.JavaConverters.asScalaIteratorConverter 33 | 34 | 35 | /** Simple multilayer perceptron for classifying handwritten digits from the MNIST dataset. 36 | * 37 | * Implemented using ScalNet. 38 | * 39 | * @author Sören Brunk 40 | */ 41 | object MnistMLP { 42 | private val log: Logger = LoggerFactory.getLogger(MnistMLP.getClass) 43 | 44 | def main(args: Array[String]): Unit = { 45 | 46 | val seed = 1 // for reproducibility 47 | val numInputs = 28 * 28 48 | val numHidden = 512 // size (number of neurons) in our hidden layer 49 | val numOutputs = 10 // digits from 0 to 9 50 | val learningRate = 0.01 51 | val batchSize = 128 52 | val numEpochs = 10 53 | 54 | // download and load the MNIST images as tensors 55 | val mnistTrain: DataSetIterator = new MnistDataSetIterator(batchSize, true, seed) 56 | val mnistTest: DataSetIterator = new MnistDataSetIterator(batchSize, false, seed) 57 | 58 | // define the neural network architecture 59 | val model: Sequential = Sequential(rngSeed = seed) 60 | model.add(Dense(nOut = numHidden, nIn = numInputs, weightInit = WeightInit.XAVIER, activation = Activation.RELU)) 61 | model.add(Dense(nOut = numOutputs, weightInit = WeightInit.XAVIER, activation = Activation.RELU)) 62 | model.compile(lossFunction = LossFunction.MCXENT, updater = Updater.SGD) // TODO how do we set the learning rate? 63 | 64 | // train the model 65 | model.fit(mnistTrain, nbEpoch = numEpochs, List(new ScoreIterationListener(100))) 66 | 67 | // evaluate model performance 68 | def accuracy(dataSet: DataSetIterator): Double = { 69 | val evaluator = new Evaluation(numOutputs) 70 | dataSet.reset() 71 | for (dataSet <- dataSet.asScala) { 72 | val output = model.predict(dataSet) 73 | evaluator.eval(dataSet.getLabels, output) 74 | } 75 | evaluator.accuracy() 76 | } 77 | 78 | log.info(s"Train accuracy = ${accuracy(mnistTrain)}") 79 | log.info(s"Test accuracy = ${accuracy(mnistTest)}") 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /dl4j/src/main/scala/io/brunk/examples/scalnet/SimpleCNN.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples.scalnet 18 | 19 | import io.brunk.examples.ImageReader 20 | import org.deeplearning4j.nn.conf.inputs.InputType 21 | import org.deeplearning4j.scalnet.models.NeuralNet 22 | import io.brunk.examples.ImageReader._ 23 | import org.deeplearning4j.nn.conf.Updater 24 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener 25 | import org.deeplearning4j.scalnet.layers.convolutional.Convolution2D 26 | import org.deeplearning4j.scalnet.layers.core.Dense 27 | import org.deeplearning4j.scalnet.layers.pooling.MaxPooling2D 28 | import org.nd4j.linalg.activations.Activation._ 29 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction 30 | 31 | object SimpleCNN { 32 | 33 | 34 | def main(args: Array[String]): Unit = { 35 | 36 | val dataDir = args.head 37 | 38 | val seed = 1 39 | 40 | val model = NeuralNet(inputType = InputType.convolutional(height, width, channels), rngSeed = seed) 41 | 42 | model.add(Convolution2D(32, List(3, 3), channels, activation = RELU)) 43 | model.add(MaxPooling2D(List(2, 2))) 44 | 45 | model.add(Convolution2D(64, List(3, 3), activation = RELU)) 46 | model.add(MaxPooling2D(List(2, 2))) 47 | 48 | model.add(Convolution2D(128, List(3, 3), activation = RELU)) 49 | model.add(MaxPooling2D(List(2, 2))) 50 | 51 | model.add(Convolution2D(128, List(3, 3), activation = RELU)) 52 | model.add(MaxPooling2D(List(2, 2))) 53 | 54 | model.add(Dense(512, activation = RELU, dropOut = 0.5)) 55 | model.add(Dense(2, activation = SOFTMAX)) 56 | 57 | model.compile(lossFunction = LossFunction.NEGATIVELOGLIKELIHOOD, updater = Updater.ADAM) 58 | 59 | val (trainIter, testIter) = createImageIterator(dataDir) 60 | 61 | model.fit(trainIter, 30, List(new ScoreIterationListener(10))) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /mxnet/build.sbt: -------------------------------------------------------------------------------- 1 | // ***************************************************************************** 2 | // Projects 3 | // ***************************************************************************** 4 | 5 | lazy val mxnet = 6 | project 7 | .in(file(".")) 8 | .enablePlugins(AutomateHeaderPlugin) 9 | .settings(settings) 10 | .settings( 11 | scalaVersion := "2.11.12", // MXNet is only available for Scala 2.11 12 | resolvers += Resolver.mavenLocal, 13 | libraryDependencies ++= Seq( 14 | library.logbackClassic, 15 | library.mxnetFull 16 | ) 17 | ) 18 | 19 | // ***************************************************************************** 20 | // Library dependencies 21 | // ***************************************************************************** 22 | 23 | lazy val library = 24 | new { 25 | object Version { 26 | val logbackClassic = "1.2.3" 27 | val mxnet = "1.0.0-SNAPSHOT" 28 | } 29 | val logbackClassic = "ch.qos.logback" % "logback-classic" % Version.logbackClassic 30 | // change to "mxnet-full_2.10-linux-x86_64-cpu" or "mxnet-full_2.10-linux-x86_64-gpu" depending on your os/gpu 31 | val mxnetFull = "ml.dmlc.mxnet" % "mxnet-full_2.11-osx-x86_64-cpu" % Version.mxnet 32 | } 33 | 34 | // ***************************************************************************** 35 | // Settings 36 | // ***************************************************************************** 37 | 38 | lazy val settings = 39 | Seq( 40 | scalaVersion := "2.12.4", 41 | organization := "io.brunk", 42 | organizationName := "Sören Brunk", 43 | startYear := Some(2017), 44 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")), 45 | scalacOptions ++= Seq( 46 | "-unchecked", 47 | "-deprecation", 48 | "-language:_", 49 | "-target:jvm-1.8", 50 | "-encoding", "UTF-8" 51 | ), 52 | unmanagedSourceDirectories.in(Compile) := Seq(scalaSource.in(Compile).value), 53 | unmanagedSourceDirectories.in(Test) := Seq(scalaSource.in(Test).value) 54 | ) 55 | -------------------------------------------------------------------------------- /mxnet/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.0.3 2 | -------------------------------------------------------------------------------- /mxnet/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("de.heikoseeberger" % "sbt-header" % "4.0.0") 2 | addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-RC13") -------------------------------------------------------------------------------- /mxnet/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /mxnet/src/main/scala/io/brunk/examples/IrisMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import ml.dmlc.mxnet._ 20 | import ml.dmlc.mxnet.io.NDArrayIter 21 | import ml.dmlc.mxnet.optimizer.SGD 22 | 23 | object IrisMLP { 24 | 25 | def main(args: Array[String]): Unit = { 26 | 27 | val numInputs = 4 28 | val numHidden = 10 29 | val numOutputs = 3 30 | val learningRate = 0.1f 31 | val iterations = 1000 32 | val trainSize = 100 33 | val testSize = 50 34 | 35 | val batchSize = 50 36 | val epochs = (iterations / (batchSize.toFloat / trainSize)).toInt 37 | 38 | // The mxnet Scala IO API does not support shuffling so we just read the csv using plain Scala 39 | val source = scala.io.Source.fromFile("data/iris.csv") 40 | val rows = source.getLines().drop(1).map { l => 41 | val columns = l.split(",").map(_.toFloat) 42 | new { 43 | val features = columns.take(4) 44 | val labels = columns(4) 45 | } 46 | }.toBuffer 47 | val shuffled = scala.util.Random.shuffle(rows).toArray 48 | val trainData = shuffled.take(trainSize) 49 | val testData = shuffled.drop(trainSize) 50 | val trainFeatures = NDArray.array(trainData.flatMap(_.features), Shape(trainSize, numInputs)) 51 | val trainLabels = NDArray.array(trainData.map(_.labels), Shape(trainSize)) 52 | val testFeatures = NDArray.array(testData.flatMap(_.features), Shape(testSize, numInputs)) 53 | val testLabels = NDArray.array(testData.map(_.labels), Shape(testSize)) 54 | 55 | 56 | val trainDataIter = new NDArrayIter(data = IndexedSeq(trainFeatures), label = IndexedSeq(trainLabels), dataBatchSize = 50) 57 | val testDataIter = new NDArrayIter(data = IndexedSeq(testFeatures), label = IndexedSeq(testLabels), dataBatchSize = 50) 58 | 59 | // Define the network architecture 60 | val data = Symbol.Variable("data") 61 | val label = Symbol.Variable("label") 62 | val l1 = Symbol.FullyConnected(name = "l1")()(Map("data" -> data, "num_hidden" -> numHidden)) 63 | val a1 = Symbol.Activation(name = "a1")()(Map("data" -> l1, "act_type" -> "relu")) 64 | val l2 = Symbol.FullyConnected(name = "l2")()(Map("data" -> a1, "num_hidden" -> numOutputs)) 65 | val out = Symbol.SoftmaxOutput(name = "sm")()(Map("data" -> l2, "label" -> label)) 66 | 67 | // Create and train a model 68 | val model = FeedForward.newBuilder(out) 69 | .setContext(Context.cpu()) // change to gpu if available 70 | .setNumEpoch(epochs) 71 | .setOptimizer(new SGD(learningRate = learningRate)) 72 | .setTrainData(trainDataIter) 73 | .setEvalData(testDataIter) 74 | .build() 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /mxnet/src/main/scala/io/brunk/examples/MnistMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import ml.dmlc.mxnet._ 20 | import ml.dmlc.mxnet.optimizer.SGD 21 | 22 | /** Simple multilayer perceptron for classifying handwritten digits from the MNIST dataset. 23 | * 24 | * Implemented using MXNet. 25 | * Based on https://mxnet.incubator.apache.org/tutorials/scala/mnist.html 26 | * 27 | * @author Sören Brunk 28 | */ 29 | object MnistMLP { 30 | 31 | def main(args: Array[String]): Unit = { 32 | 33 | val numHidden = 512 // size (number of neurons) of our hidden layer 34 | val numOutputs = 10 // digits from 0 to 9 35 | val learningRate = 0.01f 36 | val batchSize = 128 37 | val numEpochs = 10 38 | 39 | // load the MNIST images as tensors 40 | val trainDataIter = IO.MNISTIter(Map( 41 | "image" -> "mnist/train-images-idx3-ubyte", 42 | "label" -> "mnist/train-labels-idx1-ubyte", 43 | "data_shape" -> "(1, 28, 28)", 44 | "label_name" -> "sm_label", 45 | "batch_size" -> batchSize.toString, 46 | "shuffle" -> "1", 47 | "flat" -> "0", 48 | "silent" -> "0")) 49 | 50 | val testDataIter = IO.MNISTIter(Map( 51 | "image" -> "mnist/t10k-images-idx3-ubyte", 52 | "label" -> "mnist/t10k-labels-idx1-ubyte", 53 | "data_shape" -> "(1, 28, 28)", 54 | "label_name" -> "sm_label", 55 | "batch_size" -> batchSize.toString, 56 | "shuffle" -> "1", 57 | "flat" -> "0", 58 | "silent" -> "0")) 59 | 60 | // define the neural network architecture 61 | val data = Symbol.Variable("data") 62 | val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> numHidden)) 63 | val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu")) 64 | val fc2 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act1, "num_hidden" -> numOutputs)) 65 | val mlp = Symbol.SoftmaxOutput(name = "sm")()(Map("data" -> fc2)) 66 | 67 | // create and train the model 68 | val model = FeedForward.newBuilder(mlp) 69 | .setContext(Context.cpu()) // change to gpu if available 70 | .setTrainData(trainDataIter) 71 | .setEvalData(testDataIter) 72 | .setNumEpoch(numEpochs) 73 | .setOptimizer(new SGD(learningRate = learningRate)) 74 | .setInitializer(new Xavier()) // random weight initialization 75 | .build() 76 | 77 | // evaluate model performance 78 | def accuracy(dataset: DataIter): Float = { 79 | dataset.reset() 80 | val predictions = model.predict(dataset).head 81 | // get predicted labels 82 | val predictedY = NDArray.argmax_channel(predictions) 83 | 84 | // get real labels 85 | dataset.reset() 86 | val labels = dataset.map(_.label(0).copy()).toVector 87 | val y = NDArray.concatenate(labels) 88 | require(y.shape == predictedY.shape) 89 | 90 | // calculate accuracy 91 | val numCorrect = (y.toArray zip predictedY.toArray).count { 92 | case (labelElem, predElem) => labelElem == predElem 93 | } 94 | numCorrect.toFloat / y.size 95 | } 96 | 97 | println(s"Train accuracy = ${accuracy(trainDataIter)}") 98 | println(s"Test accuracy = ${accuracy(testDataIter)}") 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.2.1 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("de.heikoseeberger" % "sbt-header" % "4.0.0") 2 | addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.1") 3 | addSbtPlugin("org.bytedeco" % "sbt-javacv" % "1.16") 4 | addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.12") 5 | 6 | libraryDependencies += "com.trueaccord.scalapb" %% "compilerplugin" % "0.6.6" -------------------------------------------------------------------------------- /tensorflow/example_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbrunk/scala-deeplearn-examples/23edfff79c6a590ba5a5fd896080fb8ac116579a/tensorflow/example_image.jpg -------------------------------------------------------------------------------- /tensorflow/src/main/protobuf/string_int_label_map.proto: -------------------------------------------------------------------------------- 1 | // Message to store the mapping from class label strings to class id. Datasets 2 | // use string labels to represent classes while the object detection framework 3 | // works with class ids. This message maps them so they can be converted back 4 | // and forth as needed. 5 | syntax = "proto2"; 6 | 7 | package object_detection.protos; 8 | 9 | message StringIntLabelMapItem { 10 | // String name. The most common practice is to set this to a MID or synsets 11 | // id. 12 | optional string name = 1; 13 | 14 | // Integer id that maps to the string name above. Label ids should start from 15 | // 1. 16 | optional int32 id = 2; 17 | 18 | // Human readable string label. 19 | optional string display_name = 3; 20 | }; 21 | 22 | message StringIntLabelMap { 23 | repeated StringIntLabelMapItem item = 1; 24 | }; 25 | -------------------------------------------------------------------------------- /tensorflow/src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | return message.contains("TF GPU device with id 0 was not registered"); 13 | 14 | NEUTRAL 15 | DENY 16 | 17 | 18 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 19 | 20 | 21 | 22 | 23 | 25 | log-${bySecond}.txt 26 | 27 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /tensorflow/src/main/resources/mscoco_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "/m/01g317" 3 | id: 1 4 | display_name: "person" 5 | } 6 | item { 7 | name: "/m/0199g" 8 | id: 2 9 | display_name: "bicycle" 10 | } 11 | item { 12 | name: "/m/0k4j" 13 | id: 3 14 | display_name: "car" 15 | } 16 | item { 17 | name: "/m/04_sv" 18 | id: 4 19 | display_name: "motorcycle" 20 | } 21 | item { 22 | name: "/m/05czz6l" 23 | id: 5 24 | display_name: "airplane" 25 | } 26 | item { 27 | name: "/m/01bjv" 28 | id: 6 29 | display_name: "bus" 30 | } 31 | item { 32 | name: "/m/07jdr" 33 | id: 7 34 | display_name: "train" 35 | } 36 | item { 37 | name: "/m/07r04" 38 | id: 8 39 | display_name: "truck" 40 | } 41 | item { 42 | name: "/m/019jd" 43 | id: 9 44 | display_name: "boat" 45 | } 46 | item { 47 | name: "/m/015qff" 48 | id: 10 49 | display_name: "traffic light" 50 | } 51 | item { 52 | name: "/m/01pns0" 53 | id: 11 54 | display_name: "fire hydrant" 55 | } 56 | item { 57 | name: "/m/02pv19" 58 | id: 13 59 | display_name: "stop sign" 60 | } 61 | item { 62 | name: "/m/015qbp" 63 | id: 14 64 | display_name: "parking meter" 65 | } 66 | item { 67 | name: "/m/0cvnqh" 68 | id: 15 69 | display_name: "bench" 70 | } 71 | item { 72 | name: "/m/015p6" 73 | id: 16 74 | display_name: "bird" 75 | } 76 | item { 77 | name: "/m/01yrx" 78 | id: 17 79 | display_name: "cat" 80 | } 81 | item { 82 | name: "/m/0bt9lr" 83 | id: 18 84 | display_name: "dog" 85 | } 86 | item { 87 | name: "/m/03k3r" 88 | id: 19 89 | display_name: "horse" 90 | } 91 | item { 92 | name: "/m/07bgp" 93 | id: 20 94 | display_name: "sheep" 95 | } 96 | item { 97 | name: "/m/01xq0k1" 98 | id: 21 99 | display_name: "cow" 100 | } 101 | item { 102 | name: "/m/0bwd_0j" 103 | id: 22 104 | display_name: "elephant" 105 | } 106 | item { 107 | name: "/m/01dws" 108 | id: 23 109 | display_name: "bear" 110 | } 111 | item { 112 | name: "/m/0898b" 113 | id: 24 114 | display_name: "zebra" 115 | } 116 | item { 117 | name: "/m/03bk1" 118 | id: 25 119 | display_name: "giraffe" 120 | } 121 | item { 122 | name: "/m/01940j" 123 | id: 27 124 | display_name: "backpack" 125 | } 126 | item { 127 | name: "/m/0hnnb" 128 | id: 28 129 | display_name: "umbrella" 130 | } 131 | item { 132 | name: "/m/080hkjn" 133 | id: 31 134 | display_name: "handbag" 135 | } 136 | item { 137 | name: "/m/01rkbr" 138 | id: 32 139 | display_name: "tie" 140 | } 141 | item { 142 | name: "/m/01s55n" 143 | id: 33 144 | display_name: "suitcase" 145 | } 146 | item { 147 | name: "/m/02wmf" 148 | id: 34 149 | display_name: "frisbee" 150 | } 151 | item { 152 | name: "/m/071p9" 153 | id: 35 154 | display_name: "skis" 155 | } 156 | item { 157 | name: "/m/06__v" 158 | id: 36 159 | display_name: "snowboard" 160 | } 161 | item { 162 | name: "/m/018xm" 163 | id: 37 164 | display_name: "sports ball" 165 | } 166 | item { 167 | name: "/m/02zt3" 168 | id: 38 169 | display_name: "kite" 170 | } 171 | item { 172 | name: "/m/03g8mr" 173 | id: 39 174 | display_name: "baseball bat" 175 | } 176 | item { 177 | name: "/m/03grzl" 178 | id: 40 179 | display_name: "baseball glove" 180 | } 181 | item { 182 | name: "/m/06_fw" 183 | id: 41 184 | display_name: "skateboard" 185 | } 186 | item { 187 | name: "/m/019w40" 188 | id: 42 189 | display_name: "surfboard" 190 | } 191 | item { 192 | name: "/m/0dv9c" 193 | id: 43 194 | display_name: "tennis racket" 195 | } 196 | item { 197 | name: "/m/04dr76w" 198 | id: 44 199 | display_name: "bottle" 200 | } 201 | item { 202 | name: "/m/09tvcd" 203 | id: 46 204 | display_name: "wine glass" 205 | } 206 | item { 207 | name: "/m/08gqpm" 208 | id: 47 209 | display_name: "cup" 210 | } 211 | item { 212 | name: "/m/0dt3t" 213 | id: 48 214 | display_name: "fork" 215 | } 216 | item { 217 | name: "/m/04ctx" 218 | id: 49 219 | display_name: "knife" 220 | } 221 | item { 222 | name: "/m/0cmx8" 223 | id: 50 224 | display_name: "spoon" 225 | } 226 | item { 227 | name: "/m/04kkgm" 228 | id: 51 229 | display_name: "bowl" 230 | } 231 | item { 232 | name: "/m/09qck" 233 | id: 52 234 | display_name: "banana" 235 | } 236 | item { 237 | name: "/m/014j1m" 238 | id: 53 239 | display_name: "apple" 240 | } 241 | item { 242 | name: "/m/0l515" 243 | id: 54 244 | display_name: "sandwich" 245 | } 246 | item { 247 | name: "/m/0cyhj_" 248 | id: 55 249 | display_name: "orange" 250 | } 251 | item { 252 | name: "/m/0hkxq" 253 | id: 56 254 | display_name: "broccoli" 255 | } 256 | item { 257 | name: "/m/0fj52s" 258 | id: 57 259 | display_name: "carrot" 260 | } 261 | item { 262 | name: "/m/01b9xk" 263 | id: 58 264 | display_name: "hot dog" 265 | } 266 | item { 267 | name: "/m/0663v" 268 | id: 59 269 | display_name: "pizza" 270 | } 271 | item { 272 | name: "/m/0jy4k" 273 | id: 60 274 | display_name: "donut" 275 | } 276 | item { 277 | name: "/m/0fszt" 278 | id: 61 279 | display_name: "cake" 280 | } 281 | item { 282 | name: "/m/01mzpv" 283 | id: 62 284 | display_name: "chair" 285 | } 286 | item { 287 | name: "/m/02crq1" 288 | id: 63 289 | display_name: "couch" 290 | } 291 | item { 292 | name: "/m/03fp41" 293 | id: 64 294 | display_name: "potted plant" 295 | } 296 | item { 297 | name: "/m/03ssj5" 298 | id: 65 299 | display_name: "bed" 300 | } 301 | item { 302 | name: "/m/04bcr3" 303 | id: 67 304 | display_name: "dining table" 305 | } 306 | item { 307 | name: "/m/09g1w" 308 | id: 70 309 | display_name: "toilet" 310 | } 311 | item { 312 | name: "/m/07c52" 313 | id: 72 314 | display_name: "tv" 315 | } 316 | item { 317 | name: "/m/01c648" 318 | id: 73 319 | display_name: "laptop" 320 | } 321 | item { 322 | name: "/m/020lf" 323 | id: 74 324 | display_name: "mouse" 325 | } 326 | item { 327 | name: "/m/0qjjc" 328 | id: 75 329 | display_name: "remote" 330 | } 331 | item { 332 | name: "/m/01m2v" 333 | id: 76 334 | display_name: "keyboard" 335 | } 336 | item { 337 | name: "/m/050k8" 338 | id: 77 339 | display_name: "cell phone" 340 | } 341 | item { 342 | name: "/m/0fx9l" 343 | id: 78 344 | display_name: "microwave" 345 | } 346 | item { 347 | name: "/m/029bxz" 348 | id: 79 349 | display_name: "oven" 350 | } 351 | item { 352 | name: "/m/01k6s3" 353 | id: 80 354 | display_name: "toaster" 355 | } 356 | item { 357 | name: "/m/0130jx" 358 | id: 81 359 | display_name: "sink" 360 | } 361 | item { 362 | name: "/m/040b_t" 363 | id: 82 364 | display_name: "refrigerator" 365 | } 366 | item { 367 | name: "/m/0bt_c3" 368 | id: 84 369 | display_name: "book" 370 | } 371 | item { 372 | name: "/m/01x3z" 373 | id: 85 374 | display_name: "clock" 375 | } 376 | item { 377 | name: "/m/02s195" 378 | id: 86 379 | display_name: "vase" 380 | } 381 | item { 382 | name: "/m/01lsmm" 383 | id: 87 384 | display_name: "scissors" 385 | } 386 | item { 387 | name: "/m/0kmg4" 388 | id: 88 389 | display_name: "teddy bear" 390 | } 391 | item { 392 | name: "/m/03wvsk" 393 | id: 89 394 | display_name: "hair drier" 395 | } 396 | item { 397 | name: "/m/012xff" 398 | id: 90 399 | display_name: "toothbrush" 400 | } 401 | -------------------------------------------------------------------------------- /tensorflow/src/main/scala/io/brunk/DatasetSplitter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk 18 | 19 | import better.files.File 20 | 21 | import scala.util.Random.shuffle 22 | 23 | /** Script that splits an image dataset into train/validation/test set 24 | * 25 | * Expects the following structure per class: / 26 | * Outputs each subset into a subdir for training, validation and testset 27 | * 28 | * usage: DatasetSplitter [ ] 29 | * Sizes in % 30 | */ 31 | object DatasetSplitter { 32 | val datasetSplitNames = Seq("train", "validation", "test") 33 | 34 | def main(args: Array[String]): Unit = { 35 | val inputDir = File(args(0)) 36 | val outputDir = File(args(1)) 37 | val splitSizes = args.drop(2).map(_.toFloat).toSeq 38 | 39 | val imgClassDirs = inputDir.list.filter(_.isDirectory).toVector.sortBy(_.name) 40 | 41 | // in case of different numbers of samples per class, use the smallest one 42 | val numSamplesPerClass = imgClassDirs.map(_.glob("*.jpg").size).min 43 | println(s"Number of samples per class (balanced): $numSamplesPerClass") 44 | 45 | val samplesPerClass = { 46 | imgClassDirs.flatMap { imgClassDir => 47 | shuffle(imgClassDir.children.toVector) 48 | .take(numSamplesPerClass) // balance samples to have the same number for each class 49 | .map((imgClassDir.name, _)) 50 | }.groupBy(_._1).mapValues(_.map(_._2)) 51 | } 52 | 53 | val numSamples = samplesPerClass.map(_._2.size).sum 54 | println(s"Number of samples (balanced): $numSamples") 55 | val splitSizesAbsolute = splitSizes.map(_ / 100.0).map(_ * numSamplesPerClass).map(_.toInt) 56 | println(s"Number of samples per split: ${datasetSplitNames.zip(splitSizesAbsolute).mkString(" ")}") 57 | val splitIndices = splitSizesAbsolute 58 | .map(_ -1) 59 | .scanLeft(-1 to -1)((prev, current) => prev.last + 1 to (prev.last + current + 1)).tail // TODO cleaner solution 60 | println(splitIndices) 61 | 62 | val datasetNamesWithIndices = datasetSplitNames.zip(splitIndices) 63 | 64 | val datasetIndices = (for { 65 | (name, indices) <- datasetNamesWithIndices 66 | index <- indices 67 | } yield (index, name)) 68 | .sortBy(_._1) 69 | .map(_._2) 70 | 71 | // create directories 72 | for { 73 | dataset <- datasetSplitNames 74 | imgClassDir <- imgClassDirs 75 | imgClass = imgClassDir.name 76 | } { 77 | (outputDir/dataset/imgClass).createDirectories() 78 | } 79 | // write into train, validation and test folders 80 | for { 81 | (imgClass, samples) <- samplesPerClass 82 | (filename, dataset) <- samples.zip(datasetIndices) 83 | } { 84 | filename.copyTo(outputDir/dataset/imgClass/filename.name) 85 | } 86 | 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /tensorflow/src/main/scala/io/brunk/examples/FashionMnistCNN.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import java.nio.file.Paths 20 | 21 | import com.typesafe.scalalogging.Logger 22 | import org.platanios.tensorflow.api._ 23 | import org.platanios.tensorflow.api.ops.NN.ValidConvPadding 24 | import org.platanios.tensorflow.data.image.MNISTLoader 25 | import org.platanios.tensorflow.data.image.MNISTLoader.FASHION_MNIST 26 | import org.slf4j.LoggerFactory 27 | 28 | /** Simple CNN for classifying handwritten digits from the MNIST dataset 29 | * 30 | * Implemented using TensorFlow for Scala. 31 | * 32 | * @author Sören Brunk 33 | */ 34 | object FashionMnistCNN { 35 | private[this] val logger = Logger(LoggerFactory.getLogger(FashionMnistCNN.getClass)) 36 | 37 | def main(args: Array[String]): Unit = { 38 | 39 | val batchSize = 2048 40 | val numEpochs = 500 41 | 42 | // download and load the MNIST images as tensors 43 | val dataSet = MNISTLoader.load(Paths.get("datasets/Fashion-MNIST"), FASHION_MNIST) 44 | val trainImages = tf.data.TensorSlicesDataset(dataSet.trainImages.expandDims(-1)) 45 | val trainLabels = tf.data.TensorSlicesDataset(dataSet.trainLabels) 46 | val testImages = tf.data.TensorSlicesDataset(dataSet.testImages.expandDims(-1)) 47 | val testLabels = tf.data.TensorSlicesDataset(dataSet.testLabels) 48 | val trainData = 49 | trainImages.zip(trainLabels) 50 | .repeat() 51 | .shuffle(60000) 52 | .batch(batchSize) 53 | .prefetch(10) 54 | val evalTrainData = trainImages.zip(trainLabels).batch(1000).prefetch(10) 55 | val evalTestData = testImages.zip(testLabels).batch(1000).prefetch(10) 56 | 57 | // define the neural network architecture 58 | val input = tf.learn.Input(UINT8, Shape(-1, 28, 28, 1)) // type and shape of our input images 59 | val labelInput = tf.learn.Input(UINT8, Shape(-1)) // type and shape of our labels 60 | 61 | val layer = tf.learn.Cast("Input/Cast", FLOAT32) >> 62 | tf.learn.Conv2D("Layer_0/Conv2D", Shape(3, 3, 1, 32), stride1 = 1, stride2 = 1, ValidConvPadding) >> 63 | tf.learn.AddBias("Layer_0/Bias") >> 64 | tf.learn.ReLU("Layer_0/ReLU", 0.1f) >> 65 | tf.learn.MaxPool("Layer_1/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, ValidConvPadding) >> 66 | tf.learn.Conv2D("Layer_1/Conv2D", Shape(5, 5, 32, 64), stride1 = 2, stride2 = 2, ValidConvPadding) >> 67 | tf.learn.AddBias("Layer_1/Bias") >> 68 | tf.learn.ReLU("Layer_1/ReLU", 0.1f) >> 69 | tf.learn.MaxPool("Layer_1/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, ValidConvPadding) >> 70 | tf.learn.Flatten("Input/Flatten") >> 71 | tf.learn.Linear("Layer_2/Linear", units = 512) >> // hidden layer 72 | tf.learn.ReLU("Layer_2/ReLU", 0.1f) >> // hidden layer activation 73 | tf.learn.Dropout("Layer_2/Dropout", keepProbability = 0.8f) >> // dropout 74 | tf.learn.Linear("OutputLayer/Linear", units = 10) // output layer 75 | 76 | val trainInputLayer = tf.learn.Cast("TrainInput/Cast", INT64) // cast labels to long 77 | 78 | val loss = tf.learn.SparseSoftmaxCrossEntropy("Loss/CrossEntropy") >> // loss/error function 79 | tf.learn.Mean("Loss/Mean") >> tf.learn.ScalarSummary("Loss/Summary", "Loss") 80 | val optimizer = tf.train.Adam(learningRate = 0.001) // the optimizer updates our weights 81 | 82 | val model = tf.learn.Model.supervised(input, layer, labelInput, trainInputLayer, loss, optimizer) 83 | 84 | val summariesDir = Paths.get("temp/fashion-mnist-cnn") 85 | val accMetric = tf.metrics.MapMetric( 86 | (v: (Output, Output)) => (v._1.argmax(-1), v._2), tf.metrics.Accuracy()) 87 | val estimator = tf.learn.InMemoryEstimator( 88 | model, 89 | tf.learn.Configuration(Some(summariesDir)), 90 | tf.learn.StopCriteria(maxSteps = Some((60000/batchSize)*numEpochs)), // due to a bug, we can't use epochs directly 91 | Set( 92 | tf.learn.LossLogger(trigger = tf.learn.StepHookTrigger(100)), 93 | tf.learn.Evaluator( 94 | log = true, datasets = Seq(("Train", () => evalTrainData), ("Test", () => evalTestData)), 95 | metrics = Seq(accMetric), trigger = tf.learn.StepHookTrigger(1000), name = "Evaluator", summaryDir = summariesDir), 96 | tf.learn.StepRateLogger(log = false, summaryDir = summariesDir, trigger = tf.learn.StepHookTrigger(100)), 97 | tf.learn.SummarySaver(summariesDir, tf.learn.StepHookTrigger(100)), 98 | tf.learn.CheckpointSaver(summariesDir, tf.learn.StepHookTrigger(1000))), 99 | tensorBoardConfig = tf.learn.TensorBoardConfig(summariesDir, reloadInterval = 1)) 100 | 101 | // train the model 102 | estimator.train(() => trainData) 103 | 104 | 105 | def accuracy(images: Tensor, labels: Tensor): Float = { 106 | val predictions = estimator.infer(() => images) 107 | predictions.argmax(1).cast(UINT8).equal(labels).cast(FLOAT32).mean().scalar.asInstanceOf[Float] 108 | } 109 | 110 | // evaluate model performance 111 | logger.info(s"Train accuracy = ${accuracy(dataSet.trainImages.expandDims(-1), dataSet.trainLabels)}") 112 | logger.info(s"Test accuracy = ${accuracy(dataSet.testImages.expandDims(-1), dataSet.testLabels)}") 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /tensorflow/src/main/scala/io/brunk/examples/FashionMnistMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import java.nio.file.Paths 20 | 21 | import com.typesafe.scalalogging.Logger 22 | import org.platanios.tensorflow.api._ 23 | import org.platanios.tensorflow.data.image.MNISTLoader 24 | import org.platanios.tensorflow.data.image.MNISTLoader.FASHION_MNIST 25 | import org.slf4j.LoggerFactory 26 | 27 | /** Simple multilayer perceptron for classifying handwritten digits from the Fashion MNIST dataset 28 | * 29 | * Implemented using TensorFlow for Scala based on the example from 30 | * https://github.com/eaplatanios/tensorflow_scala/blob/0b7ca14de53935a34deac29802d085729228c4fe/examples/src/main/scala/org/platanios/tensorflow/examples/MNIST.scala 31 | * 32 | * @author Sören Brunk 33 | */ 34 | object FashionMnistMLP { 35 | private[this] val logger = Logger(LoggerFactory.getLogger(FashionMnistMLP.getClass)) 36 | 37 | def main(args: Array[String]): Unit = { 38 | 39 | val batchSize = 2048 40 | val numEpochs = 500 41 | 42 | // download and load the MNIST images as tensors 43 | val dataSet = MNISTLoader.load(Paths.get("datasets/Fashion-MNIST"), FASHION_MNIST) 44 | val trainImages = tf.data.TensorSlicesDataset(dataSet.trainImages) 45 | val trainLabels = tf.data.TensorSlicesDataset(dataSet.trainLabels) 46 | val testImages = tf.data.TensorSlicesDataset(dataSet.testImages) 47 | val testLabels = tf.data.TensorSlicesDataset(dataSet.testLabels) 48 | val trainData = 49 | trainImages.zip(trainLabels) 50 | .repeat() 51 | .shuffle(60000) 52 | .batch(batchSize) 53 | .prefetch(10) 54 | val evalTrainData = trainImages.zip(trainLabels).batch(1000).prefetch(10) 55 | val evalTestData = testImages.zip(testLabels).batch(1000).prefetch(10) 56 | 57 | // define the neural network architecture 58 | val input = tf.learn.Input(UINT8, Shape(-1, 28, 28)) // type and shape of our input images 59 | val labelInput = tf.learn.Input(UINT8, Shape(-1)) // type and shape of our labels 60 | 61 | val layer = tf.learn.Flatten("Input/Flatten") >> // flatten the images into a single vector 62 | tf.learn.Cast("Input/Cast", FLOAT32) >> // cast input to float 63 | tf.learn.Linear("Layer_1/Linear", units = 512) >> // hidden layer 64 | tf.learn.ReLU("Layer_1/ReLU", 0.1f) >> // hidden layer activation 65 | tf.learn.Linear("OutputLayer/Linear", units = 10) // output layer 66 | 67 | val trainInputLayer = tf.learn.Cast("TrainInput/Cast", INT64) // cast labels to long 68 | 69 | val loss = tf.learn.SparseSoftmaxCrossEntropy("Loss/CrossEntropy") >> // loss/error function 70 | tf.learn.Mean("Loss/Mean") >> tf.learn.ScalarSummary("Loss/Summary", "Loss") 71 | val optimizer = tf.train.Adam(learningRate = 0.001) // the optimizer updates our weights 72 | 73 | val model = tf.learn.Model.supervised(input, layer, labelInput, trainInputLayer, loss, optimizer) 74 | 75 | val summariesDir = Paths.get("temp/fashion-mnist-mlp") 76 | val accMetric = tf.metrics.MapMetric( 77 | (v: (Output, Output)) => (v._1.argmax(-1), v._2), tf.metrics.Accuracy()) 78 | val estimator = tf.learn.InMemoryEstimator( 79 | model, 80 | tf.learn.Configuration(Some(summariesDir)), 81 | tf.learn.StopCriteria(maxSteps = Some((60000/batchSize)*numEpochs)), // due to a bug, we can't use epochs directly 82 | Set( 83 | tf.learn.LossLogger(trigger = tf.learn.StepHookTrigger(100)), 84 | tf.learn.Evaluator( 85 | log = true, datasets = Seq(("Train", () => evalTrainData), ("Test", () => evalTestData)), 86 | metrics = Seq(accMetric), trigger = tf.learn.StepHookTrigger(500), name = "Evaluator", summaryDir = summariesDir), 87 | tf.learn.StepRateLogger(log = false, summaryDir = summariesDir, trigger = tf.learn.StepHookTrigger(100)), 88 | tf.learn.SummarySaver(summariesDir, tf.learn.StepHookTrigger(100)), 89 | tf.learn.CheckpointSaver(summariesDir, tf.learn.StepHookTrigger(500))), 90 | tensorBoardConfig = tf.learn.TensorBoardConfig(summariesDir, reloadInterval = 1)) 91 | 92 | // train the model 93 | estimator.train(() => trainData) 94 | 95 | def accuracy(images: Tensor, labels: Tensor): Float = { 96 | val predictions = estimator.infer(() => images) 97 | predictions.argmax(1).cast(UINT8).equal(labels).cast(FLOAT32).mean().scalar.asInstanceOf[Float] 98 | } 99 | 100 | // evaluate model performance 101 | logger.info(s"Train accuracy = ${accuracy(dataSet.trainImages, dataSet.trainLabels)}") 102 | logger.info(s"Test accuracy = ${accuracy(dataSet.testImages, dataSet.testLabels)}") 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /tensorflow/src/main/scala/io/brunk/examples/IrisMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import java.nio.file.Paths 20 | 21 | import com.typesafe.scalalogging.Logger 22 | import org.platanios.tensorflow.api.{tf, _} 23 | import org.platanios.tensorflow.api.tf._ 24 | import org.platanios.tensorflow.api.ops.io.data.Dataset 25 | import org.platanios.tensorflow.api.types.DataType 26 | import org.slf4j.LoggerFactory 27 | 28 | object IrisMLP { 29 | 30 | private[this] val logger = Logger(LoggerFactory.getLogger("Examples / Iris")) 31 | 32 | def main(args: Array[String]): Unit = { 33 | 34 | val seed = 1 // for reproducibility 35 | val numInputs = 4 36 | val numHidden = 10 37 | val numOutputs = 3 38 | val learningRate = 0.1 39 | val iterations = 1000 40 | val trainSize = 100 41 | val testSize = 50 42 | 43 | // Read CSV using plain Scala 44 | // val source = scala.io.Source.fromFile("iris.csv") 45 | // val rows = for (l <- source.getLines().drop(1)) yield { 46 | // val cols = l.split(",").map(_.trim).map(_.toFloat) 47 | // Tensor(cols).reshape(Shape(5)) 48 | // } 49 | // val data = Tensor(rows.toArray).reshape(Shape(150, 5)) 50 | // val features = data.slice(::, 0 :: 4) 51 | // val labels = data.slice(::, 4) 52 | // val dataset = tf.data.TensorSlicesDataset((features, labels)) 53 | // .shuffle(150, Some(42)) 54 | 55 | // Read CSV using TensorFlow operations 56 | val dataset: Dataset[(Tensor, Tensor), (Output, Output), (DataType, DataType), (Shape, Shape)] = 57 | tf.data.TextLinesDataset("data/iris.csv") 58 | .drop(1) 59 | .map { l => 60 | val csv = tf.decodeCSV(l, Seq.fill(5)(Tensor(FLOAT32)), Seq.fill(5)(FLOAT32)) 61 | (tf.stack(csv.take(4)), csv(4)) 62 | } 63 | .shuffle(150, Some(seed)) 64 | 65 | val trainDataset = dataset.take(trainSize) 66 | val testDataset = dataset.drop(trainSize) 67 | 68 | val trainData = trainDataset.repeat().batch(trainSize) 69 | val evalTrainData = trainDataset.batch(trainSize) 70 | val evalTestData = testDataset.batch(testSize) 71 | 72 | 73 | val input = tf.learn.Input(FLOAT32, Shape(-1, trainDataset.outputShapes._1(0))) 74 | val trainInput = tf.learn.Input(FLOAT32, Shape(-1)) 75 | val layer = 76 | tf.learn.Linear("Layer_0/Linear", numHidden) >> tf.learn.ReLU("Layer_0/ReLU") >> 77 | tf.learn.Linear("OutputLayer/Linear", numOutputs) 78 | val trainingInputLayer = tf.learn.Cast("TrainInput/Cast", INT64) 79 | val loss = tf.learn.SparseSoftmaxCrossEntropy("Loss/CrossEntropy") >> 80 | tf.learn.Mean("Loss/Mean") >> tf.learn.ScalarSummary("Loss/Summary", "Loss") 81 | val optimizer = tf.train.GradientDescent(learningRate) 82 | val model = tf.learn.Model.supervised(input, layer, trainInput, trainingInputLayer, loss, optimizer) 83 | 84 | val summariesDir = Paths.get("temp/iris-mlp") 85 | val accMetric = tf.metrics.MapMetric( 86 | (v: (Output, Output)) => (v._1.argmax(1), v._2), tf.metrics.Accuracy()) 87 | val estimator = tf.learn.InMemoryEstimator( 88 | model, 89 | tf.learn.Configuration(Some(summariesDir)), 90 | tf.learn.StopCriteria(maxSteps = Some(iterations)), 91 | Set( 92 | tf.learn.LossLogger(trigger = tf.learn.StepHookTrigger(100)), 93 | tf.learn.Evaluator( 94 | log = true, datasets = Seq(("Train", () => evalTrainData), ("Test", () => evalTestData)), 95 | metrics = Seq(accMetric), trigger = tf.learn.StepHookTrigger(1000), name = "Evaluator"), 96 | tf.learn.StepRateLogger(log = false, summaryDir = summariesDir, trigger = tf.learn.StepHookTrigger(100)), 97 | tf.learn.SummarySaver(summariesDir, tf.learn.StepHookTrigger(100)), 98 | tf.learn.CheckpointSaver(summariesDir, tf.learn.StepHookTrigger(100))), 99 | tensorBoardConfig = tf.learn.TensorBoardConfig(summariesDir, reloadInterval = 1)) 100 | 101 | estimator.train(() => trainData) 102 | } 103 | 104 | } 105 | -------------------------------------------------------------------------------- /tensorflow/src/main/scala/io/brunk/examples/MnistMLP.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import org.platanios.tensorflow.api._ 20 | import org.platanios.tensorflow.data.image.MNISTLoader 21 | import com.typesafe.scalalogging.Logger 22 | import org.slf4j.LoggerFactory 23 | import java.nio.file.Paths 24 | 25 | import org.platanios.tensorflow.api.ops.variables.GlorotUniformInitializer 26 | 27 | /** Simple multilayer perceptron for classifying handwritten digits from the MNIST dataset 28 | * 29 | * Implemented using TensorFlow for Scala based on the example from 30 | * https://github.com/eaplatanios/tensorflow_scala/blob/0b7ca14de53935a34deac29802d085729228c4fe/examples/src/main/scala/org/platanios/tensorflow/examples/MNIST.scala 31 | * 32 | * @author Sören Brunk 33 | */ 34 | object MnistMLP { 35 | private[this] val logger = Logger(LoggerFactory.getLogger(MnistMLP.getClass)) 36 | 37 | def main(args: Array[String]): Unit = { 38 | 39 | val numHidden = 512 // size (number of neurons) of our hidden layer 40 | val numOutputs = 10 // digits from 0 to 9 41 | val learningRate = 0.01 42 | val batchSize = 128 43 | val numEpochs = 10 44 | 45 | // download and load the MNIST images as tensors 46 | val dataSet = MNISTLoader.load(Paths.get("datasets/MNIST")) 47 | val trainImages = tf.data.TensorSlicesDataset(dataSet.trainImages) 48 | val trainLabels = tf.data.TensorSlicesDataset(dataSet.trainLabels) 49 | val testImages = tf.data.TensorSlicesDataset(dataSet.testImages) 50 | val testLabels = tf.data.TensorSlicesDataset(dataSet.testLabels) 51 | val trainData = 52 | trainImages.zip(trainLabels) 53 | .repeat() 54 | .shuffle(10000) 55 | .batch(batchSize) 56 | .prefetch(10) 57 | val evalTrainData = trainImages.zip(trainLabels).batch(1000).prefetch(10) 58 | val evalTestData = testImages.zip(testLabels).batch(1000).prefetch(10) 59 | 60 | // define the neural network architecture 61 | val input = tf.learn.Input(UINT8, Shape(-1, dataSet.trainImages.shape(1), dataSet.trainImages.shape(2))) // type and shape of images 62 | val trainInput = tf.learn.Input(UINT8, Shape(-1)) // type and shape of labels 63 | 64 | val layer = tf.learn.Flatten("Input/Flatten") >> // flatten the images into a single vector 65 | tf.learn.Cast("Input/Cast", FLOAT32) >> 66 | tf.learn.Linear("Layer_1/Linear", numHidden, weightsInitializer = GlorotUniformInitializer()) >> // hidden layer 67 | tf.learn.ReLU("Layer_1/ReLU") >> // hidden layer activation 68 | tf.learn.Linear("OutputLayer/Linear", numOutputs, weightsInitializer = GlorotUniformInitializer()) // output layer 69 | 70 | val trainingInputLayer = tf.learn.Cast("TrainInput/Cast", INT64) // cast labels to long 71 | 72 | val loss = tf.learn.SparseSoftmaxCrossEntropy("Loss/CrossEntropy") >> 73 | tf.learn.Mean("Loss/Mean") >> tf.learn.ScalarSummary("Loss/Summary", "Loss") 74 | val optimizer = tf.train.GradientDescent(learningRate) 75 | 76 | val model = tf.learn.Model.supervised(input, layer, trainInput, trainingInputLayer, loss, optimizer) 77 | 78 | val summariesDir = Paths.get("temp/mnist-mlp") 79 | val accMetric = tf.metrics.MapMetric( 80 | (v: (Output, Output)) => (v._1.argmax(-1), v._2), tf.metrics.Accuracy()) 81 | val estimator = tf.learn.InMemoryEstimator( 82 | model, 83 | tf.learn.Configuration(Some(summariesDir)), 84 | tf.learn.StopCriteria(maxSteps = Some((60000/batchSize)*numEpochs)), // due to a bug, we can't use epochs directly 85 | Set( 86 | tf.learn.LossLogger(trigger = tf.learn.StepHookTrigger(100)), 87 | tf.learn.Evaluator( 88 | log = true, datasets = Seq(("Train", () => evalTrainData), ("Test", () => evalTestData)), 89 | metrics = Seq(accMetric), trigger = tf.learn.StepHookTrigger(1000), name = "Evaluator"), 90 | tf.learn.StepRateLogger(log = false, summaryDir = summariesDir, trigger = tf.learn.StepHookTrigger(100)), 91 | tf.learn.SummarySaver(summariesDir, tf.learn.StepHookTrigger(100)), 92 | tf.learn.CheckpointSaver(summariesDir, tf.learn.StepHookTrigger(1000))), 93 | tensorBoardConfig = tf.learn.TensorBoardConfig(summariesDir, reloadInterval = 1)) 94 | 95 | // train the model 96 | estimator.train(() => trainData) 97 | 98 | def accuracy(images: Tensor, labels: Tensor): Float = { 99 | val predictions = estimator.infer(() => images) 100 | predictions.argmax(1).cast(UINT8).equal(labels).cast(FLOAT32).mean().scalar.asInstanceOf[Float] 101 | } 102 | 103 | // evaluate model performance 104 | logger.info(s"Train accuracy = ${accuracy(dataSet.trainImages, dataSet.trainLabels)}") 105 | logger.info(s"Test accuracy = ${accuracy(dataSet.testImages, dataSet.testLabels)}") 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /tensorflow/src/main/scala/io/brunk/examples/ObjectDetector.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import java.io.{BufferedInputStream, File, FileInputStream} 20 | import java.nio.ByteBuffer 21 | import javax.swing.JFrame 22 | 23 | import object_detection.protos.string_int_label_map.{StringIntLabelMap, StringIntLabelMapItem} 24 | import org.bytedeco.javacpp.opencv_core.{FONT_HERSHEY_PLAIN, LINE_AA, Mat, Point, Scalar} 25 | import org.bytedeco.javacpp.opencv_imgcodecs._ 26 | import org.bytedeco.javacpp.opencv_imgproc.{COLOR_BGR2RGB, cvtColor, putText, rectangle} 27 | import org.bytedeco.javacv.{CanvasFrame, FFmpegFrameGrabber, FrameGrabber, OpenCVFrameConverter, OpenCVFrameGrabber} 28 | import org.platanios.tensorflow.api.{Graph, Session, Shape, Tensor, UINT8} 29 | import org.tensorflow.framework.GraphDef 30 | 31 | import scala.collection.Iterator.continually 32 | import scala.io.Source 33 | 34 | case class DetectionOutput(boxes: Tensor, scores: Tensor, classes: Tensor, num: Tensor) 35 | 36 | /** 37 | * This example shows how to run a pretrained TensorFlow object detection model i.e. one from 38 | * https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 39 | * 40 | * You have to download and extract the model you want to run first, like so: 41 | * $ cd tensorflow 42 | * $ mkdir models && cd models 43 | * $ wget http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz 44 | * $ tar xzf ssd_inception_v2_coco_2017_11_17.tar.gz 45 | * 46 | * @author Sören Brunk 47 | */ 48 | object ObjectDetector { 49 | 50 | def main(args: Array[String]): Unit = { 51 | 52 | def printUsageAndExit(): Unit = { 53 | Console.err.println( 54 | """ 55 | |Usage: ObjectDetector image |video |camera [] 56 | | path to an image/video file 57 | | camera device number (usually starts with 0) 58 | | optional path to the object detection model to be used. Default: ssd_inception_v2_coco_2017_11_17 59 | |""".stripMargin.trim) 60 | sys.exit(2) 61 | } 62 | 63 | if (args.length < 2) printUsageAndExit() 64 | 65 | val modelDir = args.lift(2).getOrElse("ssd_inception_v2_coco_2017_11_17") 66 | // load a pretrained detection model as TensorFlow graph 67 | val graphDef = GraphDef.parseFrom( 68 | new BufferedInputStream(new FileInputStream(new File(new File("models", modelDir), "frozen_inference_graph.pb")))) 69 | val graph = Graph.fromGraphDef(graphDef) 70 | 71 | // create a session and add our pretrained graph to it 72 | val session = Session(graph) 73 | 74 | // load the protobuf label map containing the class number to string label mapping (from COCO) 75 | val labelMap: Map[Int, String] = { 76 | val pbText = Source.fromResource("mscoco_label_map.pbtxt").mkString 77 | val stringIntLabelMap = StringIntLabelMap.fromAscii(pbText) 78 | stringIntLabelMap.item.collect { 79 | case StringIntLabelMapItem(_, Some(id), Some(displayName)) => id -> displayName 80 | }.toMap 81 | } 82 | 83 | val inputType = args(0) 84 | inputType match { 85 | case "image" => 86 | val image = imread(args(1)) 87 | detectImage(image, graph, session, labelMap) 88 | case "video" => 89 | val grabber = new FFmpegFrameGrabber(args(1)) 90 | detectSequence(grabber, graph, session, labelMap) 91 | case "camera" => 92 | val cameraDevice = Integer.parseInt(args(1)) 93 | val grabber = new OpenCVFrameGrabber(cameraDevice) 94 | detectSequence(grabber, graph, session, labelMap) 95 | case _ => printUsageAndExit() 96 | } 97 | } 98 | 99 | // convert OpenCV tensor to TensorFlow tensor 100 | def matToTensor(image: Mat): Tensor = { 101 | val imageRGB = new Mat 102 | cvtColor(image, imageRGB, COLOR_BGR2RGB) // convert channels from OpenCV GBR to RGB 103 | val imgBuffer = imageRGB.createBuffer[ByteBuffer] 104 | val shape = Shape(1, image.size.height, image.size.width(), image.channels) 105 | Tensor.fromBuffer(UINT8, shape, imgBuffer.capacity, imgBuffer) 106 | } 107 | 108 | // run detector on a single image 109 | def detectImage(image: Mat, graph: Graph, session: Session, labelMap: Map[Int, String]): Unit = { 110 | val canvasFrame = new CanvasFrame("Object Detection") 111 | canvasFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE) // exit when the canvas frame is closed 112 | canvasFrame.setCanvasSize(image.size.width, image.size.height) 113 | val detectionOutput = detect(matToTensor(image), graph, session) 114 | drawBoundingBoxes(image, labelMap, detectionOutput) 115 | canvasFrame.showImage(new OpenCVFrameConverter.ToMat().convert(image)) 116 | canvasFrame.waitKey(0) 117 | canvasFrame.dispose() 118 | } 119 | 120 | // run detector on an image sequence 121 | def detectSequence(grabber: FrameGrabber, graph: Graph, session: Session, labelMap: Map[Int, String]): Unit = { 122 | val canvasFrame = new CanvasFrame("Object Detection", CanvasFrame.getDefaultGamma / grabber.getGamma) 123 | canvasFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE) // exit when the canvas frame is closed 124 | val converter = new OpenCVFrameConverter.ToMat() 125 | grabber.start() 126 | for (frame <- continually(grabber.grab()).takeWhile(_ != null 127 | && (grabber.getLengthInFrames == 0 || grabber.getFrameNumber < grabber.getLengthInFrames))) { 128 | val image = converter.convert(frame) 129 | if (image != null) { // sometimes the first few frames are empty so we ignore them 130 | val detectionOutput = detect(matToTensor(image), graph, session) // run our model 131 | drawBoundingBoxes(image, labelMap, detectionOutput) 132 | if (canvasFrame.isVisible) { // show our frame in the preview 133 | canvasFrame.showImage(frame) 134 | } 135 | } 136 | } 137 | canvasFrame.dispose() 138 | grabber.stop() 139 | } 140 | 141 | // run the object detection model on an image 142 | def detect(image: Tensor, graph: Graph, session: Session): DetectionOutput = { 143 | 144 | // retrieve the output placeholders 145 | val imagePlaceholder = graph.getOutputByName("image_tensor:0") 146 | val detectionBoxes = graph.getOutputByName("detection_boxes:0") 147 | val detectionScores = graph.getOutputByName("detection_scores:0") 148 | val detectionClasses = graph.getOutputByName("detection_classes:0") 149 | val numDetections = graph.getOutputByName("num_detections:0") 150 | 151 | // set image as input parameter 152 | val feeds = Map(imagePlaceholder -> image) 153 | 154 | // Run the detection model 155 | val Seq(boxes, scores, classes, num) = 156 | session.run(fetches = Seq(detectionBoxes, detectionScores, detectionClasses, numDetections), feeds = feeds) 157 | DetectionOutput(boxes, scores, classes, num) 158 | } 159 | 160 | // draw boxes with class and score around detected objects 161 | def drawBoundingBoxes(image: Mat, labelMap: Map[Int, String], detectionOutput: DetectionOutput): Unit = { 162 | for (i <- 0 until detectionOutput.boxes.shape.size(1)) { 163 | val score = detectionOutput.scores(0, i).scalar.asInstanceOf[Float] 164 | 165 | if (score > 0.5) { 166 | val box = detectionOutput.boxes(0, i).entriesIterator.map(_.asInstanceOf[Float]).toSeq 167 | // we have to scale the box coordinates to the image size 168 | val ymin = (box(0) * image.size().height()).toInt 169 | val xmin = (box(1) * image.size().width()).toInt 170 | val ymax = (box(2) * image.size().height()).toInt 171 | val xmax = (box(3) * image.size().width()).toInt 172 | val label = labelMap.getOrElse(detectionOutput.classes(0, i).scalar.asInstanceOf[Float].toInt, "unknown") 173 | 174 | // draw score value 175 | putText(image, 176 | f"$label%s ($score%1.2f)", // text 177 | new Point(xmin + 6, ymin + 38), // text position 178 | FONT_HERSHEY_PLAIN, // font type 179 | 2.6, // font scale 180 | new Scalar(0, 0, 0, 0), // text color 181 | 4, // text thickness 182 | LINE_AA, // line type 183 | false) // origin is at the top-left corner 184 | putText(image, 185 | f"$label%s ($score%1.2f)", // text 186 | new Point(xmin + 4, ymin + 36), // text position 187 | FONT_HERSHEY_PLAIN, // font type 188 | 2.6, // font scale 189 | new Scalar(0, 230, 255, 0), // text color 190 | 4, // text thickness 191 | LINE_AA, // line type 192 | false) // origin is at the top-left corner 193 | // draw bounding box 194 | rectangle(image, 195 | new Point(xmin + 1, ymin + 1), // upper left corner 196 | new Point(xmax + 1, ymax + 1), // lower right corner 197 | new Scalar(0, 0, 0, 0), // color 198 | 2, // thickness 199 | 0, // lineType 200 | 0) // shift 201 | rectangle(image, 202 | new Point(xmin, ymin), // upper left corner 203 | new Point(xmax, ymax), // lower right corner 204 | new Scalar(0, 230, 255, 0), // color 205 | 2, // thickness 206 | 0, // lineType 207 | 0) // shift 208 | } 209 | } 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /tensorflow/src/main/scala/io/brunk/examples/SimpleCNN.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import java.nio.ByteBuffer 20 | import java.nio.file.Paths 21 | 22 | import com.typesafe.scalalogging.Logger 23 | import org.platanios.tensorflow.api.{DataType, _} 24 | import org.platanios.tensorflow.api.ops.NN.ValidConvPadding 25 | import org.slf4j.LoggerFactory 26 | import better.files._ 27 | import javax.swing.JFrame 28 | import org.bytedeco.javacpp.opencv_core.{Mat, Point, Scalar} 29 | import org.bytedeco.javacpp.opencv_imgproc.{COLOR_BGR2RGB, cvtColor, resize} 30 | import org.bytedeco.javacv._ 31 | 32 | import scala.util.Random.shuffle 33 | import org.platanios.tensorflow.api.ops.io.data.{Dataset, TensorSlicesDataset} 34 | import org.platanios.tensorflow.api.ops.variables.GlorotUniformInitializer 35 | import org.bytedeco.javacpp.opencv_imgproc.{COLOR_BGR2RGB, cvtColor, putText, rectangle} 36 | import org.bytedeco.javacpp.opencv_core.{repeat => _, _} 37 | import org.bytedeco.javacpp.opencv_imgcodecs.imread 38 | 39 | import scala.collection.Iterator.continually 40 | 41 | /** 42 | * CNN for image classification example 43 | * 44 | * @author Sören Brunk 45 | */ 46 | object SimpleCNN { 47 | 48 | private[this] val logger = Logger(LoggerFactory.getLogger(getClass)) 49 | 50 | def main(args: Array[String]): Unit = { 51 | 52 | val seed = 42 53 | val batchSize = 64 54 | 55 | val dataDir = File(args(0)) 56 | val mode = args(1) // train or infer 57 | 58 | // define the neural network architecture 59 | val input = tf.learn.Input(UINT8, Shape(-1, 250, 250, 3)) // type and shape of images 60 | val trainInput = tf.learn.Input(UINT8, Shape(-1)) // type and shape of labels 61 | 62 | val modelIndex = args(2).toInt 63 | val layers = SimpleCNNModels.models(modelIndex) 64 | 65 | val labelMap = Seq("not_scala", "scala") 66 | 67 | val trainInputLayer = tf.learn.Cast("TrainInput/Cast", INT64) // cast labels to long 68 | 69 | val loss = tf.learn.SparseSoftmaxCrossEntropy("Loss/CrossEntropy") >> 70 | tf.learn.Mean("Loss/Mean") >> tf.learn.ScalarSummary("Loss/Summary", "Loss") 71 | val optimizer = tf.train.Adam(0.001f) 72 | 73 | val model = tf.learn.Model.supervised(input, layers, trainInput, trainInputLayer, loss, optimizer) 74 | 75 | val summariesDir = Paths.get(s"temp/logo-classifier-v$modelIndex") 76 | val accMetric = tf.metrics.MapMetric( 77 | (v: (Output, Output)) => (v._1.argmax(-1), v._2), tf.metrics.Accuracy()) 78 | 79 | mode match { 80 | case "train" => train() 81 | case "infer" => infer() 82 | } 83 | 84 | // train the model 85 | def train(): Unit = { 86 | val trainDir = dataDir / "train" 87 | val testDir = dataDir / "validation" 88 | val imgClassDirs = trainDir.list.filter(_.isDirectory).toVector.sortBy(_.name) 89 | val numClasses = imgClassDirs.size 90 | logger.info("Number of classes {}", numClasses) 91 | 92 | val numericLabelForClass = imgClassDirs.map(_.name).zipWithIndex.toMap 93 | logger.info("classes {}", numericLabelForClass) 94 | 95 | def filenamesWithLabels(dir: File): (Tensor, Tensor) = { 96 | val (filenames, labels) = (for { 97 | dir <- dir.children.filter(_.isDirectory) 98 | filename <- dir.glob("*.jpg").map(_.pathAsString) 99 | } yield (filename, numericLabelForClass(dir.name))).toVector.unzip 100 | (Tensor(filenames).squeeze(Seq(0)), Tensor(UINT8, labels).squeeze(Seq(0)) 101 | ) 102 | } 103 | 104 | def readImage(filename: Output): Output = { 105 | val rawImage = tf.data.readFile(filename) 106 | val image = tf.image.decodeJpeg(rawImage, numChannels = 3) 107 | tf.image.resizeBilinear(image.expandDims(axis = 0), Seq(250, 250)).squeeze(Seq(0)).cast(UINT8) 108 | } 109 | 110 | val trainData: Dataset[(Tensor, Tensor), (Output, Output), (DataType, DataType), (Shape, Shape)] = 111 | tf.data.TensorSlicesDataset(filenamesWithLabels(trainDir)) 112 | .shuffle(bufferSize = 30000, Some(seed)) 113 | .map({ case (filename, label) => (readImage(filename), label)}, numParallelCalls = 16) 114 | .cache("") 115 | .repeat() 116 | .batch(batchSize) 117 | .prefetch(100) 118 | 119 | val evalTrainData: Dataset[(Tensor, Tensor), (Output, Output), (DataType, DataType), (Shape, Shape)] = 120 | tf.data.TensorSlicesDataset(filenamesWithLabels(trainDir)) 121 | .shuffle(bufferSize = 30000, Some(seed)) 122 | .map({ case (filename, label) => (readImage(filename), label)}, numParallelCalls = 16) 123 | .take(2000) 124 | .cache("") 125 | .batch(128) 126 | .prefetch(100) 127 | 128 | val evalTestData: Dataset[(Tensor, Tensor), (Output, Output), (DataType, DataType), (Shape, Shape)] = 129 | tf.data.TensorSlicesDataset(filenamesWithLabels(testDir)) 130 | .shuffle(bufferSize = 2000, Some(seed)) 131 | .map({ case (filename, label) => (readImage(filename), label)}, numParallelCalls = 16) 132 | .cache("") 133 | .batch(128) 134 | .prefetch(100) 135 | 136 | val estimator = tf.learn.InMemoryEstimator( 137 | model, 138 | tf.learn.Configuration(Some(summariesDir)), 139 | tf.learn.StopCriteria(maxSteps = Some(10000)), 140 | Set( 141 | tf.learn.LossLogger(trigger = tf.learn.StepHookTrigger(100)), 142 | tf.learn.Evaluator( 143 | log = true, datasets = Seq(("Train", () => evalTrainData), ("Test", () => evalTestData)), 144 | metrics = Seq(accMetric), trigger = tf.learn.StepHookTrigger(100), name = "Evaluator", 145 | summaryDir = summariesDir), 146 | tf.learn.StepRateLogger(log = false, summaryDir = summariesDir, trigger = tf.learn.StepHookTrigger(100)), 147 | tf.learn.SummarySaver(summariesDir, tf.learn.StepHookTrigger(100)), 148 | tf.learn.CheckpointSaver(summariesDir, tf.learn.StepHookTrigger(100))), 149 | tensorBoardConfig = tf.learn.TensorBoardConfig(summariesDir, reloadInterval = 1)) 150 | 151 | estimator.train(() => trainData) 152 | } 153 | 154 | def infer(): Unit = { 155 | 156 | val estimator = tf.learn.InMemoryEstimator(model, tf.learn.Configuration(Some(summariesDir))) 157 | 158 | val inputType = args(3) 159 | val input = args(4) 160 | inputType match { 161 | case "image" => 162 | val image = imread(input) 163 | detectImage(image) 164 | case "video" => 165 | val grabber = new FFmpegFrameGrabber(input) 166 | detectSequence(grabber) 167 | case "camera" => 168 | val cameraDevice = Integer.parseInt(input) 169 | val grabber = new OpenCVFrameGrabber(cameraDevice) 170 | detectSequence(grabber) 171 | case _ => sys.exit(1) 172 | } 173 | 174 | // convert OpenCV tensor to TensorFlow tensor 175 | def matToTensor(image: Mat): Tensor = { 176 | val imageRGB = new Mat 177 | cvtColor(image, imageRGB, COLOR_BGR2RGB) // convert channels from OpenCV GBR to RGB 178 | val imgBuffer = imageRGB.createBuffer[ByteBuffer] 179 | val shape = Shape(1, image.size.height, image.size.width(), image.channels) 180 | Tensor.fromBuffer(UINT8, shape, imgBuffer.capacity, imgBuffer) 181 | } 182 | 183 | def drawLabel(image: Mat, label: String): Unit = 184 | putText(image, 185 | label, // text 186 | new Point(50, 50), // text position 187 | FONT_HERSHEY_PLAIN, // font type 188 | 2.6, // font scale 189 | new Scalar(0, 0, 100, 0), // text color 190 | 4, // text thickness 191 | LINE_AA, // line type 192 | false) // origin is at the top-left corner 193 | 194 | // run detector on a single image 195 | def detectImage(image: Mat): Unit = { 196 | val canvasFrame = new CanvasFrame("Logo Classifier") 197 | canvasFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE) // exit when the canvas frame is closed 198 | canvasFrame.setCanvasSize(image.size.width, image.size.height) 199 | 200 | val imageTensor = matToTensor(image) 201 | val s = Session() 202 | val resized = s.run(fetches = tf.image.resizeBilinear(imageTensor, Seq(250, 250)).cast(UINT8)) 203 | 204 | val result: Tensor = estimator.infer(() => resized) 205 | 206 | logger.info("Result {}", result.summarize(flattened = true)) 207 | val probabilities = result.softmax().entriesIterator.map(_.asInstanceOf[Float]).toVector 208 | logger.info("Probabilities {}", probabilities) 209 | val label = result.argmax(-1).scalar.asInstanceOf[Long].toInt 210 | logger.info("Label {}", label.summarize(flattened = true)) 211 | 212 | drawLabel(image, 213 | s"Class: $label (${labelMap(label)}) " + 214 | s"Probability(${probabilities(label)})") 215 | 216 | canvasFrame.showImage(new OpenCVFrameConverter.ToMat().convert(image)) 217 | canvasFrame.waitKey(0) 218 | canvasFrame.dispose() 219 | } 220 | 221 | // run detector on an image sequence 222 | def detectSequence(grabber: FrameGrabber): Unit = { 223 | val canvasFrame = new CanvasFrame("Logo Classifier", CanvasFrame.getDefaultGamma / grabber.getGamma) 224 | canvasFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE) // exit when the canvas frame is closed 225 | val converter = new OpenCVFrameConverter.ToMat() 226 | grabber.start() 227 | for (frame <- continually(grabber.grab()).takeWhile(_ != null 228 | && (grabber.getLengthInFrames == 0 || grabber.getFrameNumber < grabber.getLengthInFrames))) { 229 | val image = converter.convert(frame) 230 | if (image != null) { // sometimes the first few frames are empty so we ignore them 231 | 232 | val imageTensor = matToTensor(image) 233 | val s = Session() 234 | val resized = s.run(fetches = tf.image.resizeBilinear(imageTensor, Seq(250, 250)).cast(UINT8)) 235 | 236 | val result: Tensor = estimator.infer(() => resized) 237 | 238 | logger.info("Result {}", result.summarize(flattened = true)) 239 | val probabilities = result.softmax().entriesIterator.map(_.asInstanceOf[Float]).toVector 240 | logger.info("Probabilities {}", probabilities) 241 | val label = result.argmax(-1).scalar.asInstanceOf[Long].toInt 242 | logger.info("Label {}", label.summarize(flattened = true)) 243 | 244 | drawLabel(image, 245 | s"Class: $label (${labelMap(label)}) " + 246 | s"Probability: ${probabilities(label)}") 247 | 248 | if (canvasFrame.isVisible) { // show our frame in the preview 249 | canvasFrame.showImage(frame) 250 | } 251 | } 252 | } 253 | canvasFrame.dispose() 254 | grabber.stop() 255 | } 256 | } 257 | 258 | 259 | //def accuracy(images: Tensor, labels: Tensor): Float = { 260 | // val predictions = estimator.infer(() => images) 261 | // predictions.argmax(1).cast(UINT8).equal(labels).cast(FLOAT32).mean().scalar.asInstanceOf[Float] 262 | //} 263 | 264 | // evaluate model performance 265 | //logger.info(s"Train accuracy = ${accuracy(dataSet.trainImages, dataSet.trainLabels)}") 266 | //logger.info(s"Test accuracy = ${accuracy(dataSet.testImages, dataSet.testLabels)}") 267 | 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /tensorflow/src/main/scala/io/brunk/examples/SimpleCNNModels.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Sören Brunk 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package io.brunk.examples 18 | 19 | import java.nio.ByteBuffer 20 | import java.nio.file.Paths 21 | 22 | import better.files._ 23 | import com.typesafe.scalalogging.Logger 24 | import javax.swing.JFrame 25 | import org.bytedeco.javacpp.opencv_core.{Mat, Point, Scalar, repeat => _, _} 26 | import org.bytedeco.javacpp.opencv_imgcodecs.imread 27 | import org.bytedeco.javacpp.opencv_imgproc.{COLOR_BGR2RGB, cvtColor, putText, resize} 28 | import org.bytedeco.javacv._ 29 | import org.platanios.tensorflow.api.ops.NN.ValidConvPadding 30 | import org.platanios.tensorflow.api.ops.io.data.Dataset 31 | import org.platanios.tensorflow.api.{DataType, _} 32 | import org.platanios.tensorflow.api.tf.learn._ 33 | 34 | import org.slf4j.LoggerFactory 35 | 36 | import scala.collection.Iterator.continually 37 | 38 | /** 39 | * CNN for image classification example 40 | * 41 | * @author Sören Brunk 42 | */ 43 | object SimpleCNNModels { 44 | 45 | lazy val models = Seq(v0, v1, v2, v3, v4, v5, v6, v7) 46 | 47 | val v0 = 48 | tf.learn.Cast("Input/Cast", FLOAT32) >> 49 | tf.learn.Flatten("Layer_1/Flatten") >> 50 | tf.learn.Linear("Layer_1/Linear", units = 64) >> 51 | tf.learn.ReLU("Layer_1/ReLU", 0.01f) >> 52 | tf.learn.Linear("OutputLayer/Linear", 2) 53 | 54 | val v1 = 55 | tf.learn.Cast("Input/Cast", FLOAT32) >> 56 | tf.learn.Flatten("Layer_1/Flatten") >> 57 | tf.learn.Linear("Layer_1/Linear", units = 128) >> 58 | tf.learn.ReLU("Layer_1/ReLU", 0.01f) >> 59 | tf.learn.Linear("OutputLayer/Linear", 2) 60 | 61 | val v2 = 62 | tf.learn.Cast("Input/Cast", FLOAT32) >> 63 | tf.learn.Flatten("Layer_1/Flatten") >> 64 | tf.learn.Linear("Layer_1/Linear", units = 512) >> 65 | tf.learn.ReLU("Layer_1/ReLU") >> 66 | tf.learn.Linear("OutputLayer/Linear", 2) 67 | 68 | val v3 = 69 | tf.learn.Cast("Input/Cast", FLOAT32) >> 70 | tf.learn.Flatten("Layer_1/Flatten") >> 71 | tf.learn.Linear("Layer_1/Linear", units = 512) >> 72 | tf.learn.ReLU("Layer_1/ReLU", 0.01f) >> 73 | tf.learn.Dropout("Layer_1/Dropout", keepProbability = 0.5f) >> 74 | tf.learn.Linear("OutputLayer/Linear", 2) 75 | 76 | val v4 = 77 | tf.learn.Cast("Input/Cast", FLOAT32) >> 78 | tf.learn.Conv2D("Layer_1/Conv2D", Shape(3, 3, 3, 32), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 79 | tf.learn.AddBias("Layer_1/Bias") >> 80 | tf.learn.ReLU("Layer_1/ReLU", alpha = 0.1f) >> 81 | tf.learn.MaxPool("Layer_1/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 82 | tf.learn.Flatten("Layer_2/Flatten") >> 83 | tf.learn.Linear("Layer_2/Linear", units = 512) >> 84 | tf.learn.ReLU("Layer_2/ReLU", 0.01f) >> 85 | tf.learn.Dropout("Layer_2/Dropout", keepProbability = 0.5f) >> 86 | tf.learn.Linear("OutputLayer/Linear", 2) 87 | 88 | val v5 = 89 | tf.learn.Cast("Input/Cast", FLOAT32) >> 90 | tf.learn.Conv2D("Layer_1/Conv2D", Shape(3, 3, 3, 32), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 91 | tf.learn.AddBias("Layer_1/Bias") >> 92 | tf.learn.ReLU("Layer_1/ReLU", alpha = 0.1f) >> 93 | tf.learn.MaxPool("Layer_1/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 94 | tf.learn.Conv2D("Layer_2/Conv2D", Shape(3, 3, 32, 64), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 95 | tf.learn.AddBias("Layer_2/Bias") >> 96 | tf.learn.ReLU("Layer_2/ReLU", alpha = 0.1f) >> 97 | tf.learn.MaxPool("Layer_2/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 98 | tf.learn.Conv2D("Layer_3/Conv2D", Shape(3, 3, 64, 128), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 99 | tf.learn.AddBias("Layer_3/Bias") >> 100 | tf.learn.ReLU("Layer_3/ReLU", alpha = 0.1f) >> 101 | tf.learn.MaxPool("Layer_3/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 102 | tf.learn.Conv2D("Layer_4/Conv2D", Shape(3, 3, 128, 128), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 103 | tf.learn.AddBias("Layer_4/Bias") >> 104 | tf.learn.ReLU("Layer_4/ReLU", alpha = 0.1f) >> 105 | tf.learn.MaxPool("Layer_4/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 106 | tf.learn.Flatten("Layer_5/Flatten") >> 107 | tf.learn.Linear("Layer_5/Linear", units = 512) >> 108 | tf.learn.ReLU("Layer_5/ReLU", 0.01f) >> 109 | tf.learn.Dropout("Layer_3/Dropout", keepProbability = 0.5f) >> 110 | tf.learn.Linear("OutputLayer/Linear", 2) 111 | 112 | val v6 = 113 | tf.learn.Cast("Input/Cast", FLOAT32) >> 114 | tf.learn.Conv2D("Layer_1/Conv2D", Shape(3, 3, 3, 32), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 115 | tf.learn.AddBias("Layer_1/Bias") >> 116 | tf.learn.ReLU("Layer_1/ReLU", alpha = 0.1f) >> 117 | tf.learn.MaxPool("Layer_1/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 118 | tf.learn.Dropout("Layer_1/Dropout", keepProbability = 0.8f) >> 119 | tf.learn.Conv2D("Layer_2/Conv2D", Shape(3, 3, 32, 64), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 120 | tf.learn.AddBias("Layer_2/Bias") >> 121 | tf.learn.ReLU("Layer_2/ReLU", alpha = 0.1f) >> 122 | tf.learn.MaxPool("Layer_2/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 123 | tf.learn.Dropout("Layer_2/Dropout", keepProbability = 0.8f) >> 124 | tf.learn.Conv2D("Layer_3/Conv2D", Shape(3, 3, 64, 128), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 125 | tf.learn.AddBias("Layer_3/Bias") >> 126 | tf.learn.ReLU("Layer_3/ReLU", alpha = 0.1f) >> 127 | tf.learn.MaxPool("Layer_3/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 128 | tf.learn.Dropout("Layer_3/Dropout", keepProbability = 0.8f) >> 129 | tf.learn.Conv2D("Layer_4/Conv2D", Shape(3, 3, 128, 128), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 130 | tf.learn.AddBias("Layer_4/Bias") >> 131 | tf.learn.ReLU("Layer_4/ReLU", alpha = 0.1f) >> 132 | tf.learn.MaxPool("Layer_4/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 133 | tf.learn.Dropout("Layer_4/Dropout", keepProbability = 0.8f) >> 134 | tf.learn.Flatten("Layer_5/Flatten") >> 135 | tf.learn.Linear("Layer_5/Linear", units = 512) >> 136 | tf.learn.ReLU("Layer_5/ReLU", 0.01f) >> 137 | tf.learn.Dropout("Layer_5/Dropout", keepProbability = 0.8f) >> 138 | tf.learn.Linear("OutputLayer/Linear", 2) 139 | 140 | val v7 = 141 | tf.learn.Cast("Input/Cast", FLOAT32) >> 142 | tf.learn.Conv2D("Layer_1/Conv2D", Shape(3, 3, 3, 32), stride1 = 1, stride2 = 1, padding = ValidConvPadding) >> 143 | tf.learn.AddBias("Layer_1/Bias") >> 144 | tf.learn.ReLU("Layer_1/ReLU", alpha = 0.1f) >> 145 | tf.learn.MaxPool("Layer_1/MaxPool", windowSize = Seq(1, 2, 2, 1), stride1 = 2, stride2 = 2, padding = ValidConvPadding) >> 146 | tf.learn.Flatten("Layer_2/Flatten") >> 147 | tf.learn.Linear("Layer_2/Linear", units = 512) >> 148 | tf.learn.ReLU("Layer_2/ReLU", 0.01f) >> 149 | tf.learn.Linear("OutputLayer/Linear", 2) 150 | 151 | } 152 | --------------------------------------------------------------------------------