├── .gitignore ├── .travis.yml ├── LICENSE ├── README.adoc ├── pom.xml └── src ├── main └── java │ ├── ml │ ├── DL4JMLModel.java │ ├── EncogMLModel.java │ ├── LoadTensorFlow.java │ ├── ML.java │ └── MLModel.java │ └── result │ ├── VirtualNode.java │ └── VirtualRelationship.java └── test ├── java └── ml │ ├── DL4JTest.java │ ├── DL4JXORHelloWorld.java │ ├── HelloTF.java │ ├── IrisClassification.java │ ├── LoadTest.java │ ├── MLProcedureTest.java │ ├── MLTest.java │ └── XORHelloWorld.java └── resources ├── iris.csv ├── linear_data_eval.csv ├── linear_data_train.csv ├── saved_model.pb └── tensorflow_example.pbtxt /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | \#* 3 | target 4 | out 5 | .project 6 | .classpath 7 | .settings 8 | .externalToolBuilders/ 9 | .scala_dependencies 10 | .factorypath 11 | .cache 12 | .cache-main 13 | .cache-tests 14 | *.iws 15 | *.ipr 16 | *.iml 17 | .idea 18 | .DS_Store 19 | .shell_history 20 | .mailmap 21 | .java-version 22 | .cache-main 23 | .cache-tests 24 | Thumbs.db 25 | .cache-main 26 | .cache-tests 27 | neo4j-home 28 | dependency-reduced-pom.xml 29 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | 3 | jdk: 4 | - oraclejdk8 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.adoc: -------------------------------------------------------------------------------- 1 | = Neo4j Machine Learning Procedures (WIP) 2 | 3 | image:https://travis-ci.org/neo4j-contrib/neo4j-ml-procedures.svg?branch=3.1["Build state", link="https://travis-ci.org/neo4j-contrib/neo4j-ml-procedures"] 4 | 5 | This project provides procedures and functions to support machine learning applications with Neo4j. 6 | 7 | [NOTE] 8 | This project requires Neo4j 3.2.x 9 | 10 | Thanks a lot to https://github.com/encog/encog-java-core[Encog for the great library^] and to https://www.stardog.com/docs/#_machine_learning[Stardog for the idea^]. 11 | 12 | == Installation 13 | 14 | 1. Download the jar from the https://github.com/neo4j-contrib/neo4j-ml-procedures/releases/latest[latest release] or build it locally 15 | 2. Copy it into your `$NEO4J_HOME/plugins` directory. 16 | 3. Restart your server. 17 | 18 | == Built in classification and regression (WIP) 19 | 20 | [source,cypher] 21 | ---- 22 | CALL ml.create("model",{types},"output",{config}) YIELD model, state, info 23 | 24 | CALL ml.add("model", {inputs}, given) YIELD model, state, info 25 | 26 | CALL ml.train() YIELD model, state, info 27 | 28 | CALL ml.predict("model", {inputs}) YIELD value [, confidence] 29 | 30 | CALL ml.remove(model) YIELD model, state 31 | ---- 32 | 33 | Example: IRIS Classification from Encog 34 | 35 | [source,cypher] 36 | ---- 37 | CALL ml.create("iris", 38 | {sepalLength: "float", sepalWidth: "float", petalLength: "float", petalWidth: "float", kind: "class"}, "kind",{}); 39 | ---- 40 | 41 | [source,cypher] 42 | ---- 43 | LOAD CSV FROM "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" AS row 44 | CALL ml.add('iris', {sepalLength: row[0], sepalWidth: row[1], petalLength: row[2], petalWidth: row[3]}, row[4]) 45 | YIELD state 46 | WITH collect(distinct state) as states, collect(distinct row[4]) as kinds 47 | 48 | CALL ml.train('iris') YIELD state, info 49 | 50 | RETURN state, states, kinds, info; 51 | ---- 52 | 53 | ---- 54 | ╒═══════╤════════════╤══════════════════════════════════════════════════╤══════════════════════════════════════════════════════════════════════╕ 55 | │"state"│"states" │"kinds" │"info" │ 56 | ╞═══════╪════════════╪══════════════════════════════════════════════════╪══════════════════════════════════════════════════════════════════════╡ 57 | │"ready"│["training"]│["Iris-setosa","Iris-versicolor","Iris-virginica"]│{"trainingSets":150,"methodName":"feedforward","normalization":"[Norma│ 58 | │ │ │ │lizationHelper:\n[ColumnDefinition:sepalLength(continuous);low=4,30000│ 59 | │ │ │ │0,high=7,900000,mean=5,843333,sd=0,825301]\n[ColumnDefinition:petalLen│ 60 | │ │ │ │gth(continuous);low=1,000000,high=6,900000,mean=3,758667,sd=1,758529]\│ 61 | │ │ │ │n[ColumnDefinition:sepalWidth(continuous);low=2,000000,high=4,400000,m│ 62 | │ │ │ │ean=3,054000,sd=0,432147]\n[ColumnDefinition:petalWidth(continuous);lo│ 63 | │ │ │ │w=0,100000,high=2,500000,mean=1,198667,sd=0,760613]\n[ColumnDefinition│ 64 | │ │ │ │:kind(nominal);[Iris-setosa, Iris-versicolor, Iris-virginica]]\n]","tr│ 65 | │ │ │ │ainingError":0.034672103747075696,"selectedMethod":"[BasicNetwork: Lay│ 66 | │ │ │ │ers=3]","validationError":0.05766172747088482} │ 67 | └───────┴────────────┴──────────────────────────────────────────────────┴──────────────────────────────────────────────────────────────────────┘ 68 | ---- 69 | 70 | [source,cypher] 71 | ---- 72 | LOAD CSV FROM "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" AS row 73 | WITH state, info, row limit 10 74 | 75 | CALL ml.predict("iris", {sepalLength: row[0], sepalWidth: row[1], petalLength: row[2], petalWidth: row[3]}) 76 | YIELD value as prediction 77 | 78 | RETURN row[4] as correct, prediction, state, row; 79 | ---- 80 | 81 | ---- 82 | ╒═════════════╤═════════════╤═══════╤═══════════════════════════════════════╕ 83 | │"correct" │"prediction" │"state"│"row" │ 84 | ╞═════════════╪═════════════╪═══════╪═══════════════════════════════════════╡ 85 | │"Iris-setosa"│"Iris-setosa"│"ready"│["5.1","3.5","1.4","0.2","Iris-setosa"]│ 86 | ├─────────────┼─────────────┼───────┼───────────────────────────────────────┤ 87 | │"Iris-setosa"│"Iris-setosa"│"ready"│["4.9","3.0","1.4","0.2","Iris-setosa"]│ 88 | ├─────────────┼─────────────┼───────┼───────────────────────────────────────┤ 89 | │"Iris-setosa"│"Iris-setosa"│"ready"│["4.7","3.2","1.3","0.2","Iris-setosa"]│ 90 | ├─────────────┼─────────────┼───────┼───────────────────────────────────────┤ 91 | ... 92 | ---- 93 | 94 | [source,cypher] 95 | ---- 96 | CALL ml.remove('iris') YIELD model, state; 97 | ---- 98 | 99 | == Manual neural network operations (TODO) 100 | 101 | .Procedures 102 | [source,cypher] 103 | ---- 104 | apoc.ml.propagate(network, {inputs}) yield {outputs} 105 | apoc.ml.backprop(network, {output}) yield {network} 106 | ---- 107 | 108 | .Functions 109 | [source,cypher] 110 | ---- 111 | apoc.ml.sigmoid 112 | apoc.ml.sigmoidPrime 113 | ---- 114 | 115 | Future plans include storing networks from the common machine learning libraries (TensorFlow, Deeplearning4j, Encog etc.) as executable Network structures in Neo4j. 116 | 117 | == Building it yourself 118 | 119 | This project uses maven, to build a jar-file with the procedure in this 120 | project, simply package the project with maven: 121 | 122 | mvn clean package 123 | 124 | This will produce a jar-file,`target/neo4j-ml-procedures-*-SNAPSHOT.jar`, that can be copied in the `$NEO4J_HOME/plugins` directory of your Neo4j instance. 125 | 126 | == License 127 | 128 | Apache License V2, see LICENSE 129 | 130 | == Next Steps 131 | 132 | * Normalization / Classification for dl4j 133 | * Push frameworks to separate modules 134 | * Support external frameworks 135 | * Store / Load models from graph 136 | * Expose simple ML functions / propagation/ backprop on graph structures as procedures and functions 137 | * Load PMML 138 | * More fine-grained configuration (JSON) for Networks 139 | * K-Means, Classification, Regression 140 | * Spark ML-Lib 141 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | 6 | org.neo4j.procedures 7 | neo4j-ml-procedures 8 | 3.2.2.1-SNAPSHOT 9 | 10 | jar 11 | Neo4j ML Procedures 12 | 13 | 14 | UTF-8 15 | 1.8 16 | ${encoding} 17 | ${encoding} 18 | ${java.version} 19 | ${java.version} 20 | 3.2.2 21 | 22 | 23 | 24 | 25 | org.neo4j 26 | neo4j 27 | ${neo4j.version} 28 | provided 29 | 30 | 31 | 32 | org.neo4j 33 | neo4j-kernel 34 | ${neo4j.version} 35 | test 36 | test-jar 37 | 38 | 39 | org.neo4j 40 | neo4j-io 41 | ${neo4j.version} 42 | test 43 | test-jar 44 | 45 | 46 | 47 | net.biville.florent 48 | neo4j-sproc-compiler 49 | 1.2 50 | true 51 | provided 52 | 53 | 54 | 55 | org.apache.commons 56 | commons-math3 57 | 3.4.1 58 | 59 | 60 | org.apache.commons 61 | commons-lang3 62 | 3.3.1 63 | 64 | 65 | 66 | org.encog 67 | encog-core 68 | 3.3.0 69 | 70 | 71 | org.apache.commons 72 | commons-math3 73 | 74 | 75 | 76 | 77 | 78 | org.deeplearning4j 79 | deeplearning4j-core 80 | 0.9.1 81 | 82 | 83 | org.nd4j 84 | nd4j-native-platform 85 | 0.9.1 86 | 87 | 88 | 89 | org.tensorflow 90 | tensorflow 91 | 1.4.0 92 | 93 | 94 | org.tensorflow 95 | proto 96 | 1.4.0 97 | 98 | 99 | 100 | junit 101 | junit 102 | 4.12 103 | test 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | maven-shade-plugin 112 | 2.4.3 113 | 114 | 115 | package 116 | 117 | shade 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /src/main/java/ml/DL4JMLModel.java: -------------------------------------------------------------------------------- 1 | package ml; 2 | 3 | import org.datavec.api.records.metadata.RecordMetaData; 4 | import org.datavec.api.records.reader.impl.collection.ListStringRecordReader; 5 | import org.datavec.api.split.ListStringSplit; 6 | import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; 7 | import org.deeplearning4j.eval.Evaluation; 8 | import org.deeplearning4j.eval.meta.Prediction; 9 | import org.deeplearning4j.nn.api.Layer; 10 | import org.deeplearning4j.nn.api.OptimizationAlgorithm; 11 | import org.deeplearning4j.nn.conf.MultiLayerConfiguration; 12 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 13 | import org.deeplearning4j.nn.conf.Updater; 14 | import org.deeplearning4j.nn.conf.layers.DenseLayer; 15 | import org.deeplearning4j.nn.conf.layers.OutputLayer; 16 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; 17 | import org.deeplearning4j.nn.weights.WeightInit; 18 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener; 19 | import org.nd4j.linalg.activations.Activation; 20 | import org.nd4j.linalg.api.ndarray.INDArray; 21 | import org.nd4j.linalg.checkutil.NDArrayCreationUtil; 22 | import org.nd4j.linalg.cpu.nativecpu.CpuNDArrayFactory; 23 | import org.nd4j.linalg.dataset.DataSet; 24 | import org.nd4j.linalg.dataset.SplitTestAndTrain; 25 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; 26 | import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; 27 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; 28 | import org.nd4j.linalg.util.NDArrayUtil; 29 | import org.neo4j.graphdb.Label; 30 | import org.neo4j.graphdb.Node; 31 | import org.neo4j.helpers.collection.MapUtil; 32 | import result.VirtualNode; 33 | 34 | import java.util.*; 35 | import java.util.function.Function; 36 | import java.util.stream.Collectors; 37 | 38 | /** 39 | * @author mh 40 | * @since 23.07.17 41 | */ 42 | public class DL4JMLModel extends MLModel> { 43 | private MultiLayerNetwork model; 44 | 45 | public DL4JMLModel(String name, Map types, String output, Map config) { 46 | super(name, types, output, config); 47 | } 48 | 49 | @Override 50 | protected List asRow(Map inputs, Object output) { 51 | List row = new ArrayList<>(inputs.size() + (output == null ? 0 : 1)); 52 | for (String k : inputs.keySet()) { 53 | row.add(offsets.get(k), inputs.get(k).toString()); 54 | } 55 | if (output != null) { 56 | row.add(offsets.get(this.output), output.toString()); 57 | } 58 | return row; 59 | } 60 | 61 | @Override 62 | protected Object doPredict(List line) { 63 | try { 64 | ListStringSplit input = new ListStringSplit(Collections.singletonList(line)); 65 | ListStringRecordReader rr = new ListStringRecordReader(); 66 | rr.initialize(input); 67 | DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1); 68 | 69 | DataSet ds = iterator.next(); 70 | INDArray prediction = model.output(ds.getFeatures()); 71 | 72 | DataType outputType = types.get(this.output); 73 | switch (outputType) { 74 | case _float : return prediction.getDouble(0); 75 | case _class: { 76 | int numClasses = 2; 77 | double max = 0; 78 | int maxIndex = -1; 79 | for (int i=0;i max) {maxIndex = i; max = prediction.getDouble(i);} 81 | } 82 | return maxIndex; 83 | // return prediction.getInt(0,1); // numberOfClasses 84 | } 85 | default: throw new IllegalArgumentException("Output type not yet supported "+outputType); 86 | } 87 | } catch (Exception e) { 88 | throw new RuntimeException(e); 89 | } 90 | } 91 | 92 | @Override 93 | protected void doTrain() { 94 | try { 95 | long seed = config.seed.get(); 96 | double learningRate = config.learningRate.get(); 97 | int nEpochs = config.epochs.get(); 98 | 99 | int numOutputs = 1; 100 | int numInputs = types.size() - numOutputs; 101 | int outputOffset = offsets.get(output); // last column 102 | int numHiddenNodes = config.hidden.get(); 103 | double trainPercent = config.trainPercent.get(); 104 | int batchSize = rows.size(); // full dataset size 105 | 106 | Map> classes = new HashMap<>(); 107 | types.entrySet().stream() 108 | .filter(e -> e.getValue() == DataType._class) 109 | .map(e -> new HashMap.SimpleEntry<>(e.getKey(), offsets.get(e.getKey()))) 110 | .forEach(e -> classes.put(e.getKey(),rows.parallelStream().map(r -> r.get(e.getValue())).distinct().collect(Collectors.toSet()))); 111 | 112 | int numberOfClasses = (int)classes.get("output").size(); 113 | System.out.println("labels = " + classes); 114 | 115 | ListStringSplit input = new ListStringSplit(rows); 116 | ListStringRecordReader rr = new ListStringRecordReader(); 117 | rr.initialize(input); 118 | RecordReaderDataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, outputOffset, numberOfClasses); 119 | 120 | iterator.setCollectMetaData(true); // Instruct the iterator to collect metadata, and store it in the DataSet objects 121 | DataSet allData = iterator.next(); 122 | allData.shuffle(seed); 123 | SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(trainPercent); //Use 65% of data for training 124 | 125 | DataSet trainingData = testAndTrain.getTrain(); 126 | DataSet testData = testAndTrain.getTest(); 127 | 128 | //Normalize data as per basic CSV example 129 | // NormalizerStandardize normalizer = new NormalizerStandardize(); 130 | NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(); 131 | normalizer.fitLabel(true); 132 | normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data 133 | normalizer.transform(trainingData); //Apply normalization to the training data 134 | normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set 135 | 136 | //Let's view the example metadata in the training and test sets: 137 | List trainMetaData = trainingData.getExampleMetaData(RecordMetaData.class); 138 | List testMetaData = testData.getExampleMetaData(RecordMetaData.class); 139 | 140 | //Let's show specifically which examples are in the training and test sets, using the collected metadata 141 | // System.out.println(" +++++ Training Set Examples MetaData +++++"); 142 | // String format = "%-20s\t%s"; 143 | // for(RecordMetaData recordMetaData : trainMetaData){ 144 | // System.out.println(String.format(format, recordMetaData.getLocation(), recordMetaData.getURI())); 145 | // //Also available: recordMetaData.getReaderClass() 146 | // } 147 | // System.out.println("\n\n +++++ Test Set Examples MetaData +++++"); 148 | // for(RecordMetaData recordMetaData : testMetaData){ 149 | // System.out.println(recordMetaData.getLocation()); 150 | // } 151 | 152 | 153 | 154 | MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 155 | .seed(seed) 156 | .iterations(1) 157 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 158 | .learningRate(learningRate) 159 | .updater(Updater.NESTEROVS).momentum(0.9) 160 | .list() 161 | .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) 162 | .weightInit(WeightInit.XAVIER) 163 | .activation(Activation.RELU) 164 | .build()) 165 | .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) 166 | .weightInit(WeightInit.XAVIER) 167 | .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER) 168 | .nIn(numHiddenNodes).nOut(numberOfClasses).build()) 169 | .pretrain(false).backprop(true).build(); 170 | 171 | 172 | MultiLayerNetwork model = new MultiLayerNetwork(conf); 173 | model.init(); 174 | model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates 175 | 176 | for (int n = 0; n < nEpochs; n++) { 177 | model.fit(trainingData); 178 | } 179 | 180 | System.out.println("Evaluate model...."); 181 | INDArray output = model.output(testData.getFeatureMatrix(),false); 182 | Evaluation eval = new Evaluation(numberOfClasses); 183 | eval.eval(testData.getLabels(), output, testMetaData); //Note we are passing in the test set metadata here 184 | 185 | List predictionErrors = eval.getPredictionErrors(); 186 | System.out.println("\n\n+++++ Prediction Errors +++++"); 187 | for(Prediction p : predictionErrors){ 188 | System.out.printf("Predicted class: %d, Actual class: %d\t%s%n", p.getPredictedClass(), p.getActualClass(), p.getRecordMetaData(RecordMetaData.class)); 189 | } 190 | //Print the evaluation statistics 191 | System.out.println(eval.stats()); 192 | 193 | this.model = model; 194 | this.state = State.ready; 195 | 196 | } catch (Exception e) { 197 | throw new RuntimeException(e); 198 | } 199 | } 200 | 201 | @Override 202 | List show() { 203 | if ( state != State.ready ) throw new IllegalStateException("Model not trained yet"); 204 | List result = new ArrayList<>(); 205 | int layerCount = model.getnLayers(); 206 | for (Layer layer : model.getLayers()) { 207 | Node node = node("Layer", 208 | "type", layer.type().name(), "index", layer.getIndex(), 209 | "pretrainLayer", layer.isPretrainLayer(), "miniBatchSize", layer.getInputMiniBatchSize(), 210 | "numParams", layer.numParams()); 211 | if (layer instanceof DenseLayer) { 212 | DenseLayer dl = (DenseLayer) layer; 213 | node.addLabel(Label.label("DenseLayer")); 214 | node.setProperty("activation",dl.getActivationFn().toString()); // todo parameters 215 | node.setProperty("biasInit",dl.getBiasInit()); 216 | node.setProperty("biasLearningRate",dl.getBiasLearningRate()); 217 | node.setProperty("l1",dl.getL1()); 218 | node.setProperty("l1Bias",dl.getL1Bias()); 219 | node.setProperty("l2",dl.getL2()); 220 | node.setProperty("l2Bias",dl.getL2Bias()); 221 | node.setProperty("distribution",dl.getDist().toString()); 222 | node.setProperty("in",dl.getNIn()); 223 | node.setProperty("out",dl.getNOut()); 224 | } 225 | result.add(node); 226 | // layer.preOutput(allOne, Layer.TrainingMode.TEST); 227 | // layer.p(allOne, Layer.TrainingMode.TEST); 228 | // layer.activate(allOne, Layer.TrainingMode.TEST); 229 | } 230 | return result; 231 | } 232 | 233 | private Node node(String label, Object...keyValues) { 234 | return new VirtualNode(new Label[] {Label.label(label)}, MapUtil.map(keyValues),null); 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /src/main/java/ml/EncogMLModel.java: -------------------------------------------------------------------------------- 1 | package ml; 2 | 3 | import org.encog.ml.MLRegression; 4 | import org.encog.ml.data.MLData; 5 | import org.encog.ml.data.versatile.NormalizationHelper; 6 | import org.encog.ml.data.versatile.VersatileMLDataSet; 7 | import org.encog.ml.data.versatile.columns.ColumnDefinition; 8 | import org.encog.ml.data.versatile.columns.ColumnType; 9 | import org.encog.ml.data.versatile.sources.VersatileDataSource; 10 | import org.encog.ml.factory.MLMethodFactory; 11 | import org.encog.ml.model.EncogModel; 12 | import org.encog.util.simple.EncogUtility; 13 | import org.neo4j.graphdb.Node; 14 | 15 | import java.util.*; 16 | 17 | /** 18 | * @author mh 19 | * @since 19.07.17 20 | */ 21 | public class EncogMLModel extends MLModel { 22 | 23 | private EncogModel model; // todo MLMethod and decide later between regression, classification and others 24 | private MLRegression method; 25 | 26 | 27 | @Override 28 | protected String[] asRow(Map inputs, Object output) { 29 | String[] row = new String[inputs.size() + (output == null ? 0 : 1)]; 30 | for (String k : inputs.keySet()) { 31 | row[offsets.get(k)] = inputs.get(k).toString(); 32 | } 33 | if (output != null) { 34 | row[offsets.get(this.output)] = output.toString(); 35 | } 36 | return row; 37 | } 38 | 39 | 40 | @Override 41 | protected Object doPredict(String[] line) { 42 | NormalizationHelper helper = model.getDataset().getNormHelper(); 43 | MLData input = helper.allocateInputVector(); 44 | helper.normalizeInputVector(line, input.getData(), false); 45 | MLData output = method.compute(input); 46 | DataType outputType = types.get(this.output); 47 | switch (outputType) { 48 | case _float : return output.getData(0); 49 | case _class: return helper.denormalizeOutputVectorToString(output)[0]; 50 | default: throw new IllegalArgumentException("Output type not yet supported "+outputType); 51 | } 52 | } 53 | 54 | @Override 55 | protected void doTrain() { 56 | VersatileMLDataSet data = new VersatileMLDataSet(new VersatileDataSource() { 57 | int idx = 0; 58 | 59 | @Override 60 | public String[] readLine() { 61 | return idx >= rows.size() ? null : rows.get(idx++); 62 | } 63 | 64 | @Override 65 | public void rewind() { 66 | idx = 0; 67 | } 68 | 69 | @Override 70 | public int columnIndex(String s) { 71 | return offsets.get(s); 72 | } 73 | }); 74 | offsets.entrySet().stream().sorted(Comparator.comparingInt(Map.Entry::getValue)).forEach(e -> { 75 | String k = e.getKey(); 76 | ColumnDefinition col = data.defineSourceColumn(k, offsets.get(k), typeOf(types.get(k))); // todo has bug, doesn't work like that, cols have to be in index order 77 | if (k.equals(output)) { 78 | data.defineOutput(col); 79 | } else { 80 | data.defineInput(col); 81 | } 82 | }); 83 | // types.forEach((k, v) -> { 84 | // ColumnDefinition col = data.defineSourceColumn(k, offsets.get(k), v); // todo has bug, doesn't work like that, cols have to be in index order 85 | // if (k.equals(output)) { 86 | // data.defineOutput(col); 87 | // } else { 88 | // data.defineInput(col); 89 | // } 90 | // }); 91 | // Analyze the data, determine the min/max/mean/sd of every column. 92 | data.analyze(); 93 | 94 | // Create feedforward neural network as the model type. MLMethodFactory.TYPE_FEEDFORWARD. 95 | // You could also other model types, such as: 96 | // MLMethodFactory.SVM: Support Vector Machine (SVM) 97 | // MLMethodFactory.TYPE_RBFNETWORK: RBF Neural Network 98 | // MLMethodFactor.TYPE_NEAT: NEAT Neural Network 99 | // MLMethodFactor.TYPE_PNN: Probabilistic Neural Network 100 | EncogModel model = new EncogModel(data); 101 | model.selectMethod(data, methodFor(methodName)); // todo from config 102 | // Send any output to the console. 103 | // model.setReport(new ConsoleStatusReportable()); 104 | 105 | // Now normalize the data. Encog will automatically determine the correct normalization 106 | // type based on the model you chose in the last step. 107 | data.normalize(); 108 | 109 | // Hold back some data for a final validation. 110 | // Shuffle the data into a random ordering. 111 | // Use a seed of 1001 so that we always use the same holdback and will get more consistent results. 112 | model.holdBackValidation(0.3, true, 1001); // todo from config 113 | 114 | // Choose whatever is the default training type for this model. 115 | model.selectTrainingType(data); 116 | 117 | // Use a 5-fold cross-validated train. Return the best method found. 118 | MLRegression bestMethod = (MLRegression) model.crossvalidate(5, true); // todo from config 119 | // MLRegression vs. MLClassification 120 | 121 | // Display the training and validation errors. 122 | // System.out.println("Training error: " + EncogUtility.calculateRegressionError(bestMethod, model.getTrainingDataset())); 123 | // System.out.println("Validation error: " + EncogUtility.calculateRegressionError(bestMethod, model.getValidationDataset())); 124 | 125 | // Display our normalization parameters. 126 | // NormalizationHelper helper = data.getNormHelper(); 127 | // System.out.println(helper.toString()); 128 | 129 | // Display the final model. 130 | // System.out.println("Final model: " + bestMethod); 131 | this.model = model; 132 | this.method = bestMethod; 133 | this.state = State.ready; 134 | } 135 | 136 | private String methodFor(Method method) { 137 | switch (method) { 138 | case ffd: return MLMethodFactory.TYPE_FEEDFORWARD; 139 | case svm: return MLMethodFactory.TYPE_SVM; 140 | case rbf: return MLMethodFactory.TYPE_RBFNETWORK; 141 | case neat: return MLMethodFactory.TYPE_NEAT; 142 | case pnn: return MLMethodFactory.TYPE_PNN; 143 | default: throw new IllegalArgumentException("Unknown method "+method); 144 | } 145 | } 146 | 147 | /* 148 | nominal, 149 | ordinal, 150 | continuous, 151 | ignore 152 | */ 153 | private ColumnType typeOf(DataType type) { 154 | switch (type) { 155 | case _class: 156 | return ColumnType.nominal; 157 | case _float: 158 | return ColumnType.continuous; 159 | case _order: 160 | return ColumnType.ordinal; 161 | default: 162 | throw new IllegalArgumentException("Unknown type: " + type); 163 | } 164 | } 165 | 166 | public EncogMLModel(String name, Map types, String output, Map config) { 167 | super(name, types, output, config); 168 | } 169 | 170 | 171 | @Override 172 | protected ML.ModelResult resultWithInfo(ML.ModelResult result) { 173 | return result.withInfo( 174 | "trainingError", EncogUtility.calculateRegressionError(method, model.getTrainingDataset()), 175 | "validationError",EncogUtility.calculateRegressionError(method, model.getValidationDataset()), 176 | "selectedMethod",method.toString(), 177 | "normalization",model.getDataset().getNormHelper().toString() 178 | ); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/main/java/ml/LoadTensorFlow.java: -------------------------------------------------------------------------------- 1 | package ml; 2 | 3 | import org.neo4j.graphdb.GraphDatabaseService; 4 | import org.neo4j.graphdb.Label; 5 | import org.neo4j.graphdb.Node; 6 | import org.neo4j.graphdb.RelationshipType; 7 | import org.neo4j.procedure.Context; 8 | import org.neo4j.procedure.Mode; 9 | import org.neo4j.procedure.Name; 10 | import org.neo4j.procedure.Procedure; 11 | import org.tensorflow.framework.AttrValue; 12 | import org.tensorflow.framework.GraphDef; 13 | import org.tensorflow.framework.NodeDef; 14 | 15 | import java.io.BufferedInputStream; 16 | import java.io.IOException; 17 | import java.net.URL; 18 | import java.util.HashMap; 19 | import java.util.Map; 20 | import java.util.stream.Stream; 21 | 22 | /** 23 | * @author mh 24 | * @since 26.07.17 25 | * see: https://www.tensorflow.org/extend/tool_developers/ 26 | * see: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto 27 | * https://developers.google.com/protocol-buffers/docs/javatutorial 28 | * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/testdata/half_plus_two/00000123/saved_model.pb 29 | * 30 | */ 31 | public class LoadTensorFlow { 32 | 33 | @Context 34 | public GraphDatabaseService db; 35 | 36 | enum Types implements Label { 37 | Neuron 38 | } 39 | 40 | enum RelTypes implements RelationshipType { 41 | INPUT 42 | } 43 | 44 | public static class LoadResult { 45 | public String modelName; 46 | public String type; 47 | public long nodes; 48 | public long relationships; 49 | 50 | public LoadResult(String modelName, String type, long nodes, long relationships) { 51 | this.modelName = modelName; 52 | this.type = type; 53 | this.nodes = nodes; 54 | this.relationships = relationships; 55 | } 56 | } 57 | @Procedure(value = "load.tensorflow", mode = Mode.WRITE) 58 | public Stream loadTensorFlow(@Name("file") String url) throws IOException { 59 | GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new URL(url).openStream())); 60 | Map nodes = new HashMap<>(); 61 | // tod model node, layer nodes 62 | for (NodeDef nodeDef : graphDef.getNodeList()) { 63 | Node node = db.createNode(Types.Neuron); 64 | node.setProperty("name", nodeDef.getName()); 65 | if (nodeDef.getDevice() != null) node.setProperty("device", nodeDef.getDevice()); 66 | node.setProperty("op", nodeDef.getOp()); 67 | nodeDef.getAttrMap().forEach((k, v) -> { 68 | Object value = getValue(v); 69 | if (value != null) { 70 | node.setProperty(k, value); 71 | } 72 | }); 73 | nodes.put(nodeDef.getName(), node); 74 | } 75 | long rels = 0; 76 | for (NodeDef nodeDef : graphDef.getNodeList()) { 77 | Node target = nodes.get(nodeDef.getName()); 78 | nodeDef.getInputList().forEach(name -> nodes.get(name).createRelationshipTo(target, RelTypes.INPUT)); 79 | // todo weights 80 | rels += nodeDef.getInputCount(); 81 | } 82 | return Stream.of(new LoadResult(url,"tensorflow",nodes.size(), rels)); 83 | } 84 | 85 | private Object getValue(AttrValue v) { 86 | switch (v.getValueCase()) { 87 | case S: 88 | return v.getS().toStringUtf8(); 89 | case I: 90 | return v.getI(); 91 | case F: 92 | return v.getF(); 93 | case B: 94 | return v.getB(); 95 | case TYPE: 96 | return v.getType().name(); // todo 97 | case SHAPE: 98 | return v.getShape().toString(); // tdo 99 | case TENSOR: 100 | return v.getTensor().toString(); // todo handle with prefxied properties 101 | case LIST: 102 | return v.getList().toString(); // todo getType/Count(idx) and then handle each type with prefixed property 103 | case FUNC: 104 | return v.getFunc().getAttrMap().toString(); // todo handle recursively 105 | case PLACEHOLDER: 106 | break; 107 | case VALUE_NOT_SET: 108 | return null; 109 | default: 110 | return null; 111 | } 112 | return null; 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /src/main/java/ml/ML.java: -------------------------------------------------------------------------------- 1 | package ml; 2 | 3 | import org.neo4j.graphdb.GraphDatabaseService; 4 | import org.neo4j.graphdb.Node; 5 | import org.neo4j.logging.Log; 6 | import org.neo4j.procedure.Context; 7 | import org.neo4j.procedure.Name; 8 | import org.neo4j.procedure.Procedure; 9 | 10 | import java.util.*; 11 | import java.util.stream.Stream; 12 | 13 | public class ML { 14 | @Context 15 | public GraphDatabaseService db; 16 | 17 | @Context 18 | public Log log; 19 | 20 | /* 21 | 22 | apoc.ml.create.classifier({types},['output'],{config}) yield model 23 | apoc.ml.create.regression({params},{config}) yield model 24 | 25 | apoc.ml.train(model, {params}, prediction) 26 | 27 | apoc.ml.classify(model, {params}) yield prediction:string, confidence:float 28 | apoc.ml.regression(model, {params}) yield prediction:float, confidence:float 29 | 30 | apoc.ml.delete(model) yield model 31 | 32 | */ 33 | 34 | @Procedure 35 | public Stream create(@Name("model") String model, @Name("types") Map types, @Name(value="output") String output, @Name(value="params",defaultValue="{}") Map config) { 36 | return Stream.of(MLModel.create(model,types,output,config).asResult()); 37 | } 38 | 39 | @Procedure 40 | public Stream info(@Name("model") String model) { 41 | return Stream.of(MLModel.from(model).asResult()); 42 | } 43 | 44 | @Procedure 45 | public Stream remove(@Name("model") String model) { 46 | return Stream.of(MLModel.remove(model)); 47 | } 48 | 49 | @Procedure 50 | public Stream add(@Name("model") String model, @Name("inputs") Map inputs, @Name("outputs") Object output) { 51 | MLModel mlModel = MLModel.from(model); 52 | mlModel.add(inputs,output); 53 | return Stream.of(mlModel.asResult()); 54 | } 55 | 56 | @Procedure 57 | public Stream train(@Name("model") String model) { 58 | MLModel mlModel = MLModel.from(model); 59 | mlModel.train(); 60 | return Stream.of(mlModel.asResult()); 61 | } 62 | 63 | public static class NodeResult { 64 | public final Node node; 65 | 66 | public NodeResult(Node node) { 67 | this.node = node; 68 | } 69 | } 70 | @Procedure 71 | public Stream show(@Name("model") String model) { 72 | List show = MLModel.from(model).show(); 73 | return show.stream().map(NodeResult::new); 74 | } 75 | 76 | @Procedure 77 | public Stream predict(@Name("model") String model, @Name("inputs") Map inputs) { 78 | MLModel mlModel = MLModel.from(model); 79 | Object value = mlModel.predict(inputs); 80 | double confidence = 0.0d; 81 | return Stream.of(new PredictionResult(value, confidence)); 82 | } 83 | 84 | public static class PredictionResult { 85 | public Object value; 86 | public double confidence; 87 | 88 | public PredictionResult(Object value, double confidence) { 89 | this.value = value; 90 | this.confidence = confidence; 91 | } 92 | } 93 | 94 | public static class ModelResult { 95 | public final String model; 96 | public final String state; 97 | public final Map info = new HashMap<>(); 98 | 99 | public ModelResult(String model, EncogMLModel.State state) { 100 | this.model = model; 101 | this.state = state.name(); 102 | } 103 | 104 | ModelResult withInfo(Object...infos) { 105 | for (int i = 0; i < infos.length; i+=2) { 106 | info.put(infos[i].toString(),infos[i+1]); 107 | } 108 | return this; 109 | } 110 | } 111 | 112 | } 113 | -------------------------------------------------------------------------------- /src/main/java/ml/MLModel.java: -------------------------------------------------------------------------------- 1 | package ml; 2 | 3 | import org.encog.ml.factory.MLMethodFactory; 4 | import org.neo4j.graphdb.Node; 5 | 6 | import java.util.*; 7 | import java.util.concurrent.ConcurrentHashMap; 8 | 9 | /** 10 | * @author mh 11 | * @since 23.07.17 12 | */ 13 | public abstract class MLModel { 14 | 15 | static class Config { 16 | private final Map config; 17 | 18 | public Config(Map config) { 19 | this.config = config; 20 | } 21 | 22 | class V { 23 | private final String name; 24 | private final T defaultValue; 25 | 26 | V(String name, T defaultValue) { 27 | this.name = name; 28 | this.defaultValue = defaultValue; 29 | } 30 | 31 | T get(T defaultValue) { 32 | Object value = config.get(name); 33 | if (value == null) return defaultValue; 34 | if (defaultValue instanceof Double) return (T) (Object) ((Number) value).doubleValue(); 35 | if (defaultValue instanceof Integer) return (T) (Object) ((Number) value).intValue(); 36 | if (defaultValue instanceof Long) return (T) (Object) ((Number) value).longValue(); 37 | if (defaultValue instanceof String) return (T) value.toString(); 38 | return (T) value; 39 | } 40 | 41 | T get() { 42 | return get(defaultValue); 43 | } 44 | } 45 | 46 | public final V seed = new V<>("seed", 123L); 47 | public final V learningRate = new V<>("learningRate", 0.01d); 48 | public final V epochs = new V<>("epochs", 50); 49 | public final V hidden = new V<>("hidden", 20); 50 | public final V trainPercent = new V<>("trainPercent", 0.75d); 51 | } 52 | 53 | static ConcurrentHashMap models = new ConcurrentHashMap<>(); 54 | final String name; 55 | final Map types = new HashMap<>(); 56 | final Map offsets = new HashMap<>(); 57 | final String output; 58 | final Config config; 59 | final List rows = new ArrayList<>(); 60 | State state; 61 | Method methodName; 62 | 63 | public MLModel(String name, Map types, String output, Map config) { 64 | if (models.containsKey(name)) 65 | throw new IllegalArgumentException("Model " + name + " already exists, please remove first"); 66 | 67 | this.name = name; 68 | this.state = State.created; 69 | this.output = output; 70 | this.config = new Config(config); 71 | initTypes(types, output); 72 | 73 | this.methodName = Method.ffd; 74 | 75 | models.put(name, this); 76 | 77 | } 78 | 79 | protected void initTypes(Map types, String output) { 80 | if (!types.containsKey(output)) throw new IllegalArgumentException("Outputs not defined: " + output); 81 | int i = 0; 82 | for (Map.Entry entry : types.entrySet()) { 83 | String key = entry.getKey(); 84 | this.types.put(key, DataType.from(entry.getValue())); 85 | if (!key.equals(output)) this.offsets.put(key, i++); 86 | } 87 | this.offsets.put(output, i); 88 | } 89 | 90 | public static ML.ModelResult remove(String model) { 91 | MLModel existing = models.remove(model); 92 | return new ML.ModelResult(model, existing == null ? State.unknown : State.removed); 93 | } 94 | 95 | public static MLModel from(String name) { 96 | MLModel model = models.get(name); 97 | if (model != null) return model; 98 | throw new IllegalArgumentException("No valid ML-Model " + name); 99 | } 100 | 101 | public void add(Map inputs, Object output) { 102 | if (this.state == State.created || this.state == State.training) { 103 | rows.add(asRow(inputs, output)); 104 | this.state = State.training; 105 | } else { 106 | throw new IllegalArgumentException(String.format("Model %s not able to accept training data, state is: %s", name, state)); 107 | } 108 | } 109 | 110 | protected abstract ROW asRow(Map inputs, Object output); 111 | 112 | public void train() { 113 | if (state != State.ready) { 114 | if (state != State.training) { 115 | throw new IllegalArgumentException(String.format("Model %s is not ready to predict, it has no training data, state is %s", name, state)); 116 | } 117 | doTrain(); 118 | } 119 | } 120 | 121 | public Object predict(Map inputs) { 122 | if (state != State.ready) { 123 | train(); 124 | } 125 | if (state == State.ready) { 126 | ROW line = asRow(inputs, null); 127 | 128 | Object predicted = doPredict(line); 129 | // todo confidence 130 | return predicted; 131 | } else { 132 | throw new IllegalArgumentException(String.format("Model %s is not ready to predict, state is %s", name, state)); 133 | } 134 | } 135 | 136 | protected abstract Object doPredict(ROW line); 137 | 138 | protected abstract void doTrain(); 139 | 140 | public ML.ModelResult asResult() { 141 | ML.ModelResult result = 142 | new ML.ModelResult(this.name, this.state) 143 | .withInfo("methodName", methodName); 144 | 145 | if (rows.size() > 0) { 146 | result = result.withInfo("trainingSets", (long) rows.size()); 147 | } 148 | if (state == State.ready) { 149 | // todo check how expensive this is 150 | result = resultWithInfo(result); 151 | } 152 | return result; 153 | } 154 | 155 | protected ML.ModelResult resultWithInfo(ML.ModelResult result) { 156 | return result; 157 | } 158 | 159 | ; 160 | 161 | public static MLModel create(String name, Map types, String output, Map config) { 162 | String framework = config.getOrDefault("framework", "encog").toString().toLowerCase(); 163 | switch (framework) { 164 | case "encog": 165 | return new EncogMLModel(name, types, output, config); 166 | case "dl4j": 167 | return new DL4JMLModel(name, types, output, config); 168 | default: 169 | throw new IllegalArgumentException("Unknown framework: " + framework); 170 | } 171 | } 172 | 173 | enum Method { 174 | ffd, svm, rbf, neat, pnn; 175 | } 176 | 177 | enum DataType { 178 | _class, _float, _order; 179 | 180 | public static DataType from(String type) { 181 | switch (type.toUpperCase()) { 182 | case "CLASS": 183 | return DataType._class; 184 | case "FLOAT": 185 | return DataType._float; 186 | case "ORDER": 187 | return DataType._order; 188 | default: 189 | throw new IllegalArgumentException("Unknown type: " + type); 190 | } 191 | } 192 | } 193 | 194 | public enum State {created, training, ready, removed, unknown} 195 | 196 | List show() { 197 | return Collections.emptyList(); 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /src/main/java/result/VirtualNode.java: -------------------------------------------------------------------------------- 1 | package result; 2 | 3 | import org.neo4j.graphdb.*; 4 | import org.neo4j.helpers.collection.FilteringIterable; 5 | import org.neo4j.helpers.collection.Iterables; 6 | 7 | import java.util.*; 8 | import java.util.concurrent.atomic.AtomicLong; 9 | import java.util.stream.Collectors; 10 | 11 | import static java.util.Arrays.asList; 12 | 13 | /** 14 | * @author mh 15 | * @since 10.08.17 16 | */ 17 | public class VirtualNode implements Node { 18 | private static AtomicLong MIN_ID = new AtomicLong(-1); 19 | private final List