├── .gitignore
├── .jvmopts
├── .scalafmt.conf
├── LICENSE
├── NOTICE
├── README.md
├── build.sbt
├── data
└── iris.csv
├── dl4j
└── src
│ └── main
│ ├── resources
│ └── logback.xml
│ └── scala
│ └── io
│ └── brunk
│ └── examples
│ ├── ImageReader.scala
│ ├── IrisReader.scala
│ ├── dl4j
│ ├── IrisMLP.scala
│ ├── MnistMLP.scala
│ └── SimpleCNN.scala
│ └── scalnet
│ ├── IrisMLP.scala
│ ├── MnistMLP.scala
│ └── SimpleCNN.scala
├── mxnet
├── build.sbt
├── project
│ ├── build.properties
│ └── plugins.sbt
└── src
│ └── main
│ ├── resources
│ └── logback.xml
│ └── scala
│ └── io
│ └── brunk
│ └── examples
│ ├── IrisMLP.scala
│ └── MnistMLP.scala
├── project
├── build.properties
└── plugins.sbt
└── tensorflow
├── example_image.jpg
└── src
└── main
├── protobuf
└── string_int_label_map.proto
├── resources
├── logback.xml
└── mscoco_label_map.pbtxt
└── scala
└── io
└── brunk
├── DatasetSplitter.scala
└── examples
├── FashionMnistCNN.scala
├── FashionMnistMLP.scala
├── IrisMLP.scala
├── MnistMLP.scala
├── ObjectDetector.scala
├── SimpleCNN.scala
└── SimpleCNNModels.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | /temp
2 |
3 | # sbt
4 | lib_managed
5 | project/project
6 | target
7 |
8 | # Worksheets (Eclipse or IntelliJ)
9 | *.sc
10 |
11 | # Eclipse
12 | .cache*
13 | .classpath
14 | .project
15 | .scala_dependencies
16 | .settings
17 | .target
18 | .worksheet
19 |
20 | # IntelliJ
21 | .idea
22 |
23 | # ENSIME
24 | .ensime
25 | .ensime_lucene
26 | .ensime_cache
27 |
28 | # Mac
29 | .DS_Store
30 |
31 | # Akka
32 | ddata*
33 | journal
34 | snapshots
35 |
36 | # Log files
37 | *.log
38 |
--------------------------------------------------------------------------------
/.jvmopts:
--------------------------------------------------------------------------------
1 | -Dfile.encoding=UTF8
2 | -Xms1G
3 | -Xmx6G
4 | -Xms6G
5 | -Xss2M
6 | -XX:ReservedCodeCacheSize=256m
7 | -XX:MaxMetaspaceSize=512m
8 | -XX:+TieredCompilation
9 | -XX:-UseGCOverheadLimit
10 | -XX:+CMSClassUnloadingEnabled
11 | -XX:+UseConcMarkSweepGC
12 |
--------------------------------------------------------------------------------
/.scalafmt.conf:
--------------------------------------------------------------------------------
1 | style = defaultWithAlign
2 |
3 | danglingParentheses = true
4 | indentOperator = spray
5 | maxColumn = 100
6 | project.excludeFilters = [".*\\.sbt"]
7 | rewrite.rules = [AsciiSortImports, RedundantBraces, RedundantParens]
8 | spaces.inImportCurlyBraces = true
9 | unindentTopLevelOperators = true
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2017 Sören Brunk
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # scala-deeplearn-examples #
2 |
3 | Welcome to scala-deeplearn-examples!
4 |
5 | This repository contains a list of examples used in my blog series on deep learning with Scala https://brunk.io
6 |
7 | You can clone the repository and run the examples using SBT.
8 |
9 | ## Contribution policy ##
10 |
11 | Contributions via GitHub pull requests are gladly accepted from their original author. Along with
12 | any pull requests, please state that the contribution is your original work and that you license
13 | the work to the project under the project's open source license. Whether or not you state this
14 | explicitly, by submitting any copyrighted material via pull request, email, or other means you
15 | agree to license the material under the project's open source license and warrant that you have the
16 | legal authority to do so.
17 |
18 | ## License ##
19 |
20 | This code is open source software licensed under the
21 | [Apache-2.0](http://www.apache.org/licenses/LICENSE-2.0) license.
22 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | // *****************************************************************************
2 | // Projects
3 | // *****************************************************************************
4 |
5 | // The MXNet example has been moved into its own sbt project for now because we have to build mxnet manually,
6 | // and we don't want to break dependency resolution for the other projects.
7 | // lazy val mxnet = project
8 |
9 | lazy val dl4j =
10 | project
11 | .in(file("dl4j"))
12 | .enablePlugins(AutomateHeaderPlugin)
13 | .settings(settings)
14 | .settings(
15 | scalaVersion := "2.11.12", // ScalNet and ND4S are only available for Scala 2.11
16 | libraryDependencies ++= Seq(
17 | library.dl4j,
18 | library.dl4jCuda,
19 | library.dl4jUi,
20 | library.logbackClassic,
21 | library.nd4jNativePlatform,
22 | library.scalNet
23 | )
24 | )
25 |
26 | lazy val tensorFlow =
27 | project
28 | .in(file("tensorflow"))
29 | .enablePlugins(AutomateHeaderPlugin)
30 | .settings(settings)
31 | .settings(
32 | PB.targets in Compile := Seq(
33 | scalapb.gen() -> (sourceManaged in Compile).value
34 | ),
35 | javaCppPresetLibs ++= Seq(
36 | "ffmpeg" -> "3.4.1"
37 | ),
38 | libraryDependencies ++= Seq(
39 | library.betterFiles,
40 | library.janino,
41 | library.logbackClassic,
42 | library.tensorFlow,
43 | library.tensorFlowData
44 | ),
45 | fork := true // prevent classloader issues caused by sbt and opencv
46 | )
47 |
48 | // *****************************************************************************
49 | // Library dependencies
50 | // *****************************************************************************
51 |
52 | lazy val library =
53 | new {
54 | object Version {
55 | val betterFiles = "3.4.0"
56 | val dl4j = "1.0.0-alpha"
57 | val janino = "2.6.1"
58 | val logbackClassic = "1.2.3"
59 | val scalaCheck = "1.13.5"
60 | val scalaTest = "3.0.4"
61 | val tensorFlow = "0.2.4"
62 |
63 | }
64 | val betterFiles = "com.github.pathikrit" %% "better-files" % Version.betterFiles
65 | val dl4j = "org.deeplearning4j" % "deeplearning4j-core" % Version.dl4j
66 | val dl4jUi = "org.deeplearning4j" %% "deeplearning4j-ui" % Version.dl4j
67 | val janino = "org.codehaus.janino" % "janino" % Version.janino
68 | val logbackClassic = "ch.qos.logback" % "logback-classic" % Version.logbackClassic
69 | val nd4jNativePlatform = "org.nd4j" % "nd4j-cuda-9.0-platform" % Version.dl4j
70 | val dl4jCuda = "org.deeplearning4j" % "deeplearning4j-cuda-9.0" % Version.dl4j
71 | val scalaCheck = "org.scalacheck" %% "scalacheck" % Version.scalaCheck
72 | val scalaTest = "org.scalatest" %% "scalatest" % Version.scalaTest
73 | val scalNet = "org.deeplearning4j" %% "scalnet" % Version.dl4j
74 | // change the classifier to "linux-cpu-x86_64" or "linux-gpu-x86_64" if you're on a linux/linux with nvidia system
75 | val tensorFlow = "org.platanios" %% "tensorflow" % Version.tensorFlow classifier "darwin-cpu-x86_64"
76 | val tensorFlowData = "org.platanios" %% "tensorflow-data" % Version.tensorFlow
77 | }
78 |
79 | // *****************************************************************************
80 | // Settings
81 | // *****************************************************************************
82 |
83 | lazy val settings =
84 | Seq(
85 | scalaVersion := "2.12.6",
86 | organization := "io.brunk",
87 | organizationName := "Sören Brunk",
88 | startYear := Some(2017),
89 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")),
90 | scalacOptions ++= Seq(
91 | "-unchecked",
92 | "-deprecation",
93 | "-language:_",
94 | "-target:jvm-1.8",
95 | "-encoding", "UTF-8"
96 | ),
97 | unmanagedSourceDirectories.in(Compile) := Seq(scalaSource.in(Compile).value),
98 | unmanagedSourceDirectories.in(Test) := Seq(scalaSource.in(Test).value),
99 | resolvers ++= Seq(
100 | Resolver.sonatypeRepo("snapshots")
101 | )
102 | )
--------------------------------------------------------------------------------
/data/iris.csv:
--------------------------------------------------------------------------------
1 | 150,4,setosa,versicolor,virginica
2 | 5.1,3.5,1.4,0.2,0
3 | 4.9,3.0,1.4,0.2,0
4 | 4.7,3.2,1.3,0.2,0
5 | 4.6,3.1,1.5,0.2,0
6 | 5.0,3.6,1.4,0.2,0
7 | 5.4,3.9,1.7,0.4,0
8 | 4.6,3.4,1.4,0.3,0
9 | 5.0,3.4,1.5,0.2,0
10 | 4.4,2.9,1.4,0.2,0
11 | 4.9,3.1,1.5,0.1,0
12 | 5.4,3.7,1.5,0.2,0
13 | 4.8,3.4,1.6,0.2,0
14 | 4.8,3.0,1.4,0.1,0
15 | 4.3,3.0,1.1,0.1,0
16 | 5.8,4.0,1.2,0.2,0
17 | 5.7,4.4,1.5,0.4,0
18 | 5.4,3.9,1.3,0.4,0
19 | 5.1,3.5,1.4,0.3,0
20 | 5.7,3.8,1.7,0.3,0
21 | 5.1,3.8,1.5,0.3,0
22 | 5.4,3.4,1.7,0.2,0
23 | 5.1,3.7,1.5,0.4,0
24 | 4.6,3.6,1.0,0.2,0
25 | 5.1,3.3,1.7,0.5,0
26 | 4.8,3.4,1.9,0.2,0
27 | 5.0,3.0,1.6,0.2,0
28 | 5.0,3.4,1.6,0.4,0
29 | 5.2,3.5,1.5,0.2,0
30 | 5.2,3.4,1.4,0.2,0
31 | 4.7,3.2,1.6,0.2,0
32 | 4.8,3.1,1.6,0.2,0
33 | 5.4,3.4,1.5,0.4,0
34 | 5.2,4.1,1.5,0.1,0
35 | 5.5,4.2,1.4,0.2,0
36 | 4.9,3.1,1.5,0.1,0
37 | 5.0,3.2,1.2,0.2,0
38 | 5.5,3.5,1.3,0.2,0
39 | 4.9,3.1,1.5,0.1,0
40 | 4.4,3.0,1.3,0.2,0
41 | 5.1,3.4,1.5,0.2,0
42 | 5.0,3.5,1.3,0.3,0
43 | 4.5,2.3,1.3,0.3,0
44 | 4.4,3.2,1.3,0.2,0
45 | 5.0,3.5,1.6,0.6,0
46 | 5.1,3.8,1.9,0.4,0
47 | 4.8,3.0,1.4,0.3,0
48 | 5.1,3.8,1.6,0.2,0
49 | 4.6,3.2,1.4,0.2,0
50 | 5.3,3.7,1.5,0.2,0
51 | 5.0,3.3,1.4,0.2,0
52 | 7.0,3.2,4.7,1.4,1
53 | 6.4,3.2,4.5,1.5,1
54 | 6.9,3.1,4.9,1.5,1
55 | 5.5,2.3,4.0,1.3,1
56 | 6.5,2.8,4.6,1.5,1
57 | 5.7,2.8,4.5,1.3,1
58 | 6.3,3.3,4.7,1.6,1
59 | 4.9,2.4,3.3,1.0,1
60 | 6.6,2.9,4.6,1.3,1
61 | 5.2,2.7,3.9,1.4,1
62 | 5.0,2.0,3.5,1.0,1
63 | 5.9,3.0,4.2,1.5,1
64 | 6.0,2.2,4.0,1.0,1
65 | 6.1,2.9,4.7,1.4,1
66 | 5.6,2.9,3.6,1.3,1
67 | 6.7,3.1,4.4,1.4,1
68 | 5.6,3.0,4.5,1.5,1
69 | 5.8,2.7,4.1,1.0,1
70 | 6.2,2.2,4.5,1.5,1
71 | 5.6,2.5,3.9,1.1,1
72 | 5.9,3.2,4.8,1.8,1
73 | 6.1,2.8,4.0,1.3,1
74 | 6.3,2.5,4.9,1.5,1
75 | 6.1,2.8,4.7,1.2,1
76 | 6.4,2.9,4.3,1.3,1
77 | 6.6,3.0,4.4,1.4,1
78 | 6.8,2.8,4.8,1.4,1
79 | 6.7,3.0,5.0,1.7,1
80 | 6.0,2.9,4.5,1.5,1
81 | 5.7,2.6,3.5,1.0,1
82 | 5.5,2.4,3.8,1.1,1
83 | 5.5,2.4,3.7,1.0,1
84 | 5.8,2.7,3.9,1.2,1
85 | 6.0,2.7,5.1,1.6,1
86 | 5.4,3.0,4.5,1.5,1
87 | 6.0,3.4,4.5,1.6,1
88 | 6.7,3.1,4.7,1.5,1
89 | 6.3,2.3,4.4,1.3,1
90 | 5.6,3.0,4.1,1.3,1
91 | 5.5,2.5,4.0,1.3,1
92 | 5.5,2.6,4.4,1.2,1
93 | 6.1,3.0,4.6,1.4,1
94 | 5.8,2.6,4.0,1.2,1
95 | 5.0,2.3,3.3,1.0,1
96 | 5.6,2.7,4.2,1.3,1
97 | 5.7,3.0,4.2,1.2,1
98 | 5.7,2.9,4.2,1.3,1
99 | 6.2,2.9,4.3,1.3,1
100 | 5.1,2.5,3.0,1.1,1
101 | 5.7,2.8,4.1,1.3,1
102 | 6.3,3.3,6.0,2.5,2
103 | 5.8,2.7,5.1,1.9,2
104 | 7.1,3.0,5.9,2.1,2
105 | 6.3,2.9,5.6,1.8,2
106 | 6.5,3.0,5.8,2.2,2
107 | 7.6,3.0,6.6,2.1,2
108 | 4.9,2.5,4.5,1.7,2
109 | 7.3,2.9,6.3,1.8,2
110 | 6.7,2.5,5.8,1.8,2
111 | 7.2,3.6,6.1,2.5,2
112 | 6.5,3.2,5.1,2.0,2
113 | 6.4,2.7,5.3,1.9,2
114 | 6.8,3.0,5.5,2.1,2
115 | 5.7,2.5,5.0,2.0,2
116 | 5.8,2.8,5.1,2.4,2
117 | 6.4,3.2,5.3,2.3,2
118 | 6.5,3.0,5.5,1.8,2
119 | 7.7,3.8,6.7,2.2,2
120 | 7.7,2.6,6.9,2.3,2
121 | 6.0,2.2,5.0,1.5,2
122 | 6.9,3.2,5.7,2.3,2
123 | 5.6,2.8,4.9,2.0,2
124 | 7.7,2.8,6.7,2.0,2
125 | 6.3,2.7,4.9,1.8,2
126 | 6.7,3.3,5.7,2.1,2
127 | 7.2,3.2,6.0,1.8,2
128 | 6.2,2.8,4.8,1.8,2
129 | 6.1,3.0,4.9,1.8,2
130 | 6.4,2.8,5.6,2.1,2
131 | 7.2,3.0,5.8,1.6,2
132 | 7.4,2.8,6.1,1.9,2
133 | 7.9,3.8,6.4,2.0,2
134 | 6.4,2.8,5.6,2.2,2
135 | 6.3,2.8,5.1,1.5,2
136 | 6.1,2.6,5.6,1.4,2
137 | 7.7,3.0,6.1,2.3,2
138 | 6.3,3.4,5.6,2.4,2
139 | 6.4,3.1,5.5,1.8,2
140 | 6.0,3.0,4.8,1.8,2
141 | 6.9,3.1,5.4,2.1,2
142 | 6.7,3.1,5.6,2.4,2
143 | 6.9,3.1,5.1,2.3,2
144 | 5.8,2.7,5.1,1.9,2
145 | 6.8,3.2,5.9,2.3,2
146 | 6.7,3.3,5.7,2.5,2
147 | 6.7,3.0,5.2,2.3,2
148 | 6.3,2.5,5.0,1.9,2
149 | 6.5,3.0,5.2,2.0,2
150 | 6.2,3.4,5.4,2.3,2
151 | 5.9,3.0,5.1,1.8,2
152 |
--------------------------------------------------------------------------------
/dl4j/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
6 |
7 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/dl4j/src/main/scala/io/brunk/examples/ImageReader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples
18 |
19 | import java.io.{File, FileFilter}
20 | import java.lang.Math.toIntExact
21 |
22 | import org.datavec.api.io.filters.BalancedPathFilter
23 | import org.datavec.api.io.labels.ParentPathLabelGenerator
24 | import org.datavec.api.split.{FileSplit, InputSplit}
25 | import org.datavec.image.loader.BaseImageLoader
26 | import org.datavec.image.recordreader.ImageRecordReader
27 | import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator
28 | import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator
29 | import org.deeplearning4j.eval.Evaluation
30 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
31 | import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler
32 |
33 | import scala.collection.JavaConverters._
34 |
35 |
36 | object ImageReader {
37 |
38 | val channels = 3
39 | val height = 150
40 | val width = 150
41 |
42 | val batchSize = 50
43 | val numClasses = 2
44 | val epochs = 100
45 | val splitTrainTest = 0.8
46 |
47 | val random = new java.util.Random()
48 |
49 | def createImageIterator(path: String): (MultipleEpochsIterator, DataSetIterator) = {
50 | val baseDir = new File(path)
51 | val labelGenerator = new ParentPathLabelGenerator
52 | val fileSplit = new FileSplit(baseDir, BaseImageLoader.ALLOWED_FORMATS, random)
53 |
54 | val numExamples = toIntExact(fileSplit.length)
55 | val numLabels = fileSplit.getRootDir.listFiles(new FileFilter {
56 | override def accept(pathname: File): Boolean = pathname.isDirectory
57 | }).length
58 |
59 | val pathFilter = new BalancedPathFilter(random, labelGenerator, numExamples, numLabels, batchSize)
60 |
61 | //val inputSplit = fileSplit.sample(pathFilter, splitTrainTest, 1 - splitTrainTest)
62 | val inputSplit = fileSplit.sample(pathFilter, 70, 30)
63 |
64 | val trainData = inputSplit(0)
65 | val validationData = inputSplit(1)
66 |
67 | val recordReader = new ImageRecordReader(height, width, channels, labelGenerator)
68 | val scaler = new ImagePreProcessingScaler(0, 1)
69 |
70 | recordReader.initialize(trainData, null)
71 | val dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, numClasses)
72 | scaler.fit(dataIter)
73 | dataIter.setPreProcessor(scaler)
74 | val trainIter = new MultipleEpochsIterator(epochs, dataIter)
75 |
76 | val valRecordReader = new ImageRecordReader(height, width, channels, labelGenerator)
77 | valRecordReader.initialize(validationData, null)
78 | val validationIter = new RecordReaderDataSetIterator(valRecordReader, batchSize, 1, numClasses)
79 | scaler.fit(validationIter)
80 | validationIter.setPreProcessor(scaler)
81 |
82 | (trainIter, validationIter)
83 | }
84 |
85 | }
86 |
--------------------------------------------------------------------------------
/dl4j/src/main/scala/io/brunk/examples/IrisReader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples
18 |
19 | import java.io.File
20 |
21 | import org.datavec.api.records.reader.impl.csv.CSVRecordReader
22 | import org.datavec.api.split.FileSplit
23 | import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator
24 | import org.nd4j.linalg.dataset.SplitTestAndTrain
25 | import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize
26 |
27 | object IrisReader {
28 | val numLinesToSkip = 1
29 |
30 | val batchSize = 150
31 | val labelIndex = 4
32 | val numLabels = 3
33 |
34 | val seed = 1
35 |
36 | def readData(): SplitTestAndTrain = {
37 | val recordReader = new CSVRecordReader(numLinesToSkip, ',')
38 | recordReader.initialize(new FileSplit(new File("data/iris.csv")))
39 | val iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numLabels)
40 | val dataSet = iterator.next() // read all data in a single batch
41 | dataSet.shuffle(seed)
42 | val testAndTrain = dataSet.splitTestAndTrain(0.67)
43 | val train = testAndTrain.getTrain
44 | val test = testAndTrain.getTest
45 |
46 | // val normalizer = new NormalizerStandardize
47 | // normalizer.fit(train)
48 | // normalizer.transform(train) // normalize training data
49 | // normalizer.transform(test) // normalize test data
50 | testAndTrain
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/dl4j/src/main/scala/io/brunk/examples/dl4j/IrisMLP.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples.dl4j
18 |
19 | import io.brunk.examples.IrisReader
20 | import org.deeplearning4j.eval.Evaluation
21 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration
22 | import org.deeplearning4j.nn.conf.layers.{ DenseLayer, OutputLayer }
23 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
24 | import org.deeplearning4j.nn.weights.WeightInit
25 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener
26 | import org.nd4j.linalg.activations.Activation
27 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
28 | import org.slf4j.{ Logger, LoggerFactory }
29 |
30 | /**
31 | * A simple feed forward network for classifying the IRIS dataset in dl4j with a single hidden layer
32 | *
33 | * Based on
34 | * https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/dataexamples/CSVExample.java
35 | *
36 | * @author Sören Brunk
37 | */
38 | object IrisMLP {
39 | private val log: Logger = LoggerFactory.getLogger(IrisMLP.getClass)
40 |
41 | def main(args: Array[String]): Unit = {
42 |
43 | val seed = 1 // for reproducibility
44 | val numInputs = 4
45 | val numHidden = 10
46 | val numOutputs = 3
47 | val learningRate = 0.1
48 | val numEpoch = 30
49 |
50 | val testAndTrain = IrisReader.readData()
51 |
52 | val conf = new NeuralNetConfiguration.Builder()
53 | .seed(seed)
54 | .activation(Activation.RELU)
55 | .weightInit(WeightInit.XAVIER)
56 | .list()
57 | .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHidden).build())
58 | .layer(1,
59 | new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
60 | .activation(Activation.SOFTMAX)
61 | .nIn(numHidden)
62 | .nOut(numOutputs)
63 | .build())
64 | .backprop(true)
65 | .pretrain(false)
66 | .build()
67 |
68 | val model = new MultiLayerNetwork(conf)
69 | model.init()
70 | model.setListeners(new ScoreIterationListener(100)) // print out scores every 100 iterations
71 |
72 | log.info("Running training")
73 | for(_ <- 0 until numEpoch)
74 | model.fit(testAndTrain.getTrain)
75 |
76 | log.info("Training finished")
77 |
78 | log.info(s"Evaluating model on ${testAndTrain.getTest.getLabels.rows()} examples")
79 | val evaluator = new Evaluation(numOutputs)
80 | val output = model.output(testAndTrain.getTest.getFeatureMatrix)
81 | evaluator.eval(testAndTrain.getTest.getLabels, output)
82 | println(evaluator.stats)
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/dl4j/src/main/scala/io/brunk/examples/dl4j/MnistMLP.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples.dl4j
18 |
19 | import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
20 | import org.deeplearning4j.eval.Evaluation
21 | import org.deeplearning4j.nn.api.OptimizationAlgorithm
22 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration
23 | import org.deeplearning4j.nn.conf.layers.{DenseLayer, OutputLayer}
24 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
25 | import org.deeplearning4j.nn.weights.WeightInit
26 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener
27 | import org.nd4j.linalg.activations.Activation
28 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
29 | import org.nd4j.linalg.learning.config.Sgd
30 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
31 | import org.slf4j.LoggerFactory
32 |
33 | import scala.collection.JavaConverters.asScalaIteratorConverter
34 |
35 | /** Simple multilayer perceptron for classifying handwritten digits from the MNIST dataset.
36 | *
37 | * Implemented using DL4J based on the Java example from
38 | * https://github.com/deeplearning4j/dl4j-examples/blob/dfcf71d75fff956db53a93b09b560d53e3da4638/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/mnist/MLPMnistSingleLayerExample.java
39 | *
40 | * @author Sören Brunk
41 | */
42 | object MnistMLP {
43 | private val log = LoggerFactory.getLogger(MnistMLP.getClass)
44 |
45 | def main(args: Array[String]): Unit = {
46 |
47 | val seed = 1 // for reproducibility
48 | val numInputs = 28 * 28
49 | val numHidden = 512 // size (number of neurons) of our hidden layer
50 | val numOutputs = 10 // digits from 0 to 9
51 | val learningRate = 0.01
52 | val batchSize = 128
53 | val numEpochs = 10
54 |
55 | // download and load the MNIST images as tensors
56 | val mnistTrain = new MnistDataSetIterator(batchSize, true, seed)
57 | val mnistTest = new MnistDataSetIterator(batchSize, false, seed)
58 |
59 | // define the neural network architecture
60 | val conf = new NeuralNetConfiguration.Builder()
61 | .seed(seed)
62 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
63 | .updater(new Sgd(learningRate))
64 | .weightInit(WeightInit.XAVIER) // random initialization of our weights
65 | .list // builder for creating stacked layers
66 | .layer(0, new DenseLayer.Builder() // define the hidden layer
67 | .nIn(numInputs)
68 | .nOut(numHidden)
69 | .activation(Activation.RELU)
70 | .build())
71 | .layer(1, new OutputLayer.Builder(LossFunction.MCXENT) // define loss and output layer
72 | .nIn(numHidden)
73 | .nOut(numOutputs)
74 | .activation(Activation.SOFTMAX)
75 | .build())
76 | .build()
77 |
78 | val model = new MultiLayerNetwork(conf)
79 | model.init()
80 | model.setListeners(new ScoreIterationListener(100)) // print the score every 100th iteration
81 |
82 | // train the model
83 | for (_ <- 0 until numEpochs)
84 | model.fit(mnistTrain)
85 |
86 | // evaluate model performance
87 | def accuracy(dataSet: DataSetIterator): Double = {
88 | val evaluator = new Evaluation(numOutputs)
89 | dataSet.reset()
90 | for (dataSet <- dataSet.asScala) {
91 | val output = model.output(dataSet.getFeatureMatrix)
92 | evaluator.eval(dataSet.getLabels, output)
93 | }
94 | evaluator.accuracy()
95 | }
96 |
97 | log.info(s"Train accuracy = ${accuracy(mnistTrain)}")
98 | log.info(s"Test accuracy = ${accuracy(mnistTest)}")
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/dl4j/src/main/scala/io/brunk/examples/dl4j/SimpleCNN.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples.dl4j
18 |
19 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration
20 | import org.deeplearning4j.nn.conf.layers.{ConvolutionLayer, DenseLayer, OutputLayer, SubsamplingLayer}
21 | import org.nd4j.linalg.learning.config.Adam
22 | import io.brunk.examples.ImageReader._
23 | import org.deeplearning4j.nn.conf.dropout.Dropout
24 | import org.deeplearning4j.nn.conf.inputs.InputType
25 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
26 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener
27 | import org.deeplearning4j.ui.api.UIServer
28 | import org.deeplearning4j.ui.stats.StatsListener
29 | import org.deeplearning4j.ui.storage.InMemoryStatsStorage
30 | import org.nd4j.linalg.activations.Activation.{RELU, SOFTMAX}
31 | import org.nd4j.linalg.lossfunctions.LossFunctions
32 | import org.slf4j.LoggerFactory
33 |
34 |
35 | object SimpleCNN {
36 |
37 | private val log = LoggerFactory.getLogger(getClass)
38 | val seed = 1
39 |
40 | def main(args: Array[String]): Unit = {
41 |
42 | val dataDir = args.head
43 |
44 | val conf = new NeuralNetConfiguration.Builder()
45 | .seed(seed)
46 | .updater(new Adam)
47 | .list()
48 | .layer(0, new ConvolutionLayer.Builder(3, 3)
49 | .nIn(channels)
50 | .nOut(32)
51 | .activation(RELU)
52 | .build())
53 | .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
54 | .kernelSize(2, 2)
55 | .build())
56 | .layer(2, new ConvolutionLayer.Builder(3, 3)
57 | .nOut(64)
58 | .activation(RELU)
59 | .build())
60 | .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
61 | .kernelSize(2, 2)
62 | .build())
63 | .layer(4, new ConvolutionLayer.Builder(3, 3)
64 | .nOut(128)
65 | .activation(RELU)
66 | .build())
67 | .layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
68 | .kernelSize(2, 2)
69 | .build())
70 | .layer(6, new ConvolutionLayer.Builder(3, 3)
71 | .nOut(128)
72 | .activation(RELU)
73 | .build())
74 | .layer(7, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
75 | .kernelSize(2, 2)
76 | .build())
77 | .layer(8, new DenseLayer.Builder()
78 | .nOut(512)
79 | .activation(RELU)
80 | .dropOut(new Dropout(0.5))
81 | .build())
82 | .layer(9, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
83 | .nOut(2)
84 | .activation(SOFTMAX)
85 | .build())
86 | .setInputType(InputType.convolutional(150, 150, 3))
87 | .backprop(true).pretrain(false).build()
88 |
89 | val model = new MultiLayerNetwork(conf)
90 | model.init()
91 | model.setListeners(new ScoreIterationListener(10))
92 | log.debug("Total num of params: {}", model.numParams)
93 |
94 | val uiServer = UIServer.getInstance
95 | val statsStorage = new InMemoryStatsStorage
96 | uiServer.attach(statsStorage)
97 | model.setListeners(new StatsListener(statsStorage))
98 |
99 | val (trainIter, testIter) = createImageIterator(dataDir)
100 |
101 | model.fit(trainIter)
102 | val eval = model.evaluate(testIter)
103 | log.info(eval.stats)
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/dl4j/src/main/scala/io/brunk/examples/scalnet/IrisMLP.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples.scalnet
18 |
19 | import io.brunk.examples.IrisReader
20 | import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator
21 | import org.deeplearning4j.eval.Evaluation
22 | import org.deeplearning4j.nn.conf.Updater
23 | import org.deeplearning4j.nn.weights.WeightInit
24 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener
25 | import org.deeplearning4j.scalnet.layers.core.Dense
26 | import org.deeplearning4j.scalnet.models.Sequential
27 | import org.deeplearning4j.scalnet.regularizers.L2
28 | import org.nd4j.linalg.activations.Activation
29 | import org.nd4j.linalg.api.ndarray.INDArray
30 | import org.nd4j.linalg.learning.config.Sgd
31 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
32 | import org.slf4j.{Logger, LoggerFactory}
33 |
34 | /**
35 | * A simple feed forward network (one hidden layer) for classifying the IRIS dataset
36 | * implemented using ScalNet.
37 | *
38 | * @author Sören Brunk
39 | */
40 | object IrisMLP {
41 |
42 | private val log: Logger = LoggerFactory.getLogger(IrisMLP.getClass)
43 |
44 | def main(args: Array[String]): Unit = {
45 |
46 | val seed = 1
47 | val numInputs = 4
48 | val numHidden = 10
49 | val numOutputs = 3
50 | val learningRate = 0.1
51 | val iterations = 1000
52 |
53 | val testAndTrain = IrisReader.readData()
54 | val trainList = testAndTrain.getTrain.asList()
55 | val trainIterator = new ListDataSetIterator(trainList, trainList.size)
56 |
57 | val model = Sequential(rngSeed = seed)
58 | model.add(Dense(numHidden, nIn = numInputs, weightInit = WeightInit.XAVIER, activation = Activation.RELU))
59 | model.add(Dense(numOutputs, weightInit = WeightInit.XAVIER, activation = Activation.SOFTMAX))
60 |
61 | model.compile(lossFunction = LossFunction.NEGATIVELOGLIKELIHOOD, updater = Updater.SGD)
62 |
63 | log.info("Running training")
64 | model.fit(iter = trainIterator,
65 | nbEpoch = iterations,
66 | listeners = List(new ScoreIterationListener(100)))
67 | log.info("Training finished")
68 |
69 | log.info(s"Evaluating model on ${testAndTrain.getTest.getLabels.rows()} examples")
70 | val evaluator = new Evaluation(numOutputs)
71 | val output: INDArray = model.predict(testAndTrain.getTest.getFeatureMatrix)
72 | evaluator.eval(testAndTrain.getTest.getLabels, output)
73 | log.info(evaluator.stats())
74 |
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/dl4j/src/main/scala/io/brunk/examples/scalnet/MnistMLP.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples.scalnet
18 |
19 | import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
20 | import org.deeplearning4j.eval.Evaluation
21 | import org.deeplearning4j.nn.conf.Updater
22 | import org.deeplearning4j.nn.weights.WeightInit
23 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener
24 | import org.deeplearning4j.scalnet.layers.core.Dense
25 | import org.deeplearning4j.scalnet.models.Sequential
26 | import org.nd4j.linalg.activations.Activation
27 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
28 | import org.nd4j.linalg.learning.config.Sgd
29 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
30 | import org.slf4j.{Logger, LoggerFactory}
31 |
32 | import scala.collection.JavaConverters.asScalaIteratorConverter
33 |
34 |
35 | /** Simple multilayer perceptron for classifying handwritten digits from the MNIST dataset.
36 | *
37 | * Implemented using ScalNet.
38 | *
39 | * @author Sören Brunk
40 | */
41 | object MnistMLP {
42 | private val log: Logger = LoggerFactory.getLogger(MnistMLP.getClass)
43 |
44 | def main(args: Array[String]): Unit = {
45 |
46 | val seed = 1 // for reproducibility
47 | val numInputs = 28 * 28
48 | val numHidden = 512 // size (number of neurons) in our hidden layer
49 | val numOutputs = 10 // digits from 0 to 9
50 | val learningRate = 0.01
51 | val batchSize = 128
52 | val numEpochs = 10
53 |
54 | // download and load the MNIST images as tensors
55 | val mnistTrain: DataSetIterator = new MnistDataSetIterator(batchSize, true, seed)
56 | val mnistTest: DataSetIterator = new MnistDataSetIterator(batchSize, false, seed)
57 |
58 | // define the neural network architecture
59 | val model: Sequential = Sequential(rngSeed = seed)
60 | model.add(Dense(nOut = numHidden, nIn = numInputs, weightInit = WeightInit.XAVIER, activation = Activation.RELU))
61 | model.add(Dense(nOut = numOutputs, weightInit = WeightInit.XAVIER, activation = Activation.RELU))
62 | model.compile(lossFunction = LossFunction.MCXENT, updater = Updater.SGD) // TODO how do we set the learning rate?
63 |
64 | // train the model
65 | model.fit(mnistTrain, nbEpoch = numEpochs, List(new ScoreIterationListener(100)))
66 |
67 | // evaluate model performance
68 | def accuracy(dataSet: DataSetIterator): Double = {
69 | val evaluator = new Evaluation(numOutputs)
70 | dataSet.reset()
71 | for (dataSet <- dataSet.asScala) {
72 | val output = model.predict(dataSet)
73 | evaluator.eval(dataSet.getLabels, output)
74 | }
75 | evaluator.accuracy()
76 | }
77 |
78 | log.info(s"Train accuracy = ${accuracy(mnistTrain)}")
79 | log.info(s"Test accuracy = ${accuracy(mnistTest)}")
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/dl4j/src/main/scala/io/brunk/examples/scalnet/SimpleCNN.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples.scalnet
18 |
19 | import io.brunk.examples.ImageReader
20 | import org.deeplearning4j.nn.conf.inputs.InputType
21 | import org.deeplearning4j.scalnet.models.NeuralNet
22 | import io.brunk.examples.ImageReader._
23 | import org.deeplearning4j.nn.conf.Updater
24 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener
25 | import org.deeplearning4j.scalnet.layers.convolutional.Convolution2D
26 | import org.deeplearning4j.scalnet.layers.core.Dense
27 | import org.deeplearning4j.scalnet.layers.pooling.MaxPooling2D
28 | import org.nd4j.linalg.activations.Activation._
29 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction
30 |
31 | object SimpleCNN {
32 |
33 |
34 | def main(args: Array[String]): Unit = {
35 |
36 | val dataDir = args.head
37 |
38 | val seed = 1
39 |
40 | val model = NeuralNet(inputType = InputType.convolutional(height, width, channels), rngSeed = seed)
41 |
42 | model.add(Convolution2D(32, List(3, 3), channels, activation = RELU))
43 | model.add(MaxPooling2D(List(2, 2)))
44 |
45 | model.add(Convolution2D(64, List(3, 3), activation = RELU))
46 | model.add(MaxPooling2D(List(2, 2)))
47 |
48 | model.add(Convolution2D(128, List(3, 3), activation = RELU))
49 | model.add(MaxPooling2D(List(2, 2)))
50 |
51 | model.add(Convolution2D(128, List(3, 3), activation = RELU))
52 | model.add(MaxPooling2D(List(2, 2)))
53 |
54 | model.add(Dense(512, activation = RELU, dropOut = 0.5))
55 | model.add(Dense(2, activation = SOFTMAX))
56 |
57 | model.compile(lossFunction = LossFunction.NEGATIVELOGLIKELIHOOD, updater = Updater.ADAM)
58 |
59 | val (trainIter, testIter) = createImageIterator(dataDir)
60 |
61 | model.fit(trainIter, 30, List(new ScoreIterationListener(10)))
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/mxnet/build.sbt:
--------------------------------------------------------------------------------
1 | // *****************************************************************************
2 | // Projects
3 | // *****************************************************************************
4 |
5 | lazy val mxnet =
6 | project
7 | .in(file("."))
8 | .enablePlugins(AutomateHeaderPlugin)
9 | .settings(settings)
10 | .settings(
11 | scalaVersion := "2.11.12", // MXNet is only available for Scala 2.11
12 | resolvers += Resolver.mavenLocal,
13 | libraryDependencies ++= Seq(
14 | library.logbackClassic,
15 | library.mxnetFull
16 | )
17 | )
18 |
19 | // *****************************************************************************
20 | // Library dependencies
21 | // *****************************************************************************
22 |
23 | lazy val library =
24 | new {
25 | object Version {
26 | val logbackClassic = "1.2.3"
27 | val mxnet = "1.0.0-SNAPSHOT"
28 | }
29 | val logbackClassic = "ch.qos.logback" % "logback-classic" % Version.logbackClassic
30 | // change to "mxnet-full_2.10-linux-x86_64-cpu" or "mxnet-full_2.10-linux-x86_64-gpu" depending on your os/gpu
31 | val mxnetFull = "ml.dmlc.mxnet" % "mxnet-full_2.11-osx-x86_64-cpu" % Version.mxnet
32 | }
33 |
34 | // *****************************************************************************
35 | // Settings
36 | // *****************************************************************************
37 |
38 | lazy val settings =
39 | Seq(
40 | scalaVersion := "2.12.4",
41 | organization := "io.brunk",
42 | organizationName := "Sören Brunk",
43 | startYear := Some(2017),
44 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")),
45 | scalacOptions ++= Seq(
46 | "-unchecked",
47 | "-deprecation",
48 | "-language:_",
49 | "-target:jvm-1.8",
50 | "-encoding", "UTF-8"
51 | ),
52 | unmanagedSourceDirectories.in(Compile) := Seq(scalaSource.in(Compile).value),
53 | unmanagedSourceDirectories.in(Test) := Seq(scalaSource.in(Test).value)
54 | )
55 |
--------------------------------------------------------------------------------
/mxnet/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version = 1.0.3
2 |
--------------------------------------------------------------------------------
/mxnet/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("de.heikoseeberger" % "sbt-header" % "4.0.0")
2 | addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.0-RC13")
--------------------------------------------------------------------------------
/mxnet/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
6 |
7 | %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/mxnet/src/main/scala/io/brunk/examples/IrisMLP.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples
18 |
19 | import ml.dmlc.mxnet._
20 | import ml.dmlc.mxnet.io.NDArrayIter
21 | import ml.dmlc.mxnet.optimizer.SGD
22 |
23 | object IrisMLP {
24 |
25 | def main(args: Array[String]): Unit = {
26 |
27 | val numInputs = 4
28 | val numHidden = 10
29 | val numOutputs = 3
30 | val learningRate = 0.1f
31 | val iterations = 1000
32 | val trainSize = 100
33 | val testSize = 50
34 |
35 | val batchSize = 50
36 | val epochs = (iterations / (batchSize.toFloat / trainSize)).toInt
37 |
38 | // The mxnet Scala IO API does not support shuffling so we just read the csv using plain Scala
39 | val source = scala.io.Source.fromFile("data/iris.csv")
40 | val rows = source.getLines().drop(1).map { l =>
41 | val columns = l.split(",").map(_.toFloat)
42 | new {
43 | val features = columns.take(4)
44 | val labels = columns(4)
45 | }
46 | }.toBuffer
47 | val shuffled = scala.util.Random.shuffle(rows).toArray
48 | val trainData = shuffled.take(trainSize)
49 | val testData = shuffled.drop(trainSize)
50 | val trainFeatures = NDArray.array(trainData.flatMap(_.features), Shape(trainSize, numInputs))
51 | val trainLabels = NDArray.array(trainData.map(_.labels), Shape(trainSize))
52 | val testFeatures = NDArray.array(testData.flatMap(_.features), Shape(testSize, numInputs))
53 | val testLabels = NDArray.array(testData.map(_.labels), Shape(testSize))
54 |
55 |
56 | val trainDataIter = new NDArrayIter(data = IndexedSeq(trainFeatures), label = IndexedSeq(trainLabels), dataBatchSize = 50)
57 | val testDataIter = new NDArrayIter(data = IndexedSeq(testFeatures), label = IndexedSeq(testLabels), dataBatchSize = 50)
58 |
59 | // Define the network architecture
60 | val data = Symbol.Variable("data")
61 | val label = Symbol.Variable("label")
62 | val l1 = Symbol.FullyConnected(name = "l1")()(Map("data" -> data, "num_hidden" -> numHidden))
63 | val a1 = Symbol.Activation(name = "a1")()(Map("data" -> l1, "act_type" -> "relu"))
64 | val l2 = Symbol.FullyConnected(name = "l2")()(Map("data" -> a1, "num_hidden" -> numOutputs))
65 | val out = Symbol.SoftmaxOutput(name = "sm")()(Map("data" -> l2, "label" -> label))
66 |
67 | // Create and train a model
68 | val model = FeedForward.newBuilder(out)
69 | .setContext(Context.cpu()) // change to gpu if available
70 | .setNumEpoch(epochs)
71 | .setOptimizer(new SGD(learningRate = learningRate))
72 | .setTrainData(trainDataIter)
73 | .setEvalData(testDataIter)
74 | .build()
75 | }
76 |
77 | }
78 |
--------------------------------------------------------------------------------
/mxnet/src/main/scala/io/brunk/examples/MnistMLP.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk.examples
18 |
19 | import ml.dmlc.mxnet._
20 | import ml.dmlc.mxnet.optimizer.SGD
21 |
22 | /** Simple multilayer perceptron for classifying handwritten digits from the MNIST dataset.
23 | *
24 | * Implemented using MXNet.
25 | * Based on https://mxnet.incubator.apache.org/tutorials/scala/mnist.html
26 | *
27 | * @author Sören Brunk
28 | */
29 | object MnistMLP {
30 |
31 | def main(args: Array[String]): Unit = {
32 |
33 | val numHidden = 512 // size (number of neurons) of our hidden layer
34 | val numOutputs = 10 // digits from 0 to 9
35 | val learningRate = 0.01f
36 | val batchSize = 128
37 | val numEpochs = 10
38 |
39 | // load the MNIST images as tensors
40 | val trainDataIter = IO.MNISTIter(Map(
41 | "image" -> "mnist/train-images-idx3-ubyte",
42 | "label" -> "mnist/train-labels-idx1-ubyte",
43 | "data_shape" -> "(1, 28, 28)",
44 | "label_name" -> "sm_label",
45 | "batch_size" -> batchSize.toString,
46 | "shuffle" -> "1",
47 | "flat" -> "0",
48 | "silent" -> "0"))
49 |
50 | val testDataIter = IO.MNISTIter(Map(
51 | "image" -> "mnist/t10k-images-idx3-ubyte",
52 | "label" -> "mnist/t10k-labels-idx1-ubyte",
53 | "data_shape" -> "(1, 28, 28)",
54 | "label_name" -> "sm_label",
55 | "batch_size" -> batchSize.toString,
56 | "shuffle" -> "1",
57 | "flat" -> "0",
58 | "silent" -> "0"))
59 |
60 | // define the neural network architecture
61 | val data = Symbol.Variable("data")
62 | val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> numHidden))
63 | val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
64 | val fc2 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act1, "num_hidden" -> numOutputs))
65 | val mlp = Symbol.SoftmaxOutput(name = "sm")()(Map("data" -> fc2))
66 |
67 | // create and train the model
68 | val model = FeedForward.newBuilder(mlp)
69 | .setContext(Context.cpu()) // change to gpu if available
70 | .setTrainData(trainDataIter)
71 | .setEvalData(testDataIter)
72 | .setNumEpoch(numEpochs)
73 | .setOptimizer(new SGD(learningRate = learningRate))
74 | .setInitializer(new Xavier()) // random weight initialization
75 | .build()
76 |
77 | // evaluate model performance
78 | def accuracy(dataset: DataIter): Float = {
79 | dataset.reset()
80 | val predictions = model.predict(dataset).head
81 | // get predicted labels
82 | val predictedY = NDArray.argmax_channel(predictions)
83 |
84 | // get real labels
85 | dataset.reset()
86 | val labels = dataset.map(_.label(0).copy()).toVector
87 | val y = NDArray.concatenate(labels)
88 | require(y.shape == predictedY.shape)
89 |
90 | // calculate accuracy
91 | val numCorrect = (y.toArray zip predictedY.toArray).count {
92 | case (labelElem, predElem) => labelElem == predElem
93 | }
94 | numCorrect.toFloat / y.size
95 | }
96 |
97 | println(s"Train accuracy = ${accuracy(trainDataIter)}")
98 | println(s"Test accuracy = ${accuracy(testDataIter)}")
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version = 1.2.1
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("de.heikoseeberger" % "sbt-header" % "4.0.0")
2 | addSbtPlugin("io.get-coursier" % "sbt-coursier" % "1.0.1")
3 | addSbtPlugin("org.bytedeco" % "sbt-javacv" % "1.16")
4 | addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.12")
5 |
6 | libraryDependencies += "com.trueaccord.scalapb" %% "compilerplugin" % "0.6.6"
--------------------------------------------------------------------------------
/tensorflow/example_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sbrunk/scala-deeplearn-examples/23edfff79c6a590ba5a5fd896080fb8ac116579a/tensorflow/example_image.jpg
--------------------------------------------------------------------------------
/tensorflow/src/main/protobuf/string_int_label_map.proto:
--------------------------------------------------------------------------------
1 | // Message to store the mapping from class label strings to class id. Datasets
2 | // use string labels to represent classes while the object detection framework
3 | // works with class ids. This message maps them so they can be converted back
4 | // and forth as needed.
5 | syntax = "proto2";
6 |
7 | package object_detection.protos;
8 |
9 | message StringIntLabelMapItem {
10 | // String name. The most common practice is to set this to a MID or synsets
11 | // id.
12 | optional string name = 1;
13 |
14 | // Integer id that maps to the string name above. Label ids should start from
15 | // 1.
16 | optional int32 id = 2;
17 |
18 | // Human readable string label.
19 | optional string display_name = 3;
20 | };
21 |
22 | message StringIntLabelMap {
23 | repeated StringIntLabelMapItem item = 1;
24 | };
25 |
--------------------------------------------------------------------------------
/tensorflow/src/main/resources/logback.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
7 |
8 |
9 |
10 |
11 |
12 | return message.contains("TF GPU device with id 0 was not registered");
13 |
14 | NEUTRAL
15 | DENY
16 |
17 |
18 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
19 |
20 |
21 |
22 |
23 |
25 | log-${bySecond}.txt
26 |
27 | %d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/tensorflow/src/main/resources/mscoco_label_map.pbtxt:
--------------------------------------------------------------------------------
1 | item {
2 | name: "/m/01g317"
3 | id: 1
4 | display_name: "person"
5 | }
6 | item {
7 | name: "/m/0199g"
8 | id: 2
9 | display_name: "bicycle"
10 | }
11 | item {
12 | name: "/m/0k4j"
13 | id: 3
14 | display_name: "car"
15 | }
16 | item {
17 | name: "/m/04_sv"
18 | id: 4
19 | display_name: "motorcycle"
20 | }
21 | item {
22 | name: "/m/05czz6l"
23 | id: 5
24 | display_name: "airplane"
25 | }
26 | item {
27 | name: "/m/01bjv"
28 | id: 6
29 | display_name: "bus"
30 | }
31 | item {
32 | name: "/m/07jdr"
33 | id: 7
34 | display_name: "train"
35 | }
36 | item {
37 | name: "/m/07r04"
38 | id: 8
39 | display_name: "truck"
40 | }
41 | item {
42 | name: "/m/019jd"
43 | id: 9
44 | display_name: "boat"
45 | }
46 | item {
47 | name: "/m/015qff"
48 | id: 10
49 | display_name: "traffic light"
50 | }
51 | item {
52 | name: "/m/01pns0"
53 | id: 11
54 | display_name: "fire hydrant"
55 | }
56 | item {
57 | name: "/m/02pv19"
58 | id: 13
59 | display_name: "stop sign"
60 | }
61 | item {
62 | name: "/m/015qbp"
63 | id: 14
64 | display_name: "parking meter"
65 | }
66 | item {
67 | name: "/m/0cvnqh"
68 | id: 15
69 | display_name: "bench"
70 | }
71 | item {
72 | name: "/m/015p6"
73 | id: 16
74 | display_name: "bird"
75 | }
76 | item {
77 | name: "/m/01yrx"
78 | id: 17
79 | display_name: "cat"
80 | }
81 | item {
82 | name: "/m/0bt9lr"
83 | id: 18
84 | display_name: "dog"
85 | }
86 | item {
87 | name: "/m/03k3r"
88 | id: 19
89 | display_name: "horse"
90 | }
91 | item {
92 | name: "/m/07bgp"
93 | id: 20
94 | display_name: "sheep"
95 | }
96 | item {
97 | name: "/m/01xq0k1"
98 | id: 21
99 | display_name: "cow"
100 | }
101 | item {
102 | name: "/m/0bwd_0j"
103 | id: 22
104 | display_name: "elephant"
105 | }
106 | item {
107 | name: "/m/01dws"
108 | id: 23
109 | display_name: "bear"
110 | }
111 | item {
112 | name: "/m/0898b"
113 | id: 24
114 | display_name: "zebra"
115 | }
116 | item {
117 | name: "/m/03bk1"
118 | id: 25
119 | display_name: "giraffe"
120 | }
121 | item {
122 | name: "/m/01940j"
123 | id: 27
124 | display_name: "backpack"
125 | }
126 | item {
127 | name: "/m/0hnnb"
128 | id: 28
129 | display_name: "umbrella"
130 | }
131 | item {
132 | name: "/m/080hkjn"
133 | id: 31
134 | display_name: "handbag"
135 | }
136 | item {
137 | name: "/m/01rkbr"
138 | id: 32
139 | display_name: "tie"
140 | }
141 | item {
142 | name: "/m/01s55n"
143 | id: 33
144 | display_name: "suitcase"
145 | }
146 | item {
147 | name: "/m/02wmf"
148 | id: 34
149 | display_name: "frisbee"
150 | }
151 | item {
152 | name: "/m/071p9"
153 | id: 35
154 | display_name: "skis"
155 | }
156 | item {
157 | name: "/m/06__v"
158 | id: 36
159 | display_name: "snowboard"
160 | }
161 | item {
162 | name: "/m/018xm"
163 | id: 37
164 | display_name: "sports ball"
165 | }
166 | item {
167 | name: "/m/02zt3"
168 | id: 38
169 | display_name: "kite"
170 | }
171 | item {
172 | name: "/m/03g8mr"
173 | id: 39
174 | display_name: "baseball bat"
175 | }
176 | item {
177 | name: "/m/03grzl"
178 | id: 40
179 | display_name: "baseball glove"
180 | }
181 | item {
182 | name: "/m/06_fw"
183 | id: 41
184 | display_name: "skateboard"
185 | }
186 | item {
187 | name: "/m/019w40"
188 | id: 42
189 | display_name: "surfboard"
190 | }
191 | item {
192 | name: "/m/0dv9c"
193 | id: 43
194 | display_name: "tennis racket"
195 | }
196 | item {
197 | name: "/m/04dr76w"
198 | id: 44
199 | display_name: "bottle"
200 | }
201 | item {
202 | name: "/m/09tvcd"
203 | id: 46
204 | display_name: "wine glass"
205 | }
206 | item {
207 | name: "/m/08gqpm"
208 | id: 47
209 | display_name: "cup"
210 | }
211 | item {
212 | name: "/m/0dt3t"
213 | id: 48
214 | display_name: "fork"
215 | }
216 | item {
217 | name: "/m/04ctx"
218 | id: 49
219 | display_name: "knife"
220 | }
221 | item {
222 | name: "/m/0cmx8"
223 | id: 50
224 | display_name: "spoon"
225 | }
226 | item {
227 | name: "/m/04kkgm"
228 | id: 51
229 | display_name: "bowl"
230 | }
231 | item {
232 | name: "/m/09qck"
233 | id: 52
234 | display_name: "banana"
235 | }
236 | item {
237 | name: "/m/014j1m"
238 | id: 53
239 | display_name: "apple"
240 | }
241 | item {
242 | name: "/m/0l515"
243 | id: 54
244 | display_name: "sandwich"
245 | }
246 | item {
247 | name: "/m/0cyhj_"
248 | id: 55
249 | display_name: "orange"
250 | }
251 | item {
252 | name: "/m/0hkxq"
253 | id: 56
254 | display_name: "broccoli"
255 | }
256 | item {
257 | name: "/m/0fj52s"
258 | id: 57
259 | display_name: "carrot"
260 | }
261 | item {
262 | name: "/m/01b9xk"
263 | id: 58
264 | display_name: "hot dog"
265 | }
266 | item {
267 | name: "/m/0663v"
268 | id: 59
269 | display_name: "pizza"
270 | }
271 | item {
272 | name: "/m/0jy4k"
273 | id: 60
274 | display_name: "donut"
275 | }
276 | item {
277 | name: "/m/0fszt"
278 | id: 61
279 | display_name: "cake"
280 | }
281 | item {
282 | name: "/m/01mzpv"
283 | id: 62
284 | display_name: "chair"
285 | }
286 | item {
287 | name: "/m/02crq1"
288 | id: 63
289 | display_name: "couch"
290 | }
291 | item {
292 | name: "/m/03fp41"
293 | id: 64
294 | display_name: "potted plant"
295 | }
296 | item {
297 | name: "/m/03ssj5"
298 | id: 65
299 | display_name: "bed"
300 | }
301 | item {
302 | name: "/m/04bcr3"
303 | id: 67
304 | display_name: "dining table"
305 | }
306 | item {
307 | name: "/m/09g1w"
308 | id: 70
309 | display_name: "toilet"
310 | }
311 | item {
312 | name: "/m/07c52"
313 | id: 72
314 | display_name: "tv"
315 | }
316 | item {
317 | name: "/m/01c648"
318 | id: 73
319 | display_name: "laptop"
320 | }
321 | item {
322 | name: "/m/020lf"
323 | id: 74
324 | display_name: "mouse"
325 | }
326 | item {
327 | name: "/m/0qjjc"
328 | id: 75
329 | display_name: "remote"
330 | }
331 | item {
332 | name: "/m/01m2v"
333 | id: 76
334 | display_name: "keyboard"
335 | }
336 | item {
337 | name: "/m/050k8"
338 | id: 77
339 | display_name: "cell phone"
340 | }
341 | item {
342 | name: "/m/0fx9l"
343 | id: 78
344 | display_name: "microwave"
345 | }
346 | item {
347 | name: "/m/029bxz"
348 | id: 79
349 | display_name: "oven"
350 | }
351 | item {
352 | name: "/m/01k6s3"
353 | id: 80
354 | display_name: "toaster"
355 | }
356 | item {
357 | name: "/m/0130jx"
358 | id: 81
359 | display_name: "sink"
360 | }
361 | item {
362 | name: "/m/040b_t"
363 | id: 82
364 | display_name: "refrigerator"
365 | }
366 | item {
367 | name: "/m/0bt_c3"
368 | id: 84
369 | display_name: "book"
370 | }
371 | item {
372 | name: "/m/01x3z"
373 | id: 85
374 | display_name: "clock"
375 | }
376 | item {
377 | name: "/m/02s195"
378 | id: 86
379 | display_name: "vase"
380 | }
381 | item {
382 | name: "/m/01lsmm"
383 | id: 87
384 | display_name: "scissors"
385 | }
386 | item {
387 | name: "/m/0kmg4"
388 | id: 88
389 | display_name: "teddy bear"
390 | }
391 | item {
392 | name: "/m/03wvsk"
393 | id: 89
394 | display_name: "hair drier"
395 | }
396 | item {
397 | name: "/m/012xff"
398 | id: 90
399 | display_name: "toothbrush"
400 | }
401 |
--------------------------------------------------------------------------------
/tensorflow/src/main/scala/io/brunk/DatasetSplitter.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2017 Sören Brunk
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.brunk
18 |
19 | import better.files.File
20 |
21 | import scala.util.Random.shuffle
22 |
23 | /** Script that splits an image dataset into train/validation/test set
24 | *
25 | * Expects the following structure per class: /
26 | * Outputs each subset into a subdir for training, validation and testset
27 | *
28 | * usage: DatasetSplitter