├── .gitignore
├── .travis.yml
├── LICENSE
├── README.adoc
├── pom.xml
└── src
├── main
└── java
│ ├── ml
│ ├── DL4JMLModel.java
│ ├── EncogMLModel.java
│ ├── LoadTensorFlow.java
│ ├── ML.java
│ └── MLModel.java
│ └── result
│ ├── VirtualNode.java
│ └── VirtualRelationship.java
└── test
├── java
└── ml
│ ├── DL4JTest.java
│ ├── DL4JXORHelloWorld.java
│ ├── HelloTF.java
│ ├── IrisClassification.java
│ ├── LoadTest.java
│ ├── MLProcedureTest.java
│ ├── MLTest.java
│ └── XORHelloWorld.java
└── resources
├── iris.csv
├── linear_data_eval.csv
├── linear_data_train.csv
├── saved_model.pb
└── tensorflow_example.pbtxt
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | \#*
3 | target
4 | out
5 | .project
6 | .classpath
7 | .settings
8 | .externalToolBuilders/
9 | .scala_dependencies
10 | .factorypath
11 | .cache
12 | .cache-main
13 | .cache-tests
14 | *.iws
15 | *.ipr
16 | *.iml
17 | .idea
18 | .DS_Store
19 | .shell_history
20 | .mailmap
21 | .java-version
22 | .cache-main
23 | .cache-tests
24 | Thumbs.db
25 | .cache-main
26 | .cache-tests
27 | neo4j-home
28 | dependency-reduced-pom.xml
29 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: java
2 |
3 | jdk:
4 | - oraclejdk8
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.adoc:
--------------------------------------------------------------------------------
1 | = Neo4j Machine Learning Procedures (WIP)
2 |
3 | image:https://travis-ci.org/neo4j-contrib/neo4j-ml-procedures.svg?branch=3.1["Build state", link="https://travis-ci.org/neo4j-contrib/neo4j-ml-procedures"]
4 |
5 | This project provides procedures and functions to support machine learning applications with Neo4j.
6 |
7 | [NOTE]
8 | This project requires Neo4j 3.2.x
9 |
10 | Thanks a lot to https://github.com/encog/encog-java-core[Encog for the great library^] and to https://www.stardog.com/docs/#_machine_learning[Stardog for the idea^].
11 |
12 | == Installation
13 |
14 | 1. Download the jar from the https://github.com/neo4j-contrib/neo4j-ml-procedures/releases/latest[latest release] or build it locally
15 | 2. Copy it into your `$NEO4J_HOME/plugins` directory.
16 | 3. Restart your server.
17 |
18 | == Built in classification and regression (WIP)
19 |
20 | [source,cypher]
21 | ----
22 | CALL ml.create("model",{types},"output",{config}) YIELD model, state, info
23 |
24 | CALL ml.add("model", {inputs}, given) YIELD model, state, info
25 |
26 | CALL ml.train() YIELD model, state, info
27 |
28 | CALL ml.predict("model", {inputs}) YIELD value [, confidence]
29 |
30 | CALL ml.remove(model) YIELD model, state
31 | ----
32 |
33 | Example: IRIS Classification from Encog
34 |
35 | [source,cypher]
36 | ----
37 | CALL ml.create("iris",
38 | {sepalLength: "float", sepalWidth: "float", petalLength: "float", petalWidth: "float", kind: "class"}, "kind",{});
39 | ----
40 |
41 | [source,cypher]
42 | ----
43 | LOAD CSV FROM "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" AS row
44 | CALL ml.add('iris', {sepalLength: row[0], sepalWidth: row[1], petalLength: row[2], petalWidth: row[3]}, row[4])
45 | YIELD state
46 | WITH collect(distinct state) as states, collect(distinct row[4]) as kinds
47 |
48 | CALL ml.train('iris') YIELD state, info
49 |
50 | RETURN state, states, kinds, info;
51 | ----
52 |
53 | ----
54 | ╒═══════╤════════════╤══════════════════════════════════════════════════╤══════════════════════════════════════════════════════════════════════╕
55 | │"state"│"states" │"kinds" │"info" │
56 | ╞═══════╪════════════╪══════════════════════════════════════════════════╪══════════════════════════════════════════════════════════════════════╡
57 | │"ready"│["training"]│["Iris-setosa","Iris-versicolor","Iris-virginica"]│{"trainingSets":150,"methodName":"feedforward","normalization":"[Norma│
58 | │ │ │ │lizationHelper:\n[ColumnDefinition:sepalLength(continuous);low=4,30000│
59 | │ │ │ │0,high=7,900000,mean=5,843333,sd=0,825301]\n[ColumnDefinition:petalLen│
60 | │ │ │ │gth(continuous);low=1,000000,high=6,900000,mean=3,758667,sd=1,758529]\│
61 | │ │ │ │n[ColumnDefinition:sepalWidth(continuous);low=2,000000,high=4,400000,m│
62 | │ │ │ │ean=3,054000,sd=0,432147]\n[ColumnDefinition:petalWidth(continuous);lo│
63 | │ │ │ │w=0,100000,high=2,500000,mean=1,198667,sd=0,760613]\n[ColumnDefinition│
64 | │ │ │ │:kind(nominal);[Iris-setosa, Iris-versicolor, Iris-virginica]]\n]","tr│
65 | │ │ │ │ainingError":0.034672103747075696,"selectedMethod":"[BasicNetwork: Lay│
66 | │ │ │ │ers=3]","validationError":0.05766172747088482} │
67 | └───────┴────────────┴──────────────────────────────────────────────────┴──────────────────────────────────────────────────────────────────────┘
68 | ----
69 |
70 | [source,cypher]
71 | ----
72 | LOAD CSV FROM "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" AS row
73 | WITH state, info, row limit 10
74 |
75 | CALL ml.predict("iris", {sepalLength: row[0], sepalWidth: row[1], petalLength: row[2], petalWidth: row[3]})
76 | YIELD value as prediction
77 |
78 | RETURN row[4] as correct, prediction, state, row;
79 | ----
80 |
81 | ----
82 | ╒═════════════╤═════════════╤═══════╤═══════════════════════════════════════╕
83 | │"correct" │"prediction" │"state"│"row" │
84 | ╞═════════════╪═════════════╪═══════╪═══════════════════════════════════════╡
85 | │"Iris-setosa"│"Iris-setosa"│"ready"│["5.1","3.5","1.4","0.2","Iris-setosa"]│
86 | ├─────────────┼─────────────┼───────┼───────────────────────────────────────┤
87 | │"Iris-setosa"│"Iris-setosa"│"ready"│["4.9","3.0","1.4","0.2","Iris-setosa"]│
88 | ├─────────────┼─────────────┼───────┼───────────────────────────────────────┤
89 | │"Iris-setosa"│"Iris-setosa"│"ready"│["4.7","3.2","1.3","0.2","Iris-setosa"]│
90 | ├─────────────┼─────────────┼───────┼───────────────────────────────────────┤
91 | ...
92 | ----
93 |
94 | [source,cypher]
95 | ----
96 | CALL ml.remove('iris') YIELD model, state;
97 | ----
98 |
99 | == Manual neural network operations (TODO)
100 |
101 | .Procedures
102 | [source,cypher]
103 | ----
104 | apoc.ml.propagate(network, {inputs}) yield {outputs}
105 | apoc.ml.backprop(network, {output}) yield {network}
106 | ----
107 |
108 | .Functions
109 | [source,cypher]
110 | ----
111 | apoc.ml.sigmoid
112 | apoc.ml.sigmoidPrime
113 | ----
114 |
115 | Future plans include storing networks from the common machine learning libraries (TensorFlow, Deeplearning4j, Encog etc.) as executable Network structures in Neo4j.
116 |
117 | == Building it yourself
118 |
119 | This project uses maven, to build a jar-file with the procedure in this
120 | project, simply package the project with maven:
121 |
122 | mvn clean package
123 |
124 | This will produce a jar-file,`target/neo4j-ml-procedures-*-SNAPSHOT.jar`, that can be copied in the `$NEO4J_HOME/plugins` directory of your Neo4j instance.
125 |
126 | == License
127 |
128 | Apache License V2, see LICENSE
129 |
130 | == Next Steps
131 |
132 | * Normalization / Classification for dl4j
133 | * Push frameworks to separate modules
134 | * Support external frameworks
135 | * Store / Load models from graph
136 | * Expose simple ML functions / propagation/ backprop on graph structures as procedures and functions
137 | * Load PMML
138 | * More fine-grained configuration (JSON) for Networks
139 | * K-Means, Classification, Regression
140 | * Spark ML-Lib
141 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
4 | 4.0.0
5 |
6 | org.neo4j.procedures
7 | neo4j-ml-procedures
8 | 3.2.2.1-SNAPSHOT
9 |
10 | jar
11 | Neo4j ML Procedures
12 |
13 |
14 | UTF-8
15 | 1.8
16 | ${encoding}
17 | ${encoding}
18 | ${java.version}
19 | ${java.version}
20 | 3.2.2
21 |
22 |
23 |
24 |
25 | org.neo4j
26 | neo4j
27 | ${neo4j.version}
28 | provided
29 |
30 |
31 |
32 | org.neo4j
33 | neo4j-kernel
34 | ${neo4j.version}
35 | test
36 | test-jar
37 |
38 |
39 | org.neo4j
40 | neo4j-io
41 | ${neo4j.version}
42 | test
43 | test-jar
44 |
45 |
46 |
47 | net.biville.florent
48 | neo4j-sproc-compiler
49 | 1.2
50 | true
51 | provided
52 |
53 |
54 |
55 | org.apache.commons
56 | commons-math3
57 | 3.4.1
58 |
59 |
60 | org.apache.commons
61 | commons-lang3
62 | 3.3.1
63 |
64 |
65 |
66 | org.encog
67 | encog-core
68 | 3.3.0
69 |
70 |
71 | org.apache.commons
72 | commons-math3
73 |
74 |
75 |
76 |
77 |
78 | org.deeplearning4j
79 | deeplearning4j-core
80 | 0.9.1
81 |
82 |
83 | org.nd4j
84 | nd4j-native-platform
85 | 0.9.1
86 |
87 |
88 |
89 | org.tensorflow
90 | tensorflow
91 | 1.4.0
92 |
93 |
94 | org.tensorflow
95 | proto
96 | 1.4.0
97 |
98 |
99 |
100 | junit
101 | junit
102 | 4.12
103 | test
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 | maven-shade-plugin
112 | 2.4.3
113 |
114 |
115 | package
116 |
117 | shade
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/src/main/java/ml/DL4JMLModel.java:
--------------------------------------------------------------------------------
1 | package ml;
2 |
3 | import org.datavec.api.records.metadata.RecordMetaData;
4 | import org.datavec.api.records.reader.impl.collection.ListStringRecordReader;
5 | import org.datavec.api.split.ListStringSplit;
6 | import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
7 | import org.deeplearning4j.eval.Evaluation;
8 | import org.deeplearning4j.eval.meta.Prediction;
9 | import org.deeplearning4j.nn.api.Layer;
10 | import org.deeplearning4j.nn.api.OptimizationAlgorithm;
11 | import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
12 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
13 | import org.deeplearning4j.nn.conf.Updater;
14 | import org.deeplearning4j.nn.conf.layers.DenseLayer;
15 | import org.deeplearning4j.nn.conf.layers.OutputLayer;
16 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
17 | import org.deeplearning4j.nn.weights.WeightInit;
18 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
19 | import org.nd4j.linalg.activations.Activation;
20 | import org.nd4j.linalg.api.ndarray.INDArray;
21 | import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
22 | import org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory;
23 | import org.nd4j.linalg.dataset.DataSet;
24 | import org.nd4j.linalg.dataset.SplitTestAndTrain;
25 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
26 | import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
27 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
28 | import org.nd4j.linalg.util.NDArrayUtil;
29 | import org.neo4j.graphdb.Label;
30 | import org.neo4j.graphdb.Node;
31 | import org.neo4j.helpers.collection.MapUtil;
32 | import result.VirtualNode;
33 |
34 | import java.util.*;
35 | import java.util.function.Function;
36 | import java.util.stream.Collectors;
37 |
38 | /**
39 | * @author mh
40 | * @since 23.07.17
41 | */
42 | public class DL4JMLModel extends MLModel> {
43 | private MultiLayerNetwork model;
44 |
45 | public DL4JMLModel(String name, Map types, String output, Map config) {
46 | super(name, types, output, config);
47 | }
48 |
49 | @Override
50 | protected List asRow(Map inputs, Object output) {
51 | List row = new ArrayList<>(inputs.size() + (output == null ? 0 : 1));
52 | for (String k : inputs.keySet()) {
53 | row.add(offsets.get(k), inputs.get(k).toString());
54 | }
55 | if (output != null) {
56 | row.add(offsets.get(this.output), output.toString());
57 | }
58 | return row;
59 | }
60 |
61 | @Override
62 | protected Object doPredict(List line) {
63 | try {
64 | ListStringSplit input = new ListStringSplit(Collections.singletonList(line));
65 | ListStringRecordReader rr = new ListStringRecordReader();
66 | rr.initialize(input);
67 | DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1);
68 |
69 | DataSet ds = iterator.next();
70 | INDArray prediction = model.output(ds.getFeatures());
71 |
72 | DataType outputType = types.get(this.output);
73 | switch (outputType) {
74 | case _float : return prediction.getDouble(0);
75 | case _class: {
76 | int numClasses = 2;
77 | double max = 0;
78 | int maxIndex = -1;
79 | for (int i=0;i max) {maxIndex = i; max = prediction.getDouble(i);}
81 | }
82 | return maxIndex;
83 | // return prediction.getInt(0,1); // numberOfClasses
84 | }
85 | default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
86 | }
87 | } catch (Exception e) {
88 | throw new RuntimeException(e);
89 | }
90 | }
91 |
92 | @Override
93 | protected void doTrain() {
94 | try {
95 | long seed = config.seed.get();
96 | double learningRate = config.learningRate.get();
97 | int nEpochs = config.epochs.get();
98 |
99 | int numOutputs = 1;
100 | int numInputs = types.size() - numOutputs;
101 | int outputOffset = offsets.get(output); // last column
102 | int numHiddenNodes = config.hidden.get();
103 | double trainPercent = config.trainPercent.get();
104 | int batchSize = rows.size(); // full dataset size
105 |
106 | Map> classes = new HashMap<>();
107 | types.entrySet().stream()
108 | .filter(e -> e.getValue() == DataType._class)
109 | .map(e -> new HashMap.SimpleEntry<>(e.getKey(), offsets.get(e.getKey())))
110 | .forEach(e -> classes.put(e.getKey(),rows.parallelStream().map(r -> r.get(e.getValue())).distinct().collect(Collectors.toSet())));
111 |
112 | int numberOfClasses = (int)classes.get("output").size();
113 | System.out.println("labels = " + classes);
114 |
115 | ListStringSplit input = new ListStringSplit(rows);
116 | ListStringRecordReader rr = new ListStringRecordReader();
117 | rr.initialize(input);
118 | RecordReaderDataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, outputOffset, numberOfClasses);
119 |
120 | iterator.setCollectMetaData(true); // Instruct the iterator to collect metadata, and store it in the DataSet objects
121 | DataSet allData = iterator.next();
122 | allData.shuffle(seed);
123 | SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(trainPercent); //Use 65% of data for training
124 |
125 | DataSet trainingData = testAndTrain.getTrain();
126 | DataSet testData = testAndTrain.getTest();
127 |
128 | //Normalize data as per basic CSV example
129 | // NormalizerStandardize normalizer = new NormalizerStandardize();
130 | NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler();
131 | normalizer.fitLabel(true);
132 | normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
133 | normalizer.transform(trainingData); //Apply normalization to the training data
134 | normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
135 |
136 | //Let's view the example metadata in the training and test sets:
137 | List trainMetaData = trainingData.getExampleMetaData(RecordMetaData.class);
138 | List testMetaData = testData.getExampleMetaData(RecordMetaData.class);
139 |
140 | //Let's show specifically which examples are in the training and test sets, using the collected metadata
141 | // System.out.println(" +++++ Training Set Examples MetaData +++++");
142 | // String format = "%-20s\t%s";
143 | // for(RecordMetaData recordMetaData : trainMetaData){
144 | // System.out.println(String.format(format, recordMetaData.getLocation(), recordMetaData.getURI()));
145 | // //Also available: recordMetaData.getReaderClass()
146 | // }
147 | // System.out.println("\n\n +++++ Test Set Examples MetaData +++++");
148 | // for(RecordMetaData recordMetaData : testMetaData){
149 | // System.out.println(recordMetaData.getLocation());
150 | // }
151 |
152 |
153 |
154 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
155 | .seed(seed)
156 | .iterations(1)
157 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
158 | .learningRate(learningRate)
159 | .updater(Updater.NESTEROVS).momentum(0.9)
160 | .list()
161 | .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
162 | .weightInit(WeightInit.XAVIER)
163 | .activation(Activation.RELU)
164 | .build())
165 | .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
166 | .weightInit(WeightInit.XAVIER)
167 | .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER)
168 | .nIn(numHiddenNodes).nOut(numberOfClasses).build())
169 | .pretrain(false).backprop(true).build();
170 |
171 |
172 | MultiLayerNetwork model = new MultiLayerNetwork(conf);
173 | model.init();
174 | model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
175 |
176 | for (int n = 0; n < nEpochs; n++) {
177 | model.fit(trainingData);
178 | }
179 |
180 | System.out.println("Evaluate model....");
181 | INDArray output = model.output(testData.getFeatureMatrix(),false);
182 | Evaluation eval = new Evaluation(numberOfClasses);
183 | eval.eval(testData.getLabels(), output, testMetaData); //Note we are passing in the test set metadata here
184 |
185 | List predictionErrors = eval.getPredictionErrors();
186 | System.out.println("\n\n+++++ Prediction Errors +++++");
187 | for(Prediction p : predictionErrors){
188 | System.out.printf("Predicted class: %d, Actual class: %d\t%s%n", p.getPredictedClass(), p.getActualClass(), p.getRecordMetaData(RecordMetaData.class));
189 | }
190 | //Print the evaluation statistics
191 | System.out.println(eval.stats());
192 |
193 | this.model = model;
194 | this.state = State.ready;
195 |
196 | } catch (Exception e) {
197 | throw new RuntimeException(e);
198 | }
199 | }
200 |
201 | @Override
202 | List show() {
203 | if ( state != State.ready ) throw new IllegalStateException("Model not trained yet");
204 | List result = new ArrayList<>();
205 | int layerCount = model.getnLayers();
206 | for (Layer layer : model.getLayers()) {
207 | Node node = node("Layer",
208 | "type", layer.type().name(), "index", layer.getIndex(),
209 | "pretrainLayer", layer.isPretrainLayer(), "miniBatchSize", layer.getInputMiniBatchSize(),
210 | "numParams", layer.numParams());
211 | if (layer instanceof DenseLayer) {
212 | DenseLayer dl = (DenseLayer) layer;
213 | node.addLabel(Label.label("DenseLayer"));
214 | node.setProperty("activation",dl.getActivationFn().toString()); // todo parameters
215 | node.setProperty("biasInit",dl.getBiasInit());
216 | node.setProperty("biasLearningRate",dl.getBiasLearningRate());
217 | node.setProperty("l1",dl.getL1());
218 | node.setProperty("l1Bias",dl.getL1Bias());
219 | node.setProperty("l2",dl.getL2());
220 | node.setProperty("l2Bias",dl.getL2Bias());
221 | node.setProperty("distribution",dl.getDist().toString());
222 | node.setProperty("in",dl.getNIn());
223 | node.setProperty("out",dl.getNOut());
224 | }
225 | result.add(node);
226 | // layer.preOutput(allOne, Layer.TrainingMode.TEST);
227 | // layer.p(allOne, Layer.TrainingMode.TEST);
228 | // layer.activate(allOne, Layer.TrainingMode.TEST);
229 | }
230 | return result;
231 | }
232 |
233 | private Node node(String label, Object...keyValues) {
234 | return new VirtualNode(new Label[] {Label.label(label)}, MapUtil.map(keyValues),null);
235 | }
236 | }
237 |
--------------------------------------------------------------------------------
/src/main/java/ml/EncogMLModel.java:
--------------------------------------------------------------------------------
1 | package ml;
2 |
3 | import org.encog.ml.MLRegression;
4 | import org.encog.ml.data.MLData;
5 | import org.encog.ml.data.versatile.NormalizationHelper;
6 | import org.encog.ml.data.versatile.VersatileMLDataSet;
7 | import org.encog.ml.data.versatile.columns.ColumnDefinition;
8 | import org.encog.ml.data.versatile.columns.ColumnType;
9 | import org.encog.ml.data.versatile.sources.VersatileDataSource;
10 | import org.encog.ml.factory.MLMethodFactory;
11 | import org.encog.ml.model.EncogModel;
12 | import org.encog.util.simple.EncogUtility;
13 | import org.neo4j.graphdb.Node;
14 |
15 | import java.util.*;
16 |
17 | /**
18 | * @author mh
19 | * @since 19.07.17
20 | */
21 | public class EncogMLModel extends MLModel {
22 |
23 | private EncogModel model; // todo MLMethod and decide later between regression, classification and others
24 | private MLRegression method;
25 |
26 |
27 | @Override
28 | protected String[] asRow(Map inputs, Object output) {
29 | String[] row = new String[inputs.size() + (output == null ? 0 : 1)];
30 | for (String k : inputs.keySet()) {
31 | row[offsets.get(k)] = inputs.get(k).toString();
32 | }
33 | if (output != null) {
34 | row[offsets.get(this.output)] = output.toString();
35 | }
36 | return row;
37 | }
38 |
39 |
40 | @Override
41 | protected Object doPredict(String[] line) {
42 | NormalizationHelper helper = model.getDataset().getNormHelper();
43 | MLData input = helper.allocateInputVector();
44 | helper.normalizeInputVector(line, input.getData(), false);
45 | MLData output = method.compute(input);
46 | DataType outputType = types.get(this.output);
47 | switch (outputType) {
48 | case _float : return output.getData(0);
49 | case _class: return helper.denormalizeOutputVectorToString(output)[0];
50 | default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
51 | }
52 | }
53 |
54 | @Override
55 | protected void doTrain() {
56 | VersatileMLDataSet data = new VersatileMLDataSet(new VersatileDataSource() {
57 | int idx = 0;
58 |
59 | @Override
60 | public String[] readLine() {
61 | return idx >= rows.size() ? null : rows.get(idx++);
62 | }
63 |
64 | @Override
65 | public void rewind() {
66 | idx = 0;
67 | }
68 |
69 | @Override
70 | public int columnIndex(String s) {
71 | return offsets.get(s);
72 | }
73 | });
74 | offsets.entrySet().stream().sorted(Comparator.comparingInt(Map.Entry::getValue)).forEach(e -> {
75 | String k = e.getKey();
76 | ColumnDefinition col = data.defineSourceColumn(k, offsets.get(k), typeOf(types.get(k))); // todo has bug, doesn't work like that, cols have to be in index order
77 | if (k.equals(output)) {
78 | data.defineOutput(col);
79 | } else {
80 | data.defineInput(col);
81 | }
82 | });
83 | // types.forEach((k, v) -> {
84 | // ColumnDefinition col = data.defineSourceColumn(k, offsets.get(k), v); // todo has bug, doesn't work like that, cols have to be in index order
85 | // if (k.equals(output)) {
86 | // data.defineOutput(col);
87 | // } else {
88 | // data.defineInput(col);
89 | // }
90 | // });
91 | // Analyze the data, determine the min/max/mean/sd of every column.
92 | data.analyze();
93 |
94 | // Create feedforward neural network as the model type. MLMethodFactory.TYPE_FEEDFORWARD.
95 | // You could also other model types, such as:
96 | // MLMethodFactory.SVM: Support Vector Machine (SVM)
97 | // MLMethodFactory.TYPE_RBFNETWORK: RBF Neural Network
98 | // MLMethodFactor.TYPE_NEAT: NEAT Neural Network
99 | // MLMethodFactor.TYPE_PNN: Probabilistic Neural Network
100 | EncogModel model = new EncogModel(data);
101 | model.selectMethod(data, methodFor(methodName)); // todo from config
102 | // Send any output to the console.
103 | // model.setReport(new ConsoleStatusReportable());
104 |
105 | // Now normalize the data. Encog will automatically determine the correct normalization
106 | // type based on the model you chose in the last step.
107 | data.normalize();
108 |
109 | // Hold back some data for a final validation.
110 | // Shuffle the data into a random ordering.
111 | // Use a seed of 1001 so that we always use the same holdback and will get more consistent results.
112 | model.holdBackValidation(0.3, true, 1001); // todo from config
113 |
114 | // Choose whatever is the default training type for this model.
115 | model.selectTrainingType(data);
116 |
117 | // Use a 5-fold cross-validated train. Return the best method found.
118 | MLRegression bestMethod = (MLRegression) model.crossvalidate(5, true); // todo from config
119 | // MLRegression vs. MLClassification
120 |
121 | // Display the training and validation errors.
122 | // System.out.println("Training error: " + EncogUtility.calculateRegressionError(bestMethod, model.getTrainingDataset()));
123 | // System.out.println("Validation error: " + EncogUtility.calculateRegressionError(bestMethod, model.getValidationDataset()));
124 |
125 | // Display our normalization parameters.
126 | // NormalizationHelper helper = data.getNormHelper();
127 | // System.out.println(helper.toString());
128 |
129 | // Display the final model.
130 | // System.out.println("Final model: " + bestMethod);
131 | this.model = model;
132 | this.method = bestMethod;
133 | this.state = State.ready;
134 | }
135 |
136 | private String methodFor(Method method) {
137 | switch (method) {
138 | case ffd: return MLMethodFactory.TYPE_FEEDFORWARD;
139 | case svm: return MLMethodFactory.TYPE_SVM;
140 | case rbf: return MLMethodFactory.TYPE_RBFNETWORK;
141 | case neat: return MLMethodFactory.TYPE_NEAT;
142 | case pnn: return MLMethodFactory.TYPE_PNN;
143 | default: throw new IllegalArgumentException("Unknown method "+method);
144 | }
145 | }
146 |
147 | /*
148 | nominal,
149 | ordinal,
150 | continuous,
151 | ignore
152 | */
153 | private ColumnType typeOf(DataType type) {
154 | switch (type) {
155 | case _class:
156 | return ColumnType.nominal;
157 | case _float:
158 | return ColumnType.continuous;
159 | case _order:
160 | return ColumnType.ordinal;
161 | default:
162 | throw new IllegalArgumentException("Unknown type: " + type);
163 | }
164 | }
165 |
166 | public EncogMLModel(String name, Map types, String output, Map config) {
167 | super(name, types, output, config);
168 | }
169 |
170 |
171 | @Override
172 | protected ML.ModelResult resultWithInfo(ML.ModelResult result) {
173 | return result.withInfo(
174 | "trainingError", EncogUtility.calculateRegressionError(method, model.getTrainingDataset()),
175 | "validationError",EncogUtility.calculateRegressionError(method, model.getValidationDataset()),
176 | "selectedMethod",method.toString(),
177 | "normalization",model.getDataset().getNormHelper().toString()
178 | );
179 | }
180 | }
181 |
--------------------------------------------------------------------------------
/src/main/java/ml/LoadTensorFlow.java:
--------------------------------------------------------------------------------
1 | package ml;
2 |
3 | import org.neo4j.graphdb.GraphDatabaseService;
4 | import org.neo4j.graphdb.Label;
5 | import org.neo4j.graphdb.Node;
6 | import org.neo4j.graphdb.RelationshipType;
7 | import org.neo4j.procedure.Context;
8 | import org.neo4j.procedure.Mode;
9 | import org.neo4j.procedure.Name;
10 | import org.neo4j.procedure.Procedure;
11 | import org.tensorflow.framework.AttrValue;
12 | import org.tensorflow.framework.GraphDef;
13 | import org.tensorflow.framework.NodeDef;
14 |
15 | import java.io.BufferedInputStream;
16 | import java.io.IOException;
17 | import java.net.URL;
18 | import java.util.HashMap;
19 | import java.util.Map;
20 | import java.util.stream.Stream;
21 |
22 | /**
23 | * @author mh
24 | * @since 26.07.17
25 | * see: https://www.tensorflow.org/extend/tool_developers/
26 | * see: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
27 | * https://developers.google.com/protocol-buffers/docs/javatutorial
28 | * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/testdata/half_plus_two/00000123/saved_model.pb
29 | *
30 | */
31 | public class LoadTensorFlow {
32 |
33 | @Context
34 | public GraphDatabaseService db;
35 |
36 | enum Types implements Label {
37 | Neuron
38 | }
39 |
40 | enum RelTypes implements RelationshipType {
41 | INPUT
42 | }
43 |
44 | public static class LoadResult {
45 | public String modelName;
46 | public String type;
47 | public long nodes;
48 | public long relationships;
49 |
50 | public LoadResult(String modelName, String type, long nodes, long relationships) {
51 | this.modelName = modelName;
52 | this.type = type;
53 | this.nodes = nodes;
54 | this.relationships = relationships;
55 | }
56 | }
57 | @Procedure(value = "load.tensorflow", mode = Mode.WRITE)
58 | public Stream loadTensorFlow(@Name("file") String url) throws IOException {
59 | GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new URL(url).openStream()));
60 | Map nodes = new HashMap<>();
61 | // tod model node, layer nodes
62 | for (NodeDef nodeDef : graphDef.getNodeList()) {
63 | Node node = db.createNode(Types.Neuron);
64 | node.setProperty("name", nodeDef.getName());
65 | if (nodeDef.getDevice() != null) node.setProperty("device", nodeDef.getDevice());
66 | node.setProperty("op", nodeDef.getOp());
67 | nodeDef.getAttrMap().forEach((k, v) -> {
68 | Object value = getValue(v);
69 | if (value != null) {
70 | node.setProperty(k, value);
71 | }
72 | });
73 | nodes.put(nodeDef.getName(), node);
74 | }
75 | long rels = 0;
76 | for (NodeDef nodeDef : graphDef.getNodeList()) {
77 | Node target = nodes.get(nodeDef.getName());
78 | nodeDef.getInputList().forEach(name -> nodes.get(name).createRelationshipTo(target, RelTypes.INPUT));
79 | // todo weights
80 | rels += nodeDef.getInputCount();
81 | }
82 | return Stream.of(new LoadResult(url,"tensorflow",nodes.size(), rels));
83 | }
84 |
85 | private Object getValue(AttrValue v) {
86 | switch (v.getValueCase()) {
87 | case S:
88 | return v.getS().toStringUtf8();
89 | case I:
90 | return v.getI();
91 | case F:
92 | return v.getF();
93 | case B:
94 | return v.getB();
95 | case TYPE:
96 | return v.getType().name(); // todo
97 | case SHAPE:
98 | return v.getShape().toString(); // tdo
99 | case TENSOR:
100 | return v.getTensor().toString(); // todo handle with prefxied properties
101 | case LIST:
102 | return v.getList().toString(); // todo getType/Count(idx) and then handle each type with prefixed property
103 | case FUNC:
104 | return v.getFunc().getAttrMap().toString(); // todo handle recursively
105 | case PLACEHOLDER:
106 | break;
107 | case VALUE_NOT_SET:
108 | return null;
109 | default:
110 | return null;
111 | }
112 | return null;
113 | }
114 |
115 | }
116 |
--------------------------------------------------------------------------------
/src/main/java/ml/ML.java:
--------------------------------------------------------------------------------
1 | package ml;
2 |
3 | import org.neo4j.graphdb.GraphDatabaseService;
4 | import org.neo4j.graphdb.Node;
5 | import org.neo4j.logging.Log;
6 | import org.neo4j.procedure.Context;
7 | import org.neo4j.procedure.Name;
8 | import org.neo4j.procedure.Procedure;
9 |
10 | import java.util.*;
11 | import java.util.stream.Stream;
12 |
13 | public class ML {
14 | @Context
15 | public GraphDatabaseService db;
16 |
17 | @Context
18 | public Log log;
19 |
20 | /*
21 |
22 | apoc.ml.create.classifier({types},['output'],{config}) yield model
23 | apoc.ml.create.regression({params},{config}) yield model
24 |
25 | apoc.ml.train(model, {params}, prediction)
26 |
27 | apoc.ml.classify(model, {params}) yield prediction:string, confidence:float
28 | apoc.ml.regression(model, {params}) yield prediction:float, confidence:float
29 |
30 | apoc.ml.delete(model) yield model
31 |
32 | */
33 |
34 | @Procedure
35 | public Stream create(@Name("model") String model, @Name("types") Map types, @Name(value="output") String output, @Name(value="params",defaultValue="{}") Map config) {
36 | return Stream.of(MLModel.create(model,types,output,config).asResult());
37 | }
38 |
39 | @Procedure
40 | public Stream info(@Name("model") String model) {
41 | return Stream.of(MLModel.from(model).asResult());
42 | }
43 |
44 | @Procedure
45 | public Stream remove(@Name("model") String model) {
46 | return Stream.of(MLModel.remove(model));
47 | }
48 |
49 | @Procedure
50 | public Stream add(@Name("model") String model, @Name("inputs") Map inputs, @Name("outputs") Object output) {
51 | MLModel mlModel = MLModel.from(model);
52 | mlModel.add(inputs,output);
53 | return Stream.of(mlModel.asResult());
54 | }
55 |
56 | @Procedure
57 | public Stream train(@Name("model") String model) {
58 | MLModel mlModel = MLModel.from(model);
59 | mlModel.train();
60 | return Stream.of(mlModel.asResult());
61 | }
62 |
63 | public static class NodeResult {
64 | public final Node node;
65 |
66 | public NodeResult(Node node) {
67 | this.node = node;
68 | }
69 | }
70 | @Procedure
71 | public Stream show(@Name("model") String model) {
72 | List show = MLModel.from(model).show();
73 | return show.stream().map(NodeResult::new);
74 | }
75 |
76 | @Procedure
77 | public Stream predict(@Name("model") String model, @Name("inputs") Map inputs) {
78 | MLModel mlModel = MLModel.from(model);
79 | Object value = mlModel.predict(inputs);
80 | double confidence = 0.0d;
81 | return Stream.of(new PredictionResult(value, confidence));
82 | }
83 |
84 | public static class PredictionResult {
85 | public Object value;
86 | public double confidence;
87 |
88 | public PredictionResult(Object value, double confidence) {
89 | this.value = value;
90 | this.confidence = confidence;
91 | }
92 | }
93 |
94 | public static class ModelResult {
95 | public final String model;
96 | public final String state;
97 | public final Map info = new HashMap<>();
98 |
99 | public ModelResult(String model, EncogMLModel.State state) {
100 | this.model = model;
101 | this.state = state.name();
102 | }
103 |
104 | ModelResult withInfo(Object...infos) {
105 | for (int i = 0; i < infos.length; i+=2) {
106 | info.put(infos[i].toString(),infos[i+1]);
107 | }
108 | return this;
109 | }
110 | }
111 |
112 | }
113 |
--------------------------------------------------------------------------------
/src/main/java/ml/MLModel.java:
--------------------------------------------------------------------------------
1 | package ml;
2 |
3 | import org.encog.ml.factory.MLMethodFactory;
4 | import org.neo4j.graphdb.Node;
5 |
6 | import java.util.*;
7 | import java.util.concurrent.ConcurrentHashMap;
8 |
9 | /**
10 | * @author mh
11 | * @since 23.07.17
12 | */
13 | public abstract class MLModel {
14 |
15 | static class Config {
16 | private final Map config;
17 |
18 | public Config(Map config) {
19 | this.config = config;
20 | }
21 |
22 | class V {
23 | private final String name;
24 | private final T defaultValue;
25 |
26 | V(String name, T defaultValue) {
27 | this.name = name;
28 | this.defaultValue = defaultValue;
29 | }
30 |
31 | T get(T defaultValue) {
32 | Object value = config.get(name);
33 | if (value == null) return defaultValue;
34 | if (defaultValue instanceof Double) return (T) (Object) ((Number) value).doubleValue();
35 | if (defaultValue instanceof Integer) return (T) (Object) ((Number) value).intValue();
36 | if (defaultValue instanceof Long) return (T) (Object) ((Number) value).longValue();
37 | if (defaultValue instanceof String) return (T) value.toString();
38 | return (T) value;
39 | }
40 |
41 | T get() {
42 | return get(defaultValue);
43 | }
44 | }
45 |
46 | public final V seed = new V<>("seed", 123L);
47 | public final V learningRate = new V<>("learningRate", 0.01d);
48 | public final V epochs = new V<>("epochs", 50);
49 | public final V hidden = new V<>("hidden", 20);
50 | public final V trainPercent = new V<>("trainPercent", 0.75d);
51 | }
52 |
53 | static ConcurrentHashMap models = new ConcurrentHashMap<>();
54 | final String name;
55 | final Map types = new HashMap<>();
56 | final Map offsets = new HashMap<>();
57 | final String output;
58 | final Config config;
59 | final List rows = new ArrayList<>();
60 | State state;
61 | Method methodName;
62 |
63 | public MLModel(String name, Map types, String output, Map config) {
64 | if (models.containsKey(name))
65 | throw new IllegalArgumentException("Model " + name + " already exists, please remove first");
66 |
67 | this.name = name;
68 | this.state = State.created;
69 | this.output = output;
70 | this.config = new Config(config);
71 | initTypes(types, output);
72 |
73 | this.methodName = Method.ffd;
74 |
75 | models.put(name, this);
76 |
77 | }
78 |
79 | protected void initTypes(Map types, String output) {
80 | if (!types.containsKey(output)) throw new IllegalArgumentException("Outputs not defined: " + output);
81 | int i = 0;
82 | for (Map.Entry entry : types.entrySet()) {
83 | String key = entry.getKey();
84 | this.types.put(key, DataType.from(entry.getValue()));
85 | if (!key.equals(output)) this.offsets.put(key, i++);
86 | }
87 | this.offsets.put(output, i);
88 | }
89 |
90 | public static ML.ModelResult remove(String model) {
91 | MLModel existing = models.remove(model);
92 | return new ML.ModelResult(model, existing == null ? State.unknown : State.removed);
93 | }
94 |
95 | public static MLModel> from(String name) {
96 | MLModel model = models.get(name);
97 | if (model != null) return model;
98 | throw new IllegalArgumentException("No valid ML-Model " + name);
99 | }
100 |
101 | public void add(Map inputs, Object output) {
102 | if (this.state == State.created || this.state == State.training) {
103 | rows.add(asRow(inputs, output));
104 | this.state = State.training;
105 | } else {
106 | throw new IllegalArgumentException(String.format("Model %s not able to accept training data, state is: %s", name, state));
107 | }
108 | }
109 |
110 | protected abstract ROW asRow(Map inputs, Object output);
111 |
112 | public void train() {
113 | if (state != State.ready) {
114 | if (state != State.training) {
115 | throw new IllegalArgumentException(String.format("Model %s is not ready to predict, it has no training data, state is %s", name, state));
116 | }
117 | doTrain();
118 | }
119 | }
120 |
121 | public Object predict(Map inputs) {
122 | if (state != State.ready) {
123 | train();
124 | }
125 | if (state == State.ready) {
126 | ROW line = asRow(inputs, null);
127 |
128 | Object predicted = doPredict(line);
129 | // todo confidence
130 | return predicted;
131 | } else {
132 | throw new IllegalArgumentException(String.format("Model %s is not ready to predict, state is %s", name, state));
133 | }
134 | }
135 |
136 | protected abstract Object doPredict(ROW line);
137 |
138 | protected abstract void doTrain();
139 |
140 | public ML.ModelResult asResult() {
141 | ML.ModelResult result =
142 | new ML.ModelResult(this.name, this.state)
143 | .withInfo("methodName", methodName);
144 |
145 | if (rows.size() > 0) {
146 | result = result.withInfo("trainingSets", (long) rows.size());
147 | }
148 | if (state == State.ready) {
149 | // todo check how expensive this is
150 | result = resultWithInfo(result);
151 | }
152 | return result;
153 | }
154 |
155 | protected ML.ModelResult resultWithInfo(ML.ModelResult result) {
156 | return result;
157 | }
158 |
159 | ;
160 |
161 | public static MLModel create(String name, Map types, String output, Map config) {
162 | String framework = config.getOrDefault("framework", "encog").toString().toLowerCase();
163 | switch (framework) {
164 | case "encog":
165 | return new EncogMLModel(name, types, output, config);
166 | case "dl4j":
167 | return new DL4JMLModel(name, types, output, config);
168 | default:
169 | throw new IllegalArgumentException("Unknown framework: " + framework);
170 | }
171 | }
172 |
173 | enum Method {
174 | ffd, svm, rbf, neat, pnn;
175 | }
176 |
177 | enum DataType {
178 | _class, _float, _order;
179 |
180 | public static DataType from(String type) {
181 | switch (type.toUpperCase()) {
182 | case "CLASS":
183 | return DataType._class;
184 | case "FLOAT":
185 | return DataType._float;
186 | case "ORDER":
187 | return DataType._order;
188 | default:
189 | throw new IllegalArgumentException("Unknown type: " + type);
190 | }
191 | }
192 | }
193 |
194 | public enum State {created, training, ready, removed, unknown}
195 |
196 | List show() {
197 | return Collections.emptyList();
198 | }
199 | }
200 |
--------------------------------------------------------------------------------
/src/main/java/result/VirtualNode.java:
--------------------------------------------------------------------------------
1 | package result;
2 |
3 | import org.neo4j.graphdb.*;
4 | import org.neo4j.helpers.collection.FilteringIterable;
5 | import org.neo4j.helpers.collection.Iterables;
6 |
7 | import java.util.*;
8 | import java.util.concurrent.atomic.AtomicLong;
9 | import java.util.stream.Collectors;
10 |
11 | import static java.util.Arrays.asList;
12 |
13 | /**
14 | * @author mh
15 | * @since 10.08.17
16 | */
17 | public class VirtualNode implements Node {
18 | private static AtomicLong MIN_ID = new AtomicLong(-1);
19 | private final List