├── LICENSE.txt ├── NOTICE.txt ├── README.md ├── pom.xml └── src ├── main ├── java │ └── org │ │ └── jpmml │ │ └── tensorflow │ │ ├── DNNClassifier.java │ │ ├── DNNEstimator.java │ │ ├── DNNRegressor.java │ │ ├── Estimator.java │ │ ├── EstimatorFactory.java │ │ ├── LinearClassifier.java │ │ ├── LinearEstimator.java │ │ ├── LinearRegressor.java │ │ ├── Main.java │ │ ├── SavedModel.java │ │ ├── ShapeUtil.java │ │ ├── TensorFlowEncoder.java │ │ ├── TensorUtil.java │ │ ├── Trail.java │ │ └── TypeUtil.java └── proto │ └── tensorflow │ └── core │ ├── framework │ ├── attr_value.proto │ ├── function.proto │ ├── graph.proto │ ├── node_def.proto │ ├── op_def.proto │ ├── resource_handle.proto │ ├── tensor.proto │ ├── tensor_shape.proto │ ├── types.proto │ └── versions.proto │ └── protobuf │ ├── meta_graph.proto │ └── saver.proto └── test ├── java └── org │ └── jpmml │ └── tensorflow │ ├── DNNClassifierTest.java │ ├── DNNRegressorTest.java │ ├── EstimatorTest.java │ ├── LinearClassifierTest.java │ └── LinearRegressorTest.java └── resources ├── csv ├── Audit.csv ├── Auto.csv ├── DNNClassificationAudit.csv ├── DNNClassificationIris.csv ├── DNNRegressionAuto.csv ├── Iris.csv ├── LinearClassificationAudit.csv ├── LinearClassificationIris.csv └── LinearRegressionAuto.csv ├── main.py └── savedmodel ├── DNNClassificationAudit ├── saved_model.pbtxt └── variables │ ├── variables.data-00000-of-00001 │ └── variables.index ├── DNNClassificationIris ├── saved_model.pbtxt └── variables │ ├── variables.data-00000-of-00001 │ └── variables.index ├── DNNRegressionAuto ├── saved_model.pbtxt └── variables │ ├── variables.data-00000-of-00001 │ └── variables.index ├── LinearClassificationAudit ├── saved_model.pbtxt └── variables │ ├── variables.data-00000-of-00001 │ └── variables.index ├── LinearClassificationIris ├── saved_model.pbtxt └── variables │ ├── variables.data-00000-of-00001 │ └── variables.index └── LinearRegressionAuto ├── saved_model.pbtxt └── variables ├── variables.data-00000-of-00001 └── variables.index /NOTICE.txt: -------------------------------------------------------------------------------- 1 | JPMML-TensorFlow includes third-party dependencies that are released under the Apache License, Version 2.0: 2 | * Guava - https://github.com/google/guava 3 | * JCommander - http://jcommander.org 4 | * Protocol Buffers - https://github.com/google/protobuf 5 | * TensorFlow - https://github.com/tensorflow/tensorflow 6 | 7 | Apache License 8 | Version 2.0, January 2004 9 | http://www.apache.org/licenses/ 10 | 11 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 12 | 13 | 1. Definitions. 14 | 15 | "License" shall mean the terms and conditions for use, reproduction, 16 | and distribution as defined by Sections 1 through 9 of this document. 17 | 18 | "Licensor" shall mean the copyright owner or entity authorized by 19 | the copyright owner that is granting the License. 20 | 21 | "Legal Entity" shall mean the union of the acting entity and all 22 | other entities that control, are controlled by, or are under common 23 | control with that entity. For the purposes of this definition, 24 | "control" means (i) the power, direct or indirect, to cause the 25 | direction or management of such entity, whether by contract or 26 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 27 | outstanding shares, or (iii) beneficial ownership of such entity. 28 | 29 | "You" (or "Your") shall mean an individual or Legal Entity 30 | exercising permissions granted by this License. 31 | 32 | "Source" form shall mean the preferred form for making modifications, 33 | including but not limited to software source code, documentation 34 | source, and configuration files. 35 | 36 | "Object" form shall mean any form resulting from mechanical 37 | transformation or translation of a Source form, including but 38 | not limited to compiled object code, generated documentation, 39 | and conversions to other media types. 40 | 41 | "Work" shall mean the work of authorship, whether in Source or 42 | Object form, made available under the License, as indicated by a 43 | copyright notice that is included in or attached to the work 44 | (an example is provided in the Appendix below). 45 | 46 | "Derivative Works" shall mean any work, whether in Source or Object 47 | form, that is based on (or derived from) the Work and for which the 48 | editorial revisions, annotations, elaborations, or other modifications 49 | represent, as a whole, an original work of authorship. For the purposes 50 | of this License, Derivative Works shall not include works that remain 51 | separable from, or merely link (or bind by name) to the interfaces of, 52 | the Work and Derivative Works thereof. 53 | 54 | "Contribution" shall mean any work of authorship, including 55 | the original version of the Work and any modifications or additions 56 | to that Work or Derivative Works thereof, that is intentionally 57 | submitted to Licensor for inclusion in the Work by the copyright owner 58 | or by an individual or Legal Entity authorized to submit on behalf of 59 | the copyright owner. For the purposes of this definition, "submitted" 60 | means any form of electronic, verbal, or written communication sent 61 | to the Licensor or its representatives, including but not limited to 62 | communication on electronic mailing lists, source code control systems, 63 | and issue tracking systems that are managed by, or on behalf of, the 64 | Licensor for the purpose of discussing and improving the Work, but 65 | excluding communication that is conspicuously marked or otherwise 66 | designated in writing by the copyright owner as "Not a Contribution." 67 | 68 | "Contributor" shall mean Licensor and any individual or Legal Entity 69 | on behalf of whom a Contribution has been received by Licensor and 70 | subsequently incorporated within the Work. 71 | 72 | 2. Grant of Copyright License. Subject to the terms and conditions of 73 | this License, each Contributor hereby grants to You a perpetual, 74 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 75 | copyright license to reproduce, prepare Derivative Works of, 76 | publicly display, publicly perform, sublicense, and distribute the 77 | Work and such Derivative Works in Source or Object form. 78 | 79 | 3. Grant of Patent License. Subject to the terms and conditions of 80 | this License, each Contributor hereby grants to You a perpetual, 81 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 82 | (except as stated in this section) patent license to make, have made, 83 | use, offer to sell, sell, import, and otherwise transfer the Work, 84 | where such license applies only to those patent claims licensable 85 | by such Contributor that are necessarily infringed by their 86 | Contribution(s) alone or by combination of their Contribution(s) 87 | with the Work to which such Contribution(s) was submitted. If You 88 | institute patent litigation against any entity (including a 89 | cross-claim or counterclaim in a lawsuit) alleging that the Work 90 | or a Contribution incorporated within the Work constitutes direct 91 | or contributory patent infringement, then any patent licenses 92 | granted to You under this License for that Work shall terminate 93 | as of the date such litigation is filed. 94 | 95 | 4. Redistribution. You may reproduce and distribute copies of the 96 | Work or Derivative Works thereof in any medium, with or without 97 | modifications, and in Source or Object form, provided that You 98 | meet the following conditions: 99 | 100 | (a) You must give any other recipients of the Work or 101 | Derivative Works a copy of this License; and 102 | 103 | (b) You must cause any modified files to carry prominent notices 104 | stating that You changed the files; and 105 | 106 | (c) You must retain, in the Source form of any Derivative Works 107 | that You distribute, all copyright, patent, trademark, and 108 | attribution notices from the Source form of the Work, 109 | excluding those notices that do not pertain to any part of 110 | the Derivative Works; and 111 | 112 | (d) If the Work includes a "NOTICE" text file as part of its 113 | distribution, then any Derivative Works that You distribute must 114 | include a readable copy of the attribution notices contained 115 | within such NOTICE file, excluding those notices that do not 116 | pertain to any part of the Derivative Works, in at least one 117 | of the following places: within a NOTICE text file distributed 118 | as part of the Derivative Works; within the Source form or 119 | documentation, if provided along with the Derivative Works; or, 120 | within a display generated by the Derivative Works, if and 121 | wherever such third-party notices normally appear. The contents 122 | of the NOTICE file are for informational purposes only and 123 | do not modify the License. You may add Your own attribution 124 | notices within Derivative Works that You distribute, alongside 125 | or as an addendum to the NOTICE text from the Work, provided 126 | that such additional attribution notices cannot be construed 127 | as modifying the License. 128 | 129 | You may add Your own copyright statement to Your modifications and 130 | may provide additional or different license terms and conditions 131 | for use, reproduction, or distribution of Your modifications, or 132 | for any such Derivative Works as a whole, provided Your use, 133 | reproduction, and distribution of the Work otherwise complies with 134 | the conditions stated in this License. 135 | 136 | 5. Submission of Contributions. Unless You explicitly state otherwise, 137 | any Contribution intentionally submitted for inclusion in the Work 138 | by You to the Licensor shall be under the terms and conditions of 139 | this License, without any additional terms or conditions. 140 | Notwithstanding the above, nothing herein shall supersede or modify 141 | the terms of any separate license agreement you may have executed 142 | with Licensor regarding such Contributions. 143 | 144 | 6. Trademarks. This License does not grant permission to use the trade 145 | names, trademarks, service marks, or product names of the Licensor, 146 | except as required for reasonable and customary use in describing the 147 | origin of the Work and reproducing the content of the NOTICE file. 148 | 149 | 7. Disclaimer of Warranty. Unless required by applicable law or 150 | agreed to in writing, Licensor provides the Work (and each 151 | Contributor provides its Contributions) on an "AS IS" BASIS, 152 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 153 | implied, including, without limitation, any warranties or conditions 154 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 155 | PARTICULAR PURPOSE. You are solely responsible for determining the 156 | appropriateness of using or redistributing the Work and assume any 157 | risks associated with Your exercise of permissions under this License. 158 | 159 | 8. Limitation of Liability. In no event and under no legal theory, 160 | whether in tort (including negligence), contract, or otherwise, 161 | unless required by applicable law (such as deliberate and grossly 162 | negligent acts) or agreed to in writing, shall any Contributor be 163 | liable to You for damages, including any direct, indirect, special, 164 | incidental, or consequential damages of any character arising as a 165 | result of this License or out of the use or inability to use the 166 | Work (including but not limited to damages for loss of goodwill, 167 | work stoppage, computer failure or malfunction, or any and all 168 | other commercial damages or losses), even if such Contributor 169 | has been advised of the possibility of such damages. 170 | 171 | 9. Accepting Warranty or Additional Liability. While redistributing 172 | the Work or Derivative Works thereof, You may choose to offer, 173 | and charge a fee for, acceptance of support, warranty, indemnity, 174 | or other liability obligations and/or rights consistent with this 175 | License. However, in accepting such obligations, You may act only 176 | on Your own behalf and on Your sole responsibility, not on behalf 177 | of any other Contributor, and only if You agree to indemnify, 178 | defend, and hold each Contributor harmless for any liability 179 | incurred by, or claims asserted against, such Contributor by reason 180 | of your accepting any such warranty or additional liability. 181 | 182 | END OF TERMS AND CONDITIONS 183 | 184 | APPENDIX: How to apply the Apache License to your work. 185 | 186 | To apply the Apache License to your work, attach the following 187 | boilerplate notice, with the fields enclosed by brackets "[]" 188 | replaced with your own identifying information. (Don't include 189 | the brackets!) The text should be enclosed in the appropriate 190 | comment syntax for the file format. We also recommend that a 191 | file or class name and description of purpose be included on the 192 | same "printed page" as the copyright notice for easier 193 | identification within third-party archives. 194 | 195 | Copyright [yyyy] [name of copyright owner] 196 | 197 | Licensed under the Apache License, Version 2.0 (the "License"); 198 | you may not use this file except in compliance with the License. 199 | You may obtain a copy of the License at 200 | 201 | http://www.apache.org/licenses/LICENSE-2.0 202 | 203 | Unless required by applicable law or agreed to in writing, software 204 | distributed under the License is distributed on an "AS IS" BASIS, 205 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 206 | See the License for the specific language governing permissions and 207 | limitations under the License. 208 | 209 | -------------------------------------------------------------------------------- 210 | 211 | Additionally, JPMML-TensorFlow includes third-party dependencies that are released under the MIT License: 212 | * Simple Logging Facade for Java (SLF4J) - http://www.slf4j.org/ 213 | 214 | Copyright (c) by Irmen de Jong (irmen@razorvine.net) 215 | All rights reserved. 216 | 217 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 218 | 219 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 220 | 221 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 222 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | JPMML-TensorFlow 2 | ================ 3 | 4 | Java library and command-line application for converting [TensorFlow](http://tensorflow.org) models to PMML. 5 | 6 | # Features # 7 | 8 | * Supported Estimator types: 9 | * [`learn.DNNClassifier`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNClassifier) 10 | * [`learn.DNNRegressor`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNRegressor) 11 | * [`learn.LinearClassifier`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/LinearClassifier) 12 | * [`learn.LinearRegressor`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/LinearRegressor) 13 | * Supported Feature column types: 14 | * [`layers.one_hot_column`](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/one_hot_column) 15 | * [`layers.real_valued_column`](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/real_valued_column) 16 | * [`layers.sparse_column_with_keys`](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/sparse_column_with_keys) 17 | * Production quality: 18 | * Complete test coverage. 19 | * Fully compliant with the [JPMML-Evaluator](https://github.com/jpmml/jpmml-evaluator) library. 20 | 21 | # Prerequisites # 22 | 23 | ### The TensorFlow side of operations 24 | 25 | * Protocol Buffers 3.2.0 or newer 26 | * TensorFlow 1.1.0 or newer 27 | 28 | ### The Java side of operations 29 | 30 | * Java 1.8 or newer 31 | 32 | # Installation # 33 | 34 | Enter the project root directory and build using [Apache Maven](http://maven.apache.org/); use the `protoc.exe` system property to specify the location of the Protocol Buffers compiler: 35 | ``` 36 | mvn -Dprotoc.exe=/usr/local/bin/protoc clean install 37 | ``` 38 | 39 | The build produces an executable uber-JAR file `target/converter-executable-1.0-SNAPSHOT.jar`. 40 | 41 | # Usage # 42 | 43 | A typical workflow can be summarized as follows: 44 | 45 | 1. Use TensorFlow to train an estimator. 46 | 2. Export the estimator in `SavedModel` data format to a directory in a local filesystem. 47 | 3. Use the JPMML-TensorFlow command-line converter application to turn the SavedModel directory to a PMML file. 48 | 49 | ### The TensorFlow side of operations 50 | 51 | Please see the test script file [main.py](https://github.com/jpmml/jpmml-tensorflow/blob/master/src/test/resources/main.py) for sample workflows. 52 | 53 | ### The Java side of operations 54 | 55 | Converting the estimator SavedModel directory `estimator/` to a PMML file `estimator.pmml`: 56 | ``` 57 | java -jar target/converter-executable-1.0-SNAPSHOT.jar --tf-savedmodel-input estimator/ --pmml-output estimator.pmml 58 | ``` 59 | 60 | Getting help: 61 | ``` 62 | java -jar target/converter-executable-1.0-SNAPSHOT.jar --help 63 | ``` 64 | 65 | # License # 66 | 67 | JPMML-TensorFlow is licensed under the [GNU Affero General Public License (AGPL) version 3.0](http://www.gnu.org/licenses/agpl-3.0.html). Other licenses are available on request. 68 | 69 | # Additional information # 70 | 71 | Please contact [info@openscoring.io](mailto:info@openscoring.io) 72 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | org.jpmml 6 | jpmml-tensorflow 7 | 1.0-SNAPSHOT 8 | 9 | JPMML-TensorFlow 10 | Java library and command-line application for converting TensorFlow models to PMML 11 | https://github.com/jpmml/jpmml-tensorflow 12 | 13 | 14 | 15 | GNU Affero General Public License (AGPL) version 3.0 16 | http://www.gnu.org/licenses/agpl-3.0.html 17 | repo 18 | 19 | 20 | 21 | 22 | 23 | villu.ruusmann 24 | Villu Ruusmann 25 | 26 | 27 | 28 | 29 | scm:git:git@github.com:jpmml/jpmml-tensorflow.git 30 | scm:git:git@github.com:jpmml/jpmml-tensorflow.git 31 | git://github.com/jpmml/jpmml-tensorflow.git 32 | HEAD 33 | 34 | 35 | GitHub 36 | https://github.com/jpmml/jpmml-tensorflow/issues 37 | 38 | 39 | 40 | protoc 41 | 42 | 43 | 44 | 45 | com.beust 46 | jcommander 47 | 1.48 48 | 49 | 50 | 51 | org.jpmml 52 | jpmml-converter 53 | 1.2.5 54 | 55 | 56 | com.sun.xml.fastinfoset 57 | FastInfoset 58 | 59 | 60 | javax.xml.bind 61 | jaxb-api 62 | 63 | 64 | org.glassfish.jaxb 65 | txw2 66 | 67 | 68 | org.jvnet.staxex 69 | stax-ex 70 | 71 | 72 | 73 | 74 | 75 | org.slf4j 76 | slf4j-api 77 | 1.7.25 78 | 79 | 80 | org.slf4j 81 | slf4j-jdk14 82 | 1.7.25 83 | 84 | 85 | 86 | org.tensorflow 87 | proto 88 | [1.3.0, ) 89 | 90 | 91 | org.tensorflow 92 | tensorflow 93 | [1.1.0, ) 94 | 95 | 96 | 97 | junit 98 | junit 99 | 4.12 100 | test 101 | 102 | 103 | 104 | org.jpmml 105 | pmml-evaluator 106 | 1.3.8 107 | test 108 | 109 | 110 | org.jpmml 111 | pmml-evaluator-test 112 | 1.3.8 113 | test 114 | 115 | 116 | 117 | 118 | 119 | 120 | org.apache.maven.plugins 121 | maven-compiler-plugin 122 | 3.5.1 123 | 124 | 1.8 125 | 1.8 126 | 127 | 128 | 129 | org.apache.maven.plugins 130 | maven-enforcer-plugin 131 | 1.4.1 132 | 133 | 134 | 135 | enforce 136 | 137 | 138 | 139 | 140 | 1.8 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | org.apache.maven.plugins 149 | maven-jar-plugin 150 | 3.0.2 151 | 152 | 153 | 154 | true 155 | 156 | 157 | 158 | 159 | 160 | org.apache.maven.plugins 161 | maven-shade-plugin 162 | 2.4.3 163 | 164 | 165 | package 166 | 167 | shade 168 | 169 | 170 | converter-executable-${project.version} 171 | 172 | 173 | org.jpmml.tensorflow.Main 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | org.apache.maven.plugins 182 | maven-source-plugin 183 | 3.0.1 184 | 185 | 186 | attach-sources 187 | 188 | jar 189 | 190 | 191 | 192 | 193 | 194 | org.apache.maven.plugins 195 | maven-surefire-plugin 196 | 2.19.1 197 | 198 | ${jacoco.agent} 199 | false 200 | 201 | 202 | 203 | org.jacoco 204 | jacoco-maven-plugin 205 | 0.7.9 206 | 207 | 208 | pre-unit-test 209 | 210 | prepare-agent 211 | 212 | 213 | jacoco.agent 214 | 215 | 216 | 217 | post-unit-test 218 | prepare-package 219 | 220 | report 221 | 222 | 223 | 224 | 225 | 226 | org.xolstice.maven.plugins 227 | protobuf-maven-plugin 228 | 0.5.0 229 | 230 | ${protoc.exe} 231 | 232 | 233 | 234 | 235 | compile 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/DNNClassifier.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.ArrayList; 22 | import java.util.Arrays; 23 | import java.util.List; 24 | 25 | import com.google.common.collect.Iterables; 26 | import org.dmg.pmml.DataField; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.FieldName; 29 | import org.dmg.pmml.MiningFunction; 30 | import org.dmg.pmml.OpType; 31 | import org.dmg.pmml.neural_network.Connection; 32 | import org.dmg.pmml.neural_network.NeuralLayer; 33 | import org.dmg.pmml.neural_network.NeuralNetwork; 34 | import org.dmg.pmml.neural_network.Neuron; 35 | import org.jpmml.converter.CategoricalLabel; 36 | import org.jpmml.converter.ModelUtil; 37 | import org.jpmml.converter.ValueUtil; 38 | import org.jpmml.converter.neural_network.NeuralNetworkUtil; 39 | 40 | public class DNNClassifier extends DNNEstimator { 41 | 42 | public DNNClassifier(SavedModel savedModel, String head){ 43 | super(savedModel, head); 44 | } 45 | 46 | @Override 47 | public NeuralNetwork encodeModel(TensorFlowEncoder encoder){ 48 | DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CATEGORICAL, DataType.INTEGER); 49 | 50 | NeuralNetwork neuralNetwork = encodeNeuralNetwork(encoder); 51 | 52 | List neuralLayers = neuralNetwork.getNeuralLayers(); 53 | 54 | NeuralLayer neuralLayer = Iterables.getLast(neuralLayers); 55 | 56 | List neurons = neuralLayer.getNeurons(); 57 | 58 | List categories; 59 | 60 | if(neurons.size() == 1){ 61 | neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC); 62 | 63 | Neuron neuron = Iterables.getOnlyElement(neurons); 64 | 65 | neuralLayer = new NeuralLayer() 66 | .setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY); 67 | 68 | categories = Arrays.asList("0", "1"); 69 | 70 | // p(no event) = 1 - p(event) 71 | Neuron passiveNeuron = new Neuron() 72 | .setId(String.valueOf(neuralLayers.size() + 1) + "/" + categories.get(0)) 73 | .setBias(ValueUtil.floatToDouble(1f)) 74 | .addConnections(new Connection(neuron.getId(), -1f)); 75 | 76 | // p(event) 77 | Neuron activeNeuron = new Neuron() 78 | .setId(String.valueOf(neuralLayers.size() + 1) + "/" + categories.get(1)) 79 | .setBias(null) 80 | .addConnections(new Connection(neuron.getId(), 1f)); 81 | 82 | neuralLayer.addNeurons(passiveNeuron, activeNeuron); 83 | 84 | neuralNetwork.addNeuralLayers(neuralLayer); 85 | 86 | neurons = neuralLayer.getNeurons(); 87 | } else 88 | 89 | if(neurons.size() > 2){ 90 | neuralLayer 91 | .setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY) 92 | .setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX); 93 | 94 | categories = new ArrayList<>(); 95 | 96 | for(int i = 0; i < neurons.size(); i++){ 97 | String category = String.valueOf(i); 98 | 99 | categories.add(category); 100 | } 101 | } else 102 | 103 | { 104 | throw new IllegalArgumentException(); 105 | } 106 | 107 | dataField = encoder.toCategorical(dataField.getName(), categories); 108 | 109 | CategoricalLabel categoricalLabel = new CategoricalLabel(dataField); 110 | 111 | neuralNetwork 112 | .setMiningFunction(MiningFunction.CLASSIFICATION) 113 | .setMiningSchema(ModelUtil.createMiningSchema(categoricalLabel)) 114 | .setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(neurons, categoricalLabel)) 115 | .setOutput(ModelUtil.createProbabilityOutput(DataType.FLOAT, categoricalLabel)); 116 | 117 | return neuralNetwork; 118 | } 119 | 120 | public static final String BINARY_LOGISTIC_HEAD = "dnn/binary_logistic_head/predictions/probabilities"; 121 | public static final String MULTI_CLASS_HEAD = "dnn/multi_class_head/predictions/probabilities"; 122 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/DNNEstimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | import java.util.Map; 24 | 25 | import com.google.common.collect.Lists; 26 | import com.google.common.primitives.Floats; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.Entity; 29 | import org.dmg.pmml.MathContext; 30 | import org.dmg.pmml.neural_network.NeuralInputs; 31 | import org.dmg.pmml.neural_network.NeuralLayer; 32 | import org.dmg.pmml.neural_network.NeuralNetwork; 33 | import org.dmg.pmml.neural_network.Neuron; 34 | import org.jpmml.converter.BinaryFeature; 35 | import org.jpmml.converter.CMatrixUtil; 36 | import org.jpmml.converter.Feature; 37 | import org.jpmml.converter.ValueUtil; 38 | import org.jpmml.converter.neural_network.NeuralNetworkUtil; 39 | import org.tensorflow.Operation; 40 | import org.tensorflow.Output; 41 | import org.tensorflow.Tensor; 42 | import org.tensorflow.framework.NodeDef; 43 | 44 | abstract 45 | public class DNNEstimator extends Estimator { 46 | 47 | public DNNEstimator(SavedModel savedModel, String head){ 48 | super(savedModel, head); 49 | } 50 | 51 | protected NeuralNetwork encodeNeuralNetwork(TensorFlowEncoder encoder){ 52 | SavedModel savedModel = getSavedModel(); 53 | 54 | NeuralNetwork neuralNetwork = new NeuralNetwork() 55 | .setActivationFunction(NeuralNetwork.ActivationFunction.RECTIFIER) 56 | .setMathContext(MathContext.FLOAT); 57 | 58 | List biasAdds = Lists.newArrayList(savedModel.getInputs(getHead(), "BiasAdd")); 59 | 60 | biasAdds = Lists.reverse(biasAdds); 61 | 62 | List entities; 63 | 64 | { 65 | NodeDef biasAdd = biasAdds.get(0); 66 | 67 | NodeDef matMul = savedModel.getNodeDef(biasAdd.getInput(0)); 68 | if(!("MatMul").equals(matMul.getOp())){ 69 | throw new IllegalArgumentException(); 70 | } 71 | 72 | NodeDef concat = savedModel.getNodeDef(matMul.getInput(0)); 73 | if(!("ConcatV2").equals(concat.getOp())){ 74 | throw new IllegalArgumentException(); 75 | } 76 | 77 | List features = new ArrayList<>(); 78 | 79 | List inputNames = concat.getInputList(); 80 | for(int i = 0; i < inputNames.size() - 1; i++){ 81 | String inputName = inputNames.get(i); 82 | 83 | NodeDef term = savedModel.getNodeDef(inputName); 84 | 85 | // "real_valued_column" 86 | if(("Cast").equals(term.getOp()) || ("Placeholder").equals(term.getOp())){ 87 | NodeDef placeholder = term; 88 | 89 | Feature feature = encoder.createContinuousFeature(savedModel, placeholder); 90 | 91 | features.add(feature); 92 | } else 93 | 94 | // "one_hot_column(sparse_column_with_keys)" 95 | if(("Sum").equals(term.getOp())){ 96 | NodeDef oneHot = savedModel.getOnlyInput(term.getInput(0), "OneHot"); 97 | 98 | NodeDef placeholder = savedModel.getOnlyInput(oneHot.getInput(0), "Placeholder"); 99 | NodeDef findTable = savedModel.getOnlyInput(oneHot.getInput(0), "LookupTableFind"); 100 | 101 | Map table = savedModel.getTable(findTable.getInput(0)); 102 | 103 | List categories = (List)new ArrayList<>(table.keySet()); 104 | 105 | List binaryFeatures = encoder.createBinaryFeatures(savedModel, placeholder, categories); 106 | 107 | features.addAll(binaryFeatures); 108 | } else 109 | 110 | { 111 | throw new IllegalArgumentException(term.getName()); 112 | } 113 | } 114 | 115 | NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.FLOAT); 116 | 117 | neuralNetwork.setNeuralInputs(neuralInputs); 118 | 119 | entities = neuralInputs.getNeuralInputs(); 120 | } 121 | 122 | for(int i = 0; i < biasAdds.size(); i++){ 123 | NodeDef biasAdd = biasAdds.get(i); 124 | 125 | NodeDef matMul = savedModel.getNodeDef(biasAdd.getInput(0)); 126 | if(!("MatMul").equals(matMul.getOp())){ 127 | throw new IllegalArgumentException(); 128 | } 129 | 130 | int count; 131 | 132 | { 133 | Operation operation = savedModel.getOperation(matMul.getName()); 134 | 135 | Output output = operation.output(0); 136 | 137 | long[] shape = ShapeUtil.toArray(output.shape()); 138 | if(shape.length != 2 || shape[0] != -1){ 139 | throw new IllegalArgumentException(); 140 | } 141 | 142 | count = (int)shape[1]; 143 | } 144 | 145 | NodeDef weights = savedModel.getOnlyInput(matMul.getInput(1), "VariableV2"); 146 | 147 | float[] weightValues; 148 | 149 | try(Tensor tensor = savedModel.run(weights.getName())){ 150 | weightValues = TensorUtil.toFloatArray(tensor); 151 | } 152 | 153 | NodeDef bias = savedModel.getOnlyInput(biasAdd.getInput(1), "VariableV2"); 154 | 155 | float[] biasValues; 156 | 157 | try(Tensor tensor = savedModel.run(bias.getName())){ 158 | biasValues = TensorUtil.toFloatArray(tensor); 159 | } 160 | 161 | NeuralLayer neuralLayer = new NeuralLayer(); 162 | 163 | for(int j = 0; j < count; j++){ 164 | List entityWeights = CMatrixUtil.getColumn(Floats.asList(weightValues), entities.size(), count, j); 165 | 166 | Neuron neuron = NeuralNetworkUtil.createNeuron(entities, ValueUtil.floatsToDoubles(entityWeights), ValueUtil.floatToDouble(biasValues[j])) 167 | .setId(String.valueOf(i + 1) + "/" + String.valueOf(j + 1)); 168 | 169 | neuralLayer.addNeurons(neuron); 170 | } 171 | 172 | neuralNetwork.addNeuralLayers(neuralLayer); 173 | 174 | entities = neuralLayer.getNeurons(); 175 | } 176 | 177 | return neuralNetwork; 178 | } 179 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/DNNRegressor.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.List; 22 | 23 | import com.google.common.collect.Iterables; 24 | import org.dmg.pmml.DataField; 25 | import org.dmg.pmml.DataType; 26 | import org.dmg.pmml.FieldName; 27 | import org.dmg.pmml.MiningFunction; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.neural_network.NeuralLayer; 30 | import org.dmg.pmml.neural_network.NeuralNetwork; 31 | import org.dmg.pmml.neural_network.Neuron; 32 | import org.jpmml.converter.ContinuousLabel; 33 | import org.jpmml.converter.ModelUtil; 34 | import org.jpmml.converter.neural_network.NeuralNetworkUtil; 35 | 36 | public class DNNRegressor extends DNNEstimator { 37 | 38 | public DNNRegressor(SavedModel savedModel, String head){ 39 | super(savedModel, head); 40 | } 41 | 42 | @Override 43 | public NeuralNetwork encodeModel(TensorFlowEncoder encoder){ 44 | DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CONTINUOUS, DataType.FLOAT); 45 | 46 | NeuralNetwork neuralNetwork = encodeNeuralNetwork(encoder); 47 | 48 | List neuralLayers = neuralNetwork.getNeuralLayers(); 49 | 50 | NeuralLayer neuralLayer = Iterables.getLast(neuralLayers); 51 | 52 | neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY); 53 | 54 | List neurons = neuralLayer.getNeurons(); 55 | 56 | ContinuousLabel continuousLabel = new ContinuousLabel(dataField); 57 | 58 | neuralNetwork 59 | .setMiningFunction(MiningFunction.REGRESSION) 60 | .setMiningSchema(ModelUtil.createMiningSchema(continuousLabel)) 61 | .setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs(neurons, continuousLabel)); 62 | 63 | return neuralNetwork; 64 | } 65 | 66 | public static final String REGRESSION_HEAD = "dnn/regression_head/predictions/scores"; 67 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/Estimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import org.dmg.pmml.Model; 22 | import org.dmg.pmml.PMML; 23 | 24 | abstract 25 | public class Estimator { 26 | 27 | private SavedModel savedModel = null; 28 | 29 | private String head = null; 30 | 31 | 32 | public Estimator(SavedModel savedModel, String head){ 33 | setSavedModel(savedModel); 34 | setHead(head); 35 | } 36 | 37 | abstract 38 | public Model encodeModel(TensorFlowEncoder encoder); 39 | 40 | public PMML encodePMML(){ 41 | TensorFlowEncoder encoder = new TensorFlowEncoder(); 42 | 43 | Model model = encodeModel(encoder); 44 | 45 | PMML pmml = encoder.encodePMML(model); 46 | 47 | return pmml; 48 | } 49 | 50 | public SavedModel getSavedModel(){ 51 | return this.savedModel; 52 | } 53 | 54 | private void setSavedModel(SavedModel savedModel){ 55 | this.savedModel = savedModel; 56 | } 57 | 58 | public String getHead(){ 59 | return this.head; 60 | } 61 | 62 | private void setHead(String head){ 63 | this.head = head; 64 | } 65 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/EstimatorFactory.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.Map; 22 | 23 | import org.tensorflow.framework.NodeDef; 24 | 25 | public class EstimatorFactory { 26 | 27 | protected EstimatorFactory(){ 28 | } 29 | 30 | public Estimator newEstimator(SavedModel savedModel){ 31 | Map nodeMap = savedModel.getNodeMap(); 32 | 33 | if(nodeMap.containsKey(DNNClassifier.BINARY_LOGISTIC_HEAD)){ 34 | return new DNNClassifier(savedModel, DNNClassifier.BINARY_LOGISTIC_HEAD); 35 | } else 36 | 37 | if(nodeMap.containsKey(DNNClassifier.MULTI_CLASS_HEAD)){ 38 | return new DNNClassifier(savedModel, DNNClassifier.MULTI_CLASS_HEAD); 39 | } else 40 | 41 | if(nodeMap.containsKey(DNNRegressor.REGRESSION_HEAD)){ 42 | return new DNNRegressor(savedModel, DNNRegressor.REGRESSION_HEAD); 43 | } else 44 | 45 | if(nodeMap.containsKey(LinearClassifier.BINARY_LOGISTIC_HEAD)){ 46 | return new LinearClassifier(savedModel, LinearClassifier.BINARY_LOGISTIC_HEAD); 47 | } else 48 | 49 | if(nodeMap.containsKey(LinearClassifier.MULTI_CLASS_HEAD)){ 50 | return new LinearClassifier(savedModel, LinearClassifier.MULTI_CLASS_HEAD); 51 | } else 52 | 53 | if(nodeMap.containsKey(LinearRegressor.REGRESSION_HEAD)){ 54 | return new LinearRegressor(savedModel, LinearRegressor.REGRESSION_HEAD); 55 | } else 56 | 57 | { 58 | throw new IllegalArgumentException(); 59 | } 60 | } 61 | 62 | static 63 | public EstimatorFactory newInstance(){ 64 | return new EstimatorFactory(); 65 | } 66 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/LinearClassifier.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.ArrayList; 22 | import java.util.Arrays; 23 | import java.util.List; 24 | 25 | import org.dmg.pmml.DataField; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.FieldName; 28 | import org.dmg.pmml.MiningFunction; 29 | import org.dmg.pmml.OpType; 30 | import org.dmg.pmml.regression.RegressionModel; 31 | import org.dmg.pmml.regression.RegressionTable; 32 | import org.jpmml.converter.CategoricalLabel; 33 | import org.jpmml.converter.ModelUtil; 34 | 35 | public class LinearClassifier extends LinearEstimator { 36 | 37 | public LinearClassifier(SavedModel savedModel, String head){ 38 | super(savedModel, head); 39 | } 40 | 41 | @Override 42 | public RegressionModel encodeModel(TensorFlowEncoder encoder){ 43 | DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CATEGORICAL, DataType.INTEGER); 44 | 45 | RegressionModel regressionModel = encodeRegressionModel(encoder); 46 | 47 | List regressionTables = regressionModel.getRegressionTables(); 48 | 49 | List categories; 50 | 51 | if(regressionTables.size() == 1){ 52 | categories = Arrays.asList("0", "1"); 53 | 54 | RegressionTable activeRegressionTable = regressionTables.get(0) 55 | .setTargetCategory(categories.get(1)); 56 | 57 | RegressionTable passiveRegressionTable = new RegressionTable(0) 58 | .setTargetCategory(categories.get(0)); 59 | 60 | regressionModel.addRegressionTables(passiveRegressionTable); 61 | } else 62 | 63 | if(regressionTables.size() > 2){ 64 | categories = new ArrayList<>(); 65 | 66 | for(int i = 0; i < regressionTables.size(); i++){ 67 | RegressionTable regressionTable = regressionTables.get(i); 68 | String category = String.valueOf(i); 69 | 70 | regressionTable.setTargetCategory(category); 71 | 72 | categories.add(category); 73 | } 74 | } else 75 | 76 | { 77 | throw new IllegalArgumentException(); 78 | } 79 | 80 | dataField = encoder.toCategorical(dataField.getName(), categories); 81 | 82 | CategoricalLabel categoricalLabel = new CategoricalLabel(dataField); 83 | 84 | regressionModel 85 | .setMiningFunction(MiningFunction.CLASSIFICATION) 86 | .setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX) 87 | .setMiningSchema(ModelUtil.createMiningSchema(categoricalLabel)) 88 | .setOutput(ModelUtil.createProbabilityOutput(DataType.FLOAT, categoricalLabel)); 89 | 90 | return regressionModel; 91 | } 92 | 93 | public static final String BINARY_LOGISTIC_HEAD = "linear/binary_logistic_head/predictions/probabilities"; 94 | public static final String MULTI_CLASS_HEAD = "linear/multi_class_head/predictions/probabilities"; 95 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/LinearEstimator.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | import java.util.Map; 24 | 25 | import com.google.common.primitives.Floats; 26 | import org.dmg.pmml.MathContext; 27 | import org.dmg.pmml.regression.RegressionModel; 28 | import org.dmg.pmml.regression.RegressionTable; 29 | import org.jpmml.converter.CMatrixUtil; 30 | import org.jpmml.converter.Feature; 31 | import org.jpmml.converter.ValueUtil; 32 | import org.jpmml.converter.regression.RegressionModelUtil; 33 | import org.tensorflow.Operation; 34 | import org.tensorflow.Output; 35 | import org.tensorflow.Tensor; 36 | import org.tensorflow.framework.NodeDef; 37 | 38 | abstract 39 | public class LinearEstimator extends Estimator { 40 | 41 | public LinearEstimator(SavedModel savedModel, String head){ 42 | super(savedModel, head); 43 | } 44 | 45 | public RegressionModel encodeRegressionModel(TensorFlowEncoder encoder){ 46 | SavedModel savedModel = getSavedModel(); 47 | 48 | NodeDef biasAdd = savedModel.getOnlyInput(getHead(), "BiasAdd"); 49 | 50 | int count; 51 | 52 | { 53 | Operation operation = savedModel.getOperation(biasAdd.getName()); 54 | 55 | Output output = operation.output(0); 56 | 57 | long[] shape = ShapeUtil.toArray(output.shape()); 58 | if((shape.length != 2) || (shape[0] != -1)){ 59 | throw new IllegalArgumentException(); 60 | } 61 | 62 | count = (int)shape[1]; 63 | } 64 | 65 | List equations = new ArrayList<>(); 66 | 67 | for(int i = 0; i < count; i++){ 68 | Equation equation = new Equation(); 69 | 70 | equations.add(equation); 71 | } 72 | 73 | NodeDef addN = savedModel.getOnlyInput(biasAdd.getInput(0), "AddN"); 74 | 75 | List inputNames = addN.getInputList(); 76 | for(String inputName : inputNames){ 77 | NodeDef term = savedModel.getOnlyInput(inputName, "MatMul", "Select"); 78 | 79 | // "real_valued_column" 80 | if(("MatMul").equals(term.getOp())){ 81 | NodeDef placeholder = savedModel.getNodeDef(term.getInput(0)); 82 | NodeDef multiplier = savedModel.getOnlyInput(term.getInput(1), "VariableV2"); 83 | 84 | Feature feature = encoder.createContinuousFeature(savedModel, placeholder); 85 | 86 | try(Tensor tensor = savedModel.run(multiplier.getName())){ 87 | float[] values = TensorUtil.toFloatArray(tensor); 88 | 89 | for(int i = 0; i < count; i++){ 90 | Equation equation = equations.get(i); 91 | 92 | equation.addTerm(feature, ValueUtil.floatToDouble(values[i])); 93 | } 94 | } 95 | } else 96 | 97 | // "sparse_column_with_keys" 98 | if(("Select").equals(term.getOp())){ 99 | NodeDef placeholder = savedModel.getOnlyInput(term.getInput(0), "Placeholder"); 100 | NodeDef findTable = savedModel.getOnlyInput(term.getInput(1), "LookupTableFind"); 101 | NodeDef multiplier = savedModel.getOnlyInput(term.getInput(2), "VariableV2"); 102 | 103 | Map table = savedModel.getTable(findTable.getInput(0)); 104 | 105 | List categories = (List)new ArrayList<>(table.keySet()); 106 | 107 | List features = encoder.createBinaryFeatures(savedModel, placeholder, categories); 108 | 109 | float[] values; 110 | 111 | try(Tensor tensor = savedModel.run(multiplier.getName())){ 112 | values = TensorUtil.toFloatArray(tensor); 113 | } 114 | 115 | for(int i = 0; i < equations.size(); i++){ 116 | Equation equation = equations.get(i); 117 | 118 | List categoryValues = CMatrixUtil.getColumn(Floats.asList(values), features.size(), equations.size(), i); 119 | 120 | for(int j = 0; j < features.size(); j++){ 121 | Feature feature = features.get(j); 122 | 123 | int index = ValueUtil.asInt((Number)table.get(categories.get(j))); 124 | 125 | equation.addTerm(feature, ValueUtil.floatToDouble(categoryValues.get(index))); 126 | } 127 | } 128 | } else 129 | 130 | { 131 | throw new IllegalArgumentException(term.getName()); 132 | } 133 | } 134 | 135 | NodeDef bias = savedModel.getOnlyInput(biasAdd.getInput(1), "VariableV2"); 136 | 137 | try(Tensor tensor = savedModel.run(bias.getName())){ 138 | float[] values = TensorUtil.toFloatArray(tensor); 139 | 140 | for(int i = 0; i < count; i++){ 141 | Equation equation = equations.get(i); 142 | 143 | equation.setIntercept(ValueUtil.floatToDouble(values[i])); 144 | } 145 | } 146 | 147 | RegressionModel regressionModel = new RegressionModel() 148 | .setMathContext(MathContext.FLOAT); 149 | 150 | for(Equation equation : equations){ 151 | RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(equation.getFeatures(), equation.getCoefficients(), equation.getIntercept()); 152 | 153 | regressionModel.addRegressionTables(regressionTable); 154 | } 155 | 156 | return regressionModel; 157 | } 158 | 159 | static 160 | private class Equation { 161 | 162 | private List features = new ArrayList<>(); 163 | 164 | private List coefficients = new ArrayList<>(); 165 | 166 | private Double intercept = null; 167 | 168 | 169 | private Equation(){ 170 | } 171 | 172 | public void addTerm(Feature feature, Double coefficient){ 173 | this.features.add(feature); 174 | this.coefficients.add(coefficient); 175 | } 176 | 177 | public List getFeatures(){ 178 | return this.features; 179 | } 180 | 181 | public List getCoefficients(){ 182 | return this.coefficients; 183 | } 184 | 185 | public Double getIntercept(){ 186 | return this.intercept; 187 | } 188 | 189 | public void setIntercept(Double intercept){ 190 | this.intercept = intercept; 191 | } 192 | } 193 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/LinearRegressor.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import org.dmg.pmml.DataField; 22 | import org.dmg.pmml.DataType; 23 | import org.dmg.pmml.FieldName; 24 | import org.dmg.pmml.MiningFunction; 25 | import org.dmg.pmml.OpType; 26 | import org.dmg.pmml.regression.RegressionModel; 27 | import org.jpmml.converter.ContinuousLabel; 28 | import org.jpmml.converter.Label; 29 | import org.jpmml.converter.ModelUtil; 30 | 31 | public class LinearRegressor extends LinearEstimator { 32 | 33 | public LinearRegressor(SavedModel savedModel, String head){ 34 | super(savedModel, head); 35 | } 36 | 37 | @Override 38 | public RegressionModel encodeModel(TensorFlowEncoder encoder){ 39 | DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CONTINUOUS, DataType.FLOAT); 40 | 41 | Label label = new ContinuousLabel(dataField); 42 | 43 | RegressionModel regressionModel = encodeRegressionModel(encoder) 44 | .setMiningFunction(MiningFunction.REGRESSION) 45 | .setMiningSchema(ModelUtil.createMiningSchema(label)); 46 | 47 | return regressionModel; 48 | } 49 | 50 | public static final String REGRESSION_HEAD = "linear/regression_head/predictions/scores"; 51 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/Main.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.io.File; 22 | import java.io.FileOutputStream; 23 | import java.io.OutputStream; 24 | 25 | import com.beust.jcommander.JCommander; 26 | import com.beust.jcommander.Parameter; 27 | import com.beust.jcommander.ParameterException; 28 | import org.dmg.pmml.PMML; 29 | import org.jpmml.model.MetroJAXBUtil; 30 | import org.slf4j.Logger; 31 | import org.slf4j.LoggerFactory; 32 | import org.tensorflow.SavedModelBundle; 33 | 34 | public class Main { 35 | 36 | @Parameter ( 37 | names = "--help", 38 | description = "Show the list of configuration options and exit", 39 | help = true 40 | ) 41 | private boolean help = false; 42 | 43 | @Parameter ( 44 | names = {"--tf-input", "--tf-savedmodel-input"}, 45 | description = "TF SavedModel input directory", 46 | required = true 47 | ) 48 | private File input = null; 49 | 50 | @Parameter ( 51 | names = "--pmml-output", 52 | description = "PMML output file", 53 | required = true 54 | ) 55 | private File output = null; 56 | 57 | 58 | static 59 | public void main(String[] args) throws Exception { 60 | Main main = new Main(); 61 | 62 | JCommander commander = new JCommander(main); 63 | commander.setProgramName(Main.class.getName()); 64 | 65 | try { 66 | commander.parse(args); 67 | } catch(ParameterException pe){ 68 | StringBuilder sb = new StringBuilder(); 69 | 70 | sb.append(pe.toString()); 71 | sb.append("\n"); 72 | 73 | commander.usage(sb); 74 | 75 | System.err.println(sb.toString()); 76 | 77 | System.exit(-1); 78 | } 79 | 80 | if(main.help){ 81 | StringBuilder sb = new StringBuilder(); 82 | 83 | commander.usage(sb); 84 | 85 | System.out.println(sb.toString()); 86 | 87 | System.exit(0); 88 | } 89 | 90 | main.run(); 91 | } 92 | 93 | private void run() throws Exception { 94 | SavedModelBundle bundle; 95 | 96 | try { 97 | logger.info("Parsing SavedModel.."); 98 | 99 | long begin = System.currentTimeMillis(); 100 | bundle = SavedModelBundle.load(this.input.getAbsolutePath(), "serve"); 101 | long end = System.currentTimeMillis(); 102 | 103 | logger.info("Parsed SavedModel in {} ms.", (end - begin)); 104 | } catch(Exception e){ 105 | logger.error("Failed to parse SavedModel", e); 106 | 107 | throw e; 108 | } 109 | 110 | PMML pmml; 111 | 112 | try(SavedModel savedModel = new SavedModel(bundle)){ 113 | logger.info("Converting.."); 114 | 115 | EstimatorFactory estimatorFactory = EstimatorFactory.newInstance(); 116 | 117 | Estimator estimator = estimatorFactory.newEstimator(savedModel); 118 | 119 | long begin = System.currentTimeMillis(); 120 | pmml = estimator.encodePMML(); 121 | long end = System.currentTimeMillis(); 122 | 123 | logger.info("Converted in {} ms.", (end - begin)); 124 | } catch(Exception e){ 125 | logger.error("Failed to convert", e); 126 | 127 | throw e; 128 | } 129 | 130 | try(OutputStream os = new FileOutputStream(this.output)){ 131 | logger.info("Marshalling PMML.."); 132 | 133 | long begin = System.currentTimeMillis(); 134 | MetroJAXBUtil.marshalPMML(pmml, os); 135 | long end = System.currentTimeMillis(); 136 | 137 | logger.info("Marshalled PMML in {}", (end - begin)); 138 | } catch(Exception e){ 139 | logger.error("Failed to marshal PMML", e); 140 | 141 | throw e; 142 | } 143 | } 144 | 145 | private static final Logger logger = LoggerFactory.getLogger(Main.class); 146 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/SavedModel.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.ArrayDeque; 22 | import java.util.Arrays; 23 | import java.util.Collection; 24 | import java.util.Collections; 25 | import java.util.Deque; 26 | import java.util.HashSet; 27 | import java.util.LinkedHashMap; 28 | import java.util.LinkedHashSet; 29 | import java.util.List; 30 | import java.util.Map; 31 | import java.util.Set; 32 | 33 | import com.google.common.base.Function; 34 | import com.google.common.collect.Iterables; 35 | import com.google.protobuf.InvalidProtocolBufferException; 36 | import org.tensorflow.Graph; 37 | import org.tensorflow.Operation; 38 | import org.tensorflow.SavedModelBundle; 39 | import org.tensorflow.Session; 40 | import org.tensorflow.Session.Runner; 41 | import org.tensorflow.Tensor; 42 | import org.tensorflow.framework.CollectionDef; 43 | import org.tensorflow.framework.GraphDef; 44 | import org.tensorflow.framework.MetaGraphDef; 45 | import org.tensorflow.framework.NodeDef; 46 | 47 | public class SavedModel implements AutoCloseable { 48 | 49 | private SavedModelBundle bundle = null; 50 | 51 | private MetaGraphDef metaGraphDef = null; 52 | 53 | private Map nodeMap = null; 54 | 55 | private Map> tableMap = new LinkedHashMap<>(); 56 | 57 | 58 | public SavedModel(SavedModelBundle bundle) throws InvalidProtocolBufferException { 59 | setBundle(bundle); 60 | 61 | byte[] metaGraphDefBytes = bundle.metaGraphDef(); 62 | 63 | MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(metaGraphDefBytes); 64 | 65 | setMetaGraphDef(metaGraphDef); 66 | 67 | GraphDef graphDef = metaGraphDef.getGraphDef(); 68 | 69 | Map nodeMap = new LinkedHashMap<>(); 70 | 71 | List nodeDefs = graphDef.getNodeList(); 72 | for(NodeDef nodeDef : nodeDefs){ 73 | nodeMap.put(nodeDef.getName(), nodeDef); 74 | } 75 | 76 | setNodeMap(nodeMap); 77 | 78 | initializeTables(); 79 | } 80 | 81 | private void initializeTables(){ 82 | Collection tableInitializerNames = Collections.emptyList(); 83 | 84 | try { 85 | CollectionDef collectionDef = getCollectionDef("table_initializer"); 86 | 87 | CollectionDef.NodeList nodeList = collectionDef.getNodeList(); 88 | 89 | tableInitializerNames = nodeList.getValueList(); 90 | } catch(IllegalArgumentException iae){ 91 | // Ignored 92 | } 93 | 94 | for(String tableInitializerName : tableInitializerNames){ 95 | NodeDef tableInitializer = getNodeDef(tableInitializerName); 96 | 97 | String name = tableInitializer.getInput(0); 98 | 99 | List keys; 100 | List values; 101 | 102 | try(Tensor tensor = run(tableInitializer.getInput(1))){ 103 | keys = TensorUtil.getValues(tensor); 104 | } // End try 105 | 106 | try(Tensor tensor = run(tableInitializer.getInput(2))){ 107 | values = TensorUtil.getValues(tensor); 108 | } 109 | 110 | Map table = new LinkedHashMap<>(); 111 | 112 | if(keys.size() != values.size()){ 113 | throw new IllegalArgumentException(); 114 | } 115 | 116 | for(int i = 0; i < keys.size(); i++){ 117 | table.put(keys.get(i), values.get(i)); 118 | } 119 | 120 | putTable(name, table); 121 | } 122 | } 123 | 124 | @Override 125 | public void close(){ 126 | SavedModelBundle bundle = getBundle(); 127 | 128 | bundle.close(); 129 | } 130 | 131 | public Tensor run(String name){ 132 | Session session = getSession(); 133 | 134 | Runner runner = (session.runner()).fetch(name); 135 | 136 | List tensors = runner.run(); 137 | 138 | return Iterables.getOnlyElement(tensors); 139 | } 140 | 141 | public Operation getOperation(String name){ 142 | Graph graph = getGraph(); 143 | 144 | return graph.operation(name); 145 | } 146 | 147 | public NodeDef getNodeDef(String name){ 148 | Map nodeMap = getNodeMap(); 149 | 150 | int colon = name.indexOf(':'); 151 | 152 | NodeDef nodeDef = nodeMap.get(colon > -1 ? name.substring(0, colon) : name); 153 | if(nodeDef == null){ 154 | throw new IllegalArgumentException(name); 155 | } 156 | 157 | return nodeDef; 158 | } 159 | 160 | public CollectionDef getCollectionDef(String key){ 161 | MetaGraphDef metaGraphDef = getMetaGraphDef(); 162 | 163 | Map collectionMap = metaGraphDef.getCollectionDefMap(); 164 | 165 | CollectionDef collectionDef = collectionMap.get(key); 166 | if(collectionDef == null){ 167 | throw new IllegalArgumentException(key); 168 | } 169 | 170 | return collectionDef; 171 | } 172 | 173 | public NodeDef getOnlyInput(String name, String... ops){ 174 | Iterable inputs = getInputs(name, ops); 175 | 176 | return Iterables.getOnlyElement(inputs); 177 | } 178 | 179 | public Iterable getInputs(String name, String... ops){ 180 | NodeDef nodeDef = getNodeDef(name); 181 | 182 | Collection trails = new LinkedHashSet<>(); 183 | 184 | collectInputs(new ArrayDeque<>(), nodeDef, new HashSet<>(Arrays.asList(ops)), trails); 185 | 186 | Function function = new Function(){ 187 | 188 | @Override 189 | public NodeDef apply(Trail trail){ 190 | return trail.getNodeDef(); 191 | } 192 | }; 193 | 194 | Collection inputs = new LinkedHashSet<>(); 195 | 196 | Iterables.addAll(inputs, Iterables.transform(trails, function)); 197 | 198 | return inputs; 199 | } 200 | 201 | private void collectInputs(Deque parentNodeDefs, NodeDef nodeDef, Set ops, Collection trails){ 202 | 203 | if(ops.contains(nodeDef.getOp())){ 204 | trails.add(new Trail(parentNodeDefs, nodeDef)); 205 | } 206 | 207 | List inputNames = nodeDef.getInputList(); 208 | for(String inputName : inputNames){ 209 | NodeDef inputNodeDef = getNodeDef(inputName); 210 | 211 | parentNodeDefs.addFirst(inputNodeDef); 212 | 213 | collectInputs(parentNodeDefs, inputNodeDef, ops, trails); 214 | 215 | parentNodeDefs.removeFirst(); 216 | } 217 | } 218 | 219 | public Map getTable(String name){ 220 | Map table = this.tableMap.get(name); 221 | 222 | if(table == null){ 223 | throw new IllegalArgumentException(name); 224 | } 225 | 226 | return table; 227 | } 228 | 229 | private void putTable(String name, Map table){ 230 | this.tableMap.put(name, table); 231 | } 232 | 233 | public Session getSession(){ 234 | SavedModelBundle bundle = getBundle(); 235 | 236 | return bundle.session(); 237 | } 238 | 239 | public Graph getGraph(){ 240 | SavedModelBundle bundle = getBundle(); 241 | 242 | return bundle.graph(); 243 | } 244 | 245 | public SavedModelBundle getBundle(){ 246 | return this.bundle; 247 | } 248 | 249 | private void setBundle(SavedModelBundle bundle){ 250 | this.bundle = bundle; 251 | } 252 | 253 | public MetaGraphDef getMetaGraphDef(){ 254 | return this.metaGraphDef; 255 | } 256 | 257 | private void setMetaGraphDef(MetaGraphDef metaGraphDef){ 258 | this.metaGraphDef = metaGraphDef; 259 | } 260 | 261 | public Map getNodeMap(){ 262 | return this.nodeMap; 263 | } 264 | 265 | private void setNodeMap(Map nodeMap){ 266 | this.nodeMap = nodeMap; 267 | } 268 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/ShapeUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import org.tensorflow.Shape; 22 | 23 | public class ShapeUtil { 24 | 25 | private ShapeUtil(){ 26 | } 27 | 28 | static 29 | public long[] toArray(Shape shape){ 30 | int length = shape.numDimensions(); 31 | 32 | if(length < 0){ 33 | return null; 34 | } 35 | 36 | long[] result = new long[length]; 37 | 38 | for(int i = 0; i < length; i++){ 39 | result[i] = shape.size(i); 40 | } 41 | 42 | return result; 43 | } 44 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/TensorFlowEncoder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.dmg.pmml.DataField; 25 | import org.dmg.pmml.FieldName; 26 | import org.jpmml.converter.BinaryFeature; 27 | import org.jpmml.converter.ContinuousFeature; 28 | import org.jpmml.converter.ModelEncoder; 29 | import org.tensorflow.Operation; 30 | import org.tensorflow.Output; 31 | import org.tensorflow.framework.NodeDef; 32 | 33 | public class TensorFlowEncoder extends ModelEncoder { 34 | 35 | public DataField ensureDataField(SavedModel savedModel, NodeDef placeholder){ 36 | 37 | if(!("Placeholder").equals(placeholder.getOp())){ 38 | throw new IllegalArgumentException(placeholder.getName()); 39 | } 40 | 41 | FieldName name = FieldName.create(placeholder.getName()); 42 | 43 | DataField dataField = getDataField(name); 44 | if(dataField == null){ 45 | Operation operation = savedModel.getOperation(placeholder.getName()); 46 | 47 | Output output = operation.output(0); 48 | 49 | dataField = createDataField(name, TypeUtil.getOpType(output), TypeUtil.getDataType(output)); 50 | } 51 | 52 | return dataField; 53 | } 54 | 55 | public DataField ensureContinuousDataField(SavedModel savedModel, NodeDef placeholder){ 56 | DataField dataField = ensureDataField(savedModel, placeholder); 57 | 58 | return toContinuous(dataField.getName()); 59 | } 60 | 61 | public DataField ensureCategoricalDataField(SavedModel savedModel, NodeDef placeholder, List values){ 62 | DataField dataField = ensureDataField(savedModel, placeholder); 63 | 64 | return toCategorical(dataField.getName(), values); 65 | } 66 | 67 | public ContinuousFeature createContinuousFeature(SavedModel savedModel, NodeDef placeholder){ 68 | NodeDef cast = null; 69 | 70 | if(("Cast").equals(placeholder.getOp())){ 71 | cast = placeholder; 72 | placeholder = savedModel.getNodeDef(placeholder.getInput(0)); 73 | } 74 | 75 | DataField dataField = ensureContinuousDataField(savedModel, placeholder); 76 | 77 | ContinuousFeature result = new ContinuousFeature(this, dataField); 78 | 79 | if(cast != null){ 80 | Operation operation = savedModel.getOperation(cast.getName()); 81 | 82 | Output output = operation.output(0); 83 | 84 | result = result.toContinuousFeature(TypeUtil.getDataType(output)); 85 | } 86 | 87 | return result; 88 | } 89 | 90 | public List createBinaryFeatures(SavedModel savedModel, NodeDef placeholder, List categories){ 91 | DataField dataField = ensureCategoricalDataField(savedModel, placeholder, categories); 92 | 93 | List result = new ArrayList<>(); 94 | 95 | for(String category : categories){ 96 | BinaryFeature binaryFeature = new BinaryFeature(this, dataField, category); 97 | 98 | result.add(binaryFeature); 99 | } 100 | 101 | return result; 102 | } 103 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/TensorUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.nio.ByteBuffer; 22 | import java.nio.DoubleBuffer; 23 | import java.nio.FloatBuffer; 24 | import java.nio.IntBuffer; 25 | import java.nio.LongBuffer; 26 | import java.util.Arrays; 27 | import java.util.List; 28 | 29 | import com.google.common.primitives.Booleans; 30 | import com.google.common.primitives.Doubles; 31 | import com.google.common.primitives.Floats; 32 | import com.google.common.primitives.Ints; 33 | import com.google.common.primitives.Longs; 34 | import org.tensorflow.DataType; 35 | import org.tensorflow.Tensor; 36 | 37 | public class TensorUtil { 38 | 39 | private TensorUtil(){ 40 | } 41 | 42 | static 43 | public List getValues(Tensor tensor){ 44 | DataType dataType = tensor.dataType(); 45 | 46 | switch(dataType){ 47 | case FLOAT: 48 | return Floats.asList(TensorUtil.toFloatArray(tensor)); 49 | case DOUBLE: 50 | return Doubles.asList(TensorUtil.toDoubleArray(tensor)); 51 | case INT32: 52 | return Ints.asList(TensorUtil.toIntArray(tensor)); 53 | case INT64: 54 | return Longs.asList(TensorUtil.toLongArray(tensor)); 55 | case STRING: 56 | return Arrays.asList(TensorUtil.toStringArray(tensor)); 57 | case BOOL: 58 | return Booleans.asList(TensorUtil.toBooleanArray(tensor)); 59 | default: 60 | throw new IllegalArgumentException(); 61 | } 62 | } 63 | 64 | static 65 | public float toFloatScalar(Tensor tensor){ 66 | 67 | try { 68 | return tensor.floatValue(); 69 | } catch(Exception e){ 70 | float[] values = toFloatArray(tensor); 71 | 72 | if(values.length != 1){ 73 | throw new IllegalArgumentException("Expected 1-element array, got " + Arrays.toString(values)); 74 | } 75 | 76 | return values[0]; 77 | } 78 | } 79 | 80 | static 81 | public float[] toFloatArray(Tensor tensor){ 82 | FloatBuffer floatBuffer = FloatBuffer.allocate(tensor.numElements()); 83 | 84 | tensor.writeTo(floatBuffer); 85 | 86 | return floatBuffer.array(); 87 | } 88 | 89 | static 90 | public double[] toDoubleArray(Tensor tensor){ 91 | DoubleBuffer doubleBuffer = DoubleBuffer.allocate(tensor.numElements()); 92 | 93 | tensor.writeTo(doubleBuffer); 94 | 95 | return doubleBuffer.array(); 96 | } 97 | 98 | static 99 | public int[] toIntArray(Tensor tensor){ 100 | IntBuffer intBuffer = IntBuffer.allocate(tensor.numElements()); 101 | 102 | tensor.writeTo(intBuffer); 103 | 104 | return intBuffer.array(); 105 | } 106 | 107 | static 108 | public long[] toLongArray(Tensor tensor){ 109 | LongBuffer longBuffer = LongBuffer.allocate(tensor.numElements()); 110 | 111 | tensor.writeTo(longBuffer); 112 | 113 | return longBuffer.array(); 114 | } 115 | 116 | static 117 | public String[] toStringArray(Tensor tensor){ 118 | ByteBuffer byteBuffer = ByteBuffer.allocate(tensor.numBytes()); 119 | 120 | tensor.writeTo(byteBuffer); 121 | 122 | byteBuffer.position(tensor.numElements() * 8); 123 | 124 | String[] result = new String[tensor.numElements()]; 125 | 126 | for(int i = 0; i < result.length; i++){ 127 | int length = byteBuffer.get(); 128 | 129 | byte[] buffer = new byte[length]; 130 | 131 | byteBuffer.get(buffer); 132 | 133 | result[i] = new String(buffer); 134 | } 135 | 136 | return result; 137 | } 138 | 139 | static 140 | public boolean[] toBooleanArray(Tensor tensor){ 141 | ByteBuffer byteBuffer = ByteBuffer.allocate(tensor.numElements()); 142 | 143 | tensor.writeTo(byteBuffer); 144 | 145 | boolean[] result = new boolean[tensor.numElements()]; 146 | 147 | for(int i = 0; i < result.length; i++){ 148 | result[i] = (byteBuffer.get(i) != 0); 149 | } 150 | 151 | return result; 152 | } 153 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/Trail.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.util.ArrayList; 22 | import java.util.Deque; 23 | 24 | import org.tensorflow.framework.NodeDef; 25 | 26 | class Trail extends ArrayList { 27 | 28 | Trail(Deque parentNodeDefs, NodeDef nodeDef){ 29 | super(parentNodeDefs); 30 | 31 | add(nodeDef); 32 | } 33 | 34 | public NodeDef getNodeDef(){ 35 | return get(size() - 1); 36 | } 37 | } -------------------------------------------------------------------------------- /src/main/java/org/jpmml/tensorflow/TypeUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import org.dmg.pmml.DataType; 22 | import org.dmg.pmml.OpType; 23 | import org.tensorflow.Output; 24 | 25 | public class TypeUtil { 26 | 27 | private TypeUtil(){ 28 | } 29 | 30 | static 31 | public OpType getOpType(Output output){ 32 | org.tensorflow.DataType dataType = output.dataType(); 33 | 34 | switch(dataType){ 35 | case FLOAT: 36 | case DOUBLE: 37 | case INT32: 38 | case INT64: 39 | return OpType.CONTINUOUS; 40 | case STRING: 41 | case BOOL: 42 | return OpType.CATEGORICAL; 43 | default: 44 | throw new IllegalArgumentException(); 45 | } 46 | } 47 | 48 | static 49 | public DataType getDataType(Output output){ 50 | org.tensorflow.DataType dataType = output.dataType(); 51 | 52 | switch(dataType){ 53 | case FLOAT: 54 | return DataType.FLOAT; 55 | case DOUBLE: 56 | return DataType.DOUBLE; 57 | case INT32: 58 | case INT64: 59 | return DataType.INTEGER; 60 | case STRING: 61 | return DataType.STRING; 62 | case BOOL: 63 | return DataType.BOOLEAN; 64 | default: 65 | throw new IllegalArgumentException(); 66 | } 67 | } 68 | } -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/attr_value.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "AttrValueProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorflow/core/framework/tensor.proto"; 10 | import "tensorflow/core/framework/tensor_shape.proto"; 11 | import "tensorflow/core/framework/types.proto"; 12 | 13 | // Protocol buffer representing the value for an attr used to configure an Op. 14 | // Comment indicates the corresponding attr type. Only the field matching the 15 | // attr type may be filled. 16 | message AttrValue { 17 | // LINT.IfChange 18 | message ListValue { 19 | repeated bytes s = 2; // "list(string)" 20 | repeated int64 i = 3 [packed = true]; // "list(int)" 21 | repeated float f = 4 [packed = true]; // "list(float)" 22 | repeated bool b = 5 [packed = true]; // "list(bool)" 23 | repeated DataType type = 6 [packed = true]; // "list(type)" 24 | repeated TensorShapeProto shape = 7; // "list(shape)" 25 | repeated TensorProto tensor = 8; // "list(tensor)" 26 | repeated NameAttrList func = 9; // "list(attr)" 27 | } 28 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) 29 | 30 | oneof value { 31 | bytes s = 2; // "string" 32 | int64 i = 3; // "int" 33 | float f = 4; // "float" 34 | bool b = 5; // "bool" 35 | DataType type = 6; // "type" 36 | TensorShapeProto shape = 7; // "shape" 37 | TensorProto tensor = 8; // "tensor" 38 | ListValue list = 1; // any "list(...)" 39 | 40 | // "func" represents a function. func.name is a function's name or 41 | // a primitive op's name. func.attr.first is the name of an attr 42 | // defined for that function. func.attr.second is the value for 43 | // that attr in the instantiation. 44 | NameAttrList func = 10; 45 | 46 | // This is a placeholder only used in nodes defined inside a 47 | // function. It indicates the attr value will be supplied when 48 | // the function is instantiated. For example, let us suppose a 49 | // node "N" in function "FN". "N" has an attr "A" with value 50 | // placeholder = "foo". When FN is instantiated with attr "foo" 51 | // set to "bar", the instantiated node N's attr A will have been 52 | // given the value "bar". 53 | string placeholder = 9; 54 | } 55 | } 56 | 57 | // A list of attr names and their values. The whole list is attached 58 | // with a string name. E.g., MatMul[T=float]. 59 | message NameAttrList { 60 | string name = 1; 61 | map attr = 2; 62 | } 63 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/function.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "FunctionProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorflow/core/framework/attr_value.proto"; 10 | import "tensorflow/core/framework/node_def.proto"; 11 | import "tensorflow/core/framework/op_def.proto"; 12 | 13 | // A library is a set of named functions. 14 | message FunctionDefLibrary { 15 | repeated FunctionDef function = 1; 16 | repeated GradientDef gradient = 2; 17 | } 18 | 19 | // A function can be instantiated when the runtime can bind every attr 20 | // with a value. When a GraphDef has a call to a function, it must 21 | // have binding for every attr defined in the signature. 22 | // 23 | // TODO(zhifengc): 24 | // * device spec, etc. 25 | message FunctionDef { 26 | // The definition of the function's name, arguments, return values, 27 | // attrs etc. 28 | OpDef signature = 1; 29 | 30 | // Attributes specific to this function definition. 31 | map attr = 5; 32 | 33 | // NOTE: field id 2 deleted on Jan 11, 2016, GraphDef version 21. 34 | 35 | // In both of the following fields, there is the need to specify an 36 | // output that is used as either the input to another node (in 37 | // `node_def`) or as a return value of the function (in `ret`). 38 | // Unlike the NodeDefs in GraphDef, we need to be able to specify a 39 | // list in some cases (instead of just single outputs). Also, we 40 | // need to be able to deal with lists of unknown length (so the 41 | // output index may not be known at function definition time). So 42 | // we use the following format instead: 43 | // * "fun_in" where "fun_in" is the name of a function input arg in 44 | // the `signature` field above. This represents that input, whether 45 | // it is a single tensor or a list. 46 | // * "fun_in:0" gives the first element of a function input arg (a 47 | // non-list input is considered a list of length 1 for these 48 | // purposes). 49 | // * "node:out" where "node" is the name of a node in `node_def` and 50 | // "out" is the name one of its op's output arguments (the name 51 | // comes from the OpDef of the node's op). This represents that 52 | // node's output, whether it is a single tensor or a list. 53 | // Note: We enforce that an op's output arguments are never 54 | // renamed in the backwards-compatibility test. 55 | // * "node:out:0" gives the first element of a node output arg (a 56 | // non-list output is considered a list of length 1 for these 57 | // purposes). 58 | // 59 | // NOT CURRENTLY SUPPORTED (but may be in the future): 60 | // * "node:out:-1" gives last element in a node output list 61 | // * "node:out:1:" gives a list with all but the first element in a 62 | // node output list 63 | // * "node:out::-1" gives a list with all but the last element in a 64 | // node output list 65 | 66 | // The body of the function. Unlike the NodeDefs in a GraphDef, attrs 67 | // may have values of type `placeholder` and the `input` field uses 68 | // the "output" format above. 69 | 70 | // By convention, "op" in node_def is resolved by consulting with a 71 | // user-defined library first. If not resolved, "func" is assumed to 72 | // be a builtin op. 73 | repeated NodeDef node_def = 3; 74 | 75 | // A mapping from the output arg names from `signature` to the 76 | // outputs from `node_def` that should be returned by the function. 77 | map ret = 4; 78 | } 79 | 80 | // GradientDef defines the gradient function of a function defined in 81 | // a function library. 82 | // 83 | // A gradient function g (specified by gradient_func) for a function f 84 | // (specified by function_name) must follow the following: 85 | // 86 | // The function 'f' must be a numerical function which takes N inputs 87 | // and produces M outputs. Its gradient function 'g', which is a 88 | // function taking N + M inputs and produces N outputs. 89 | // 90 | // I.e. if we have 91 | // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), 92 | // then, g is 93 | // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, 94 | // dL/dy1, dL/dy2, ..., dL/dy_M), 95 | // where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the 96 | // loss function). dL/dx_i is the partial derivative of L with respect 97 | // to x_i. 98 | message GradientDef { 99 | string function_name = 1; // The function name. 100 | string gradient_func = 2; // The gradient function's name. 101 | } 102 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "GraphProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorflow/core/framework/node_def.proto"; 10 | import "tensorflow/core/framework/function.proto"; 11 | import "tensorflow/core/framework/versions.proto"; 12 | 13 | // Represents the graph of operations 14 | message GraphDef { 15 | repeated NodeDef node = 1; 16 | 17 | // Compatibility versions of the graph. See core/public/version.h for version 18 | // history. The GraphDef version is distinct from the TensorFlow version, and 19 | // each release of TensorFlow will support a range of GraphDef versions. 20 | VersionDef versions = 4; 21 | 22 | // Deprecated single version field; use versions above instead. Since all 23 | // GraphDef changes before "versions" was introduced were forward 24 | // compatible, this field is entirely ignored. 25 | int32 version = 3 [deprecated = true]; 26 | 27 | // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. 28 | // 29 | // "library" provides user-defined functions. 30 | // 31 | // Naming: 32 | // * library.function.name are in a flat namespace. 33 | // NOTE: We may need to change it to be hierarchical to support 34 | // different orgs. E.g., 35 | // { "/google/nn", { ... }}, 36 | // { "/google/vision", { ... }} 37 | // { "/org_foo/module_bar", { ... }} 38 | // map named_lib; 39 | // * If node[i].op is the name of one function in "library", 40 | // node[i] is deemed as a function call. Otherwise, node[i].op 41 | // must be a primitive operation supported by the runtime. 42 | // 43 | // 44 | // Function call semantics: 45 | // 46 | // * The callee may start execution as soon as some of its inputs 47 | // are ready. The caller may want to use Tuple() mechanism to 48 | // ensure all inputs are ready in the same time. 49 | // 50 | // * The consumer of return values may start executing as soon as 51 | // the return values the consumer depends on are ready. The 52 | // consumer may want to use Tuple() mechanism to ensure the 53 | // consumer does not start until all return values of the callee 54 | // function are ready. 55 | FunctionDefLibrary library = 2; 56 | }; 57 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/node_def.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "NodeProto"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorflow/core/framework/attr_value.proto"; 10 | 11 | message NodeDef { 12 | // The name given to this operator. Used for naming inputs, 13 | // logging, visualization, etc. Unique within a single GraphDef. 14 | // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". 15 | string name = 1; 16 | 17 | // The operation name. There may be custom parameters in attrs. 18 | // Op names starting with an underscore are reserved for internal use. 19 | string op = 2; 20 | 21 | // Each input is "node:src_output" with "node" being a string name and 22 | // "src_output" indicating which output tensor to use from "node". If 23 | // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs 24 | // may optionally be followed by control inputs that have the format 25 | // "^node". 26 | repeated string input = 3; 27 | 28 | // A (possibly partial) specification for the device on which this 29 | // node should be placed. 30 | // The expected syntax for this string is as follows: 31 | // 32 | // DEVICE_SPEC ::= PARTIAL_SPEC 33 | // 34 | // PARTIAL_SPEC ::= ("/" CONSTRAINT) * 35 | // CONSTRAINT ::= ("job:" JOB_NAME) 36 | // | ("replica:" [1-9][0-9]*) 37 | // | ("task:" [1-9][0-9]*) 38 | // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) 39 | // 40 | // Valid values for this string include: 41 | // * "/job:worker/replica:0/task:1/gpu:3" (full specification) 42 | // * "/job:worker/gpu:3" (partial specification) 43 | // * "" (no specification) 44 | // 45 | // If the constraints do not resolve to a single device (or if this 46 | // field is empty or not present), the runtime will attempt to 47 | // choose a device automatically. 48 | string device = 4; 49 | 50 | // Operation-specific graph-construction-time configuration. 51 | // Note that this should include all attrs defined in the 52 | // corresponding OpDef, including those with a value matching 53 | // the default -- this allows the default to change and makes 54 | // NodeDefs easier to interpret on their own. However, if 55 | // an attr with a default is not specified in this list, the 56 | // default will be used. 57 | // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and 58 | // one of the names from the corresponding OpDef's attr field). 59 | // The values must have a type matching the corresponding OpDef 60 | // attr's type field. 61 | // TODO(josh11b): Add some examples here showing best practices. 62 | map attr = 5; 63 | }; 64 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/op_def.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "OpDefProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorflow/core/framework/attr_value.proto"; 10 | import "tensorflow/core/framework/types.proto"; 11 | 12 | // Defines an operation. A NodeDef in a GraphDef specifies an Op by 13 | // using the "op" field which should match the name of a OpDef. 14 | message OpDef { 15 | // Op names starting with an underscore are reserved for internal use. 16 | // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". 17 | string name = 1; 18 | 19 | // For describing inputs and outputs. 20 | message ArgDef { 21 | // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". 22 | string name = 1; 23 | 24 | // Human readable description. 25 | string description = 2; 26 | 27 | // Describes the type of one or more tensors that are accepted/produced 28 | // by this input/output arg. The only legal combinations are: 29 | // * For a single tensor: either the "type" field is set or the 30 | // "type_attr" field is set to the name of an attr with type "type". 31 | // * For a sequence of tensors with the same type: the "number_attr" 32 | // field will be set to the name of an attr with type "int", and 33 | // either the "type" or "type_attr" field will be set as for 34 | // single tensors. 35 | // * For a sequence of tensors, the "type_list_attr" field will be set 36 | // to the name of an attr with type "list(type)". 37 | DataType type = 3; 38 | string type_attr = 4; // if specified, attr must have type "type" 39 | string number_attr = 5; // if specified, attr must have type "int" 40 | // If specified, attr must have type "list(type)", and none of 41 | // type, type_attr, and number_attr may be specified. 42 | string type_list_attr = 6; 43 | 44 | // For inputs: if true, the inputs are required to be refs. 45 | // By default, inputs can be either refs or non-refs. 46 | // For outputs: if true, outputs are refs, otherwise they are not. 47 | bool is_ref = 16; 48 | }; 49 | 50 | // Description of the input(s). 51 | repeated ArgDef input_arg = 2; 52 | 53 | // Description of the output(s). 54 | repeated ArgDef output_arg = 3; 55 | 56 | // Description of the graph-construction-time configuration of this 57 | // Op. That is to say, this describes the attr fields that will 58 | // be specified in the NodeDef. 59 | message AttrDef { 60 | // A descriptive name for the argument. May be used, e.g. by the 61 | // Python client, as a keyword argument name, and so should match 62 | // the regexp "[a-z][a-z0-9_]+". 63 | string name = 1; 64 | 65 | // One of the type names from attr_value.proto ("string", "list(string)", 66 | // "int", etc.). 67 | string type = 2; 68 | 69 | // A reasonable default for this attribute if the user does not supply 70 | // a value. If not specified, the user must supply a value. 71 | AttrValue default_value = 3; 72 | 73 | // Human-readable description. 74 | string description = 4; 75 | 76 | // TODO(josh11b): bool is_optional? 77 | 78 | // --- Constraints --- 79 | // These constraints are only in effect if specified. Default is no 80 | // constraints. 81 | 82 | // For type == "int", this is a minimum value. For "list(___)" 83 | // types, this is the minimum length. 84 | bool has_minimum = 5; 85 | int64 minimum = 6; 86 | 87 | // The set of allowed values. Has type that is the "list" version 88 | // of the "type" field above (uses the "list" field of AttrValue). 89 | // If type == "type" or "list(type)" above, then the "type" field 90 | // of "allowed_values.list" has the set of allowed DataTypes. 91 | // If type == "string" or "list(string)", then the "s" field of 92 | // "allowed_values.list" has the set of allowed strings. 93 | AttrValue allowed_values = 7; 94 | } 95 | repeated AttrDef attr = 4; 96 | 97 | // Optional deprecation based on GraphDef versions. 98 | OpDeprecation deprecation = 8; 99 | 100 | // One-line human-readable description of what the Op does. 101 | string summary = 5; 102 | 103 | // Additional, longer human-readable description of what the Op does. 104 | string description = 6; 105 | 106 | // ------------------------------------------------------------------------- 107 | // Which optimizations this operation can participate in. 108 | 109 | // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) 110 | bool is_commutative = 18; 111 | 112 | // If is_aggregate is true, then this operation accepts N >= 2 113 | // inputs and produces 1 output all of the same type. Should be 114 | // associative and commutative, and produce output with the same 115 | // shape as the input. The optimizer may replace an aggregate op 116 | // taking input from multiple devices with a tree of aggregate ops 117 | // that aggregate locally within each device (and possibly within 118 | // groups of nearby devices) before communicating. 119 | // TODO(josh11b): Implement that optimization. 120 | bool is_aggregate = 16; // for things like add 121 | 122 | // Other optimizations go here, like 123 | // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. 124 | 125 | // ------------------------------------------------------------------------- 126 | // Optimization constraints. 127 | 128 | // By default Ops may be moved between devices. Stateful ops should 129 | // either not be moved, or should only be moved if that state can also 130 | // be moved (e.g. via some sort of save / restore). 131 | // Stateful ops are guaranteed to never be optimized away by Common 132 | // Subexpression Elimination (CSE). 133 | bool is_stateful = 17; // for things like variables, queue 134 | 135 | // ------------------------------------------------------------------------- 136 | // Non-standard options. 137 | 138 | // By default, all inputs to an Op must be initialized Tensors. Ops 139 | // that may initialize tensors for the first time should set this 140 | // field to true, to allow the Op to take an uninitialized Tensor as 141 | // input. 142 | bool allows_uninitialized_input = 19; // for Assign, etc. 143 | }; 144 | 145 | // Information about version-dependent deprecation of an op 146 | message OpDeprecation { 147 | // First GraphDef version at which the op is disallowed. 148 | int32 version = 1; 149 | 150 | // Explanation of why it was deprecated and what to use instead. 151 | string explanation = 2; 152 | }; 153 | 154 | // A collection of OpDefs 155 | message OpList { 156 | repeated OpDef op = 1; 157 | }; 158 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "ResourceHandleProto"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 10 | // not valid across executions, but can be serialized back and forth from within 11 | // a single run. 12 | message ResourceHandle { 13 | // Unique name for the device containing the resource. 14 | string device = 1; 15 | 16 | // Container in which this resource is placed. 17 | string container = 2; 18 | 19 | // Unique name of this resource. 20 | string name = 3; 21 | 22 | // Hash code for the type of the resource. Is only valid in the same device 23 | // and in the same execution. 24 | uint64 hash_code = 4; 25 | 26 | // For debug-only, the name of the type pointed to by this handle, if 27 | // available. 28 | string maybe_type_name = 5; 29 | }; 30 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorflow/core/framework/resource_handle.proto"; 10 | import "tensorflow/core/framework/tensor_shape.proto"; 11 | import "tensorflow/core/framework/types.proto"; 12 | 13 | // Protocol buffer representing a tensor. 14 | message TensorProto { 15 | DataType dtype = 1; 16 | 17 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues. 18 | TensorShapeProto tensor_shape = 2; 19 | 20 | // Only one of the representations below is set, one of "tensor_contents" and 21 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot 22 | // contain repeated fields it would require another extra set of messages. 23 | 24 | // Version number. 25 | // 26 | // In version 0, if the "repeated xxx" representations contain only one 27 | // element, that element is repeated to fill the shape. This makes it easy 28 | // to represent a constant Tensor with a single value. 29 | int32 version_number = 3; 30 | 31 | // Serialized raw tensor content from either Tensor::AsProtoTensorContent or 32 | // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation 33 | // can be used for all tensor types. The purpose of this representation is to 34 | // reduce serialization overhead during RPC call by avoiding serialization of 35 | // many repeated small items. 36 | bytes tensor_content = 4; 37 | 38 | // Type specific representations that make it easy to create tensor protos in 39 | // all languages. Only the representation corresponding to "dtype" can 40 | // be set. The values hold the flattened representation of the tensor in 41 | // row major order. 42 | 43 | // DT_HALF. Note that since protobuf has no int16 type, we'll have some 44 | // pointless zero padding for each value here. 45 | repeated int32 half_val = 13 [packed = true]; 46 | 47 | // DT_FLOAT. 48 | repeated float float_val = 5 [packed = true]; 49 | 50 | // DT_DOUBLE. 51 | repeated double double_val = 6 [packed = true]; 52 | 53 | // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. 54 | repeated int32 int_val = 7 [packed = true]; 55 | 56 | // DT_STRING 57 | repeated bytes string_val = 8; 58 | 59 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real 60 | // and imaginary parts of i-th single precision complex. 61 | repeated float scomplex_val = 9 [packed = true]; 62 | 63 | // DT_INT64 64 | repeated int64 int64_val = 10 [packed = true]; 65 | 66 | // DT_BOOL 67 | repeated bool bool_val = 11 [packed = true]; 68 | 69 | // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real 70 | // and imaginary parts of i-th double precision complex. 71 | repeated double dcomplex_val = 12 [packed = true]; 72 | 73 | // DT_RESOURCE 74 | repeated ResourceHandle resource_handle_val = 14; 75 | }; 76 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorShapeProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | package tensorflow; 10 | 11 | // Dimensions of a tensor. 12 | message TensorShapeProto { 13 | // One dimension of the tensor. 14 | message Dim { 15 | // Size of the tensor in that dimension. 16 | // This value must be >= -1, but values of -1 are reserved for "unknown" 17 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 18 | // that work with TensorShapeProto may fail at runtime when deserializing 19 | // a TensorShapeProto containing a dim value of -1. 20 | int64 size = 1; 21 | 22 | // Optional name of the tensor dimension. 23 | string name = 2; 24 | }; 25 | 26 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 27 | // for a 30 x 40 2D tensor. If an entry has size -1, this 28 | // corresponds to a dimension of unknown size. The names are 29 | // optional. 30 | // 31 | // The order of entries in "dim" matters: It indicates the layout of the 32 | // values in the tensor in-memory representation. 33 | // 34 | // The first entry in "dim" is the outermost dimension used to layout the 35 | // values, the last entry is the innermost dimension. This matches the 36 | // in-memory layout of RowMajor Eigen tensors. 37 | // 38 | // If "dim.size()" > 0, "unknown_rank" must be false. 39 | repeated Dim dim = 2; 40 | 41 | // If true, the number of dimensions in the shape is unknown. 42 | // 43 | // If true, "dim.size()" must be 0. 44 | bool unknown_rank = 3; 45 | }; 46 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/types.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TypesProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // LINT.IfChange 10 | enum DataType { 11 | // Not a legal value for DataType. Used to indicate a DataType field 12 | // has not been set. 13 | DT_INVALID = 0; 14 | 15 | // Data types that all computation devices are expected to be 16 | // capable to support. 17 | DT_FLOAT = 1; 18 | DT_DOUBLE = 2; 19 | DT_INT32 = 3; 20 | DT_UINT8 = 4; 21 | DT_INT16 = 5; 22 | DT_INT8 = 6; 23 | DT_STRING = 7; 24 | DT_COMPLEX64 = 8; // Single-precision complex 25 | DT_INT64 = 9; 26 | DT_BOOL = 10; 27 | DT_QINT8 = 11; // Quantized int8 28 | DT_QUINT8 = 12; // Quantized uint8 29 | DT_QINT32 = 13; // Quantized int32 30 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. 31 | DT_QINT16 = 15; // Quantized int16 32 | DT_QUINT16 = 16; // Quantized uint16 33 | DT_UINT16 = 17; 34 | DT_COMPLEX128 = 18; // Double-precision complex 35 | DT_HALF = 19; 36 | DT_RESOURCE = 20; 37 | 38 | // TODO(josh11b): DT_GENERIC_PROTO = ??; 39 | // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? 40 | 41 | // Do not use! These are only for parameters. Every enum above 42 | // should have a corresponding value below (verified by types_test). 43 | DT_FLOAT_REF = 101; 44 | DT_DOUBLE_REF = 102; 45 | DT_INT32_REF = 103; 46 | DT_UINT8_REF = 104; 47 | DT_INT16_REF = 105; 48 | DT_INT8_REF = 106; 49 | DT_STRING_REF = 107; 50 | DT_COMPLEX64_REF = 108; 51 | DT_INT64_REF = 109; 52 | DT_BOOL_REF = 110; 53 | DT_QINT8_REF = 111; 54 | DT_QUINT8_REF = 112; 55 | DT_QINT32_REF = 113; 56 | DT_BFLOAT16_REF = 114; 57 | DT_QINT16_REF = 115; 58 | DT_QUINT16_REF = 116; 59 | DT_UINT16_REF = 117; 60 | DT_COMPLEX128_REF = 118; 61 | DT_HALF_REF = 119; 62 | DT_RESOURCE_REF = 120; 63 | } 64 | // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) 65 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/framework/versions.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "VersionsProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Version information for a piece of serialized data 10 | // 11 | // There are different types of versions for each type of data 12 | // (GraphDef, etc.), but they all have the same common shape 13 | // described here. 14 | // 15 | // Each consumer has "consumer" and "min_producer" versions (specified 16 | // elsewhere). A consumer is allowed to consume this data if 17 | // 18 | // producer >= min_producer 19 | // consumer >= min_consumer 20 | // consumer not in bad_consumers 21 | // 22 | message VersionDef { 23 | // The version of the code that produced this data. 24 | int32 producer = 1; 25 | 26 | // Any consumer below this version is not allowed to consume this data. 27 | int32 min_consumer = 2; 28 | 29 | // Specific consumer versions which are disallowed (e.g. due to bugs). 30 | repeated int32 bad_consumers = 3; 31 | }; 32 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/protobuf/meta_graph.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "MetaGraphProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "google/protobuf/any.proto"; 10 | 11 | import "tensorflow/core/framework/graph.proto"; 12 | import "tensorflow/core/framework/op_def.proto"; 13 | import "tensorflow/core/framework/tensor_shape.proto"; 14 | import "tensorflow/core/framework/types.proto"; 15 | import "tensorflow/core/protobuf/saver.proto"; 16 | 17 | // NOTE: This protocol buffer is evolving, and will go through revisions in the 18 | // coming months. 19 | // 20 | // Protocol buffer containing the following which are necessary to restart 21 | // training, run inference. It can be used to serialize/de-serialize memory 22 | // objects necessary for running computation in a graph when crossing the 23 | // process boundary. It can be used for long term storage of graphs, 24 | // cross-language execution of graphs, etc. 25 | // MetaInfoDef 26 | // GraphDef 27 | // SaverDef 28 | // CollectionDef 29 | // TensorInfo 30 | // SignatureDef 31 | message MetaGraphDef { 32 | // Meta information regarding the graph to be exported. To be used by users 33 | // of this protocol buffer to encode information regarding their meta graph. 34 | message MetaInfoDef { 35 | // User specified Version string. Can be the name of the model and revision, 36 | // steps this model has been trained to, etc. 37 | string meta_graph_version = 1; 38 | 39 | // A copy of the OpDefs used by the producer of this graph_def. 40 | // Descriptions and Ops not used in graph_def are stripped out. 41 | OpList stripped_op_list = 2; 42 | 43 | // A serialized protobuf. Can be the time this meta graph is created, or 44 | // modified, or name of the model. 45 | google.protobuf.Any any_info = 3; 46 | 47 | // User supplied tag(s) on the meta_graph and included graph_def. 48 | // 49 | // MetaGraphDefs should be tagged with their capabilities or use-cases. 50 | // Examples: "train", "serve", "gpu", "tpu", etc. 51 | // These tags enable loaders to access the MetaGraph(s) appropriate for a 52 | // specific use-case or runtime environment. 53 | repeated string tags = 4; 54 | 55 | // The __version__ string of the tensorflow build used to write this graph. 56 | // This will be populated by the framework, which will overwrite any user 57 | // supplied value. 58 | string tensorflow_version = 5; 59 | 60 | // The __git_version__ string of the tensorflow build used to write this 61 | // graph. This will be populated by the framework, which will overwrite any 62 | // user supplied value. 63 | string tensorflow_git_version = 6; 64 | } 65 | MetaInfoDef meta_info_def = 1; 66 | 67 | // GraphDef. 68 | GraphDef graph_def = 2; 69 | 70 | // SaverDef. 71 | SaverDef saver_def = 3; 72 | 73 | // collection_def: Map from collection name to collections. 74 | // See CollectionDef section for details. 75 | map collection_def = 4; 76 | 77 | // signature_def: Map from user supplied key for a signature to a single 78 | // SignatureDef. 79 | map signature_def = 5; 80 | 81 | // Asset file def to be used with the defined graph. 82 | repeated AssetFileDef asset_file_def = 6; 83 | } 84 | 85 | // CollectionDef should cover most collections. 86 | // To add a user-defined collection, do one of the following: 87 | // 1. For simple data types, such as string, int, float: 88 | // tf.add_to_collection("your_collection_name", your_simple_value) 89 | // strings will be stored as bytes_list. 90 | // 91 | // 2. For Protobuf types, there are three ways to add them: 92 | // 1) tf.add_to_collection("your_collection_name", 93 | // your_proto.SerializeToString()) 94 | // 95 | // collection_def { 96 | // key: "user_defined_bytes_collection" 97 | // value { 98 | // bytes_list { 99 | // value: "queue_name: \"test_queue\"\n" 100 | // } 101 | // } 102 | // } 103 | // 104 | // or 105 | // 106 | // 2) tf.add_to_collection("your_collection_name", str(your_proto)) 107 | // 108 | // collection_def { 109 | // key: "user_defined_string_collection" 110 | // value { 111 | // bytes_list { 112 | // value: "\n\ntest_queue" 113 | // } 114 | // } 115 | // } 116 | // 117 | // or 118 | // 119 | // 3) any_buf = any_pb2.Any() 120 | // tf.add_to_collection("your_collection_name", 121 | // any_buf.Pack(your_proto)) 122 | // 123 | // collection_def { 124 | // key: "user_defined_any_collection" 125 | // value { 126 | // any_list { 127 | // value { 128 | // type_url: "type.googleapis.com/tensorflow.QueueRunnerDef" 129 | // value: "\n\ntest_queue" 130 | // } 131 | // } 132 | // } 133 | // } 134 | // 135 | // 3. For Python objects, implement to_proto() and from_proto(), and register 136 | // them in the following manner: 137 | // ops.register_proto_function("your_collection_name", 138 | // proto_type, 139 | // to_proto=YourPythonObject.to_proto, 140 | // from_proto=YourPythonObject.from_proto) 141 | // These functions will be invoked to serialize and de-serialize the 142 | // collection. For example, 143 | // ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES, 144 | // proto_type=variable_pb2.VariableDef, 145 | // to_proto=Variable.to_proto, 146 | // from_proto=Variable.from_proto) 147 | message CollectionDef { 148 | // NodeList is used for collecting nodes in graph. For example 149 | // collection_def { 150 | // key: "summaries" 151 | // value { 152 | // node_list { 153 | // value: "input_producer/ScalarSummary:0" 154 | // value: "shuffle_batch/ScalarSummary:0" 155 | // value: "ImageSummary:0" 156 | // } 157 | // } 158 | message NodeList { 159 | repeated string value = 1; 160 | } 161 | 162 | // BytesList is used for collecting strings and serialized protobufs. For 163 | // example: 164 | // collection_def { 165 | // key: "trainable_variables" 166 | // value { 167 | // bytes_list { 168 | // value: "\n\017conv1/weights:0\022\024conv1/weights/Assign 169 | // \032\024conv1/weights/read:0" 170 | // value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032 171 | // \023conv1/biases/read:0" 172 | // } 173 | // } 174 | // } 175 | message BytesList { 176 | repeated bytes value = 1; 177 | } 178 | 179 | // Int64List is used for collecting int, int64 and long values. 180 | message Int64List { 181 | repeated int64 value = 1 [packed = true]; 182 | } 183 | 184 | // FloatList is used for collecting float values. 185 | message FloatList { 186 | repeated float value = 1 [packed = true]; 187 | } 188 | 189 | // AnyList is used for collecting Any protos. 190 | message AnyList { 191 | repeated google.protobuf.Any value = 1; 192 | } 193 | 194 | oneof kind { 195 | NodeList node_list = 1; 196 | BytesList bytes_list = 2; 197 | Int64List int64_list = 3; 198 | FloatList float_list = 4; 199 | AnyList any_list = 5; 200 | } 201 | } 202 | 203 | // Information about a Tensor necessary for feeding or retrieval. 204 | message TensorInfo { 205 | string name = 1; 206 | DataType dtype = 2; 207 | TensorShapeProto tensor_shape = 3; 208 | } 209 | 210 | // SignatureDef defines the signature of a computation supported by a TensorFlow 211 | // graph. 212 | // 213 | // For example, a model with two loss computations, sharing a single input, 214 | // might have the following signature_def map. 215 | // 216 | // Note that across the two SignatureDefs "loss_A" and "loss_B", the input key, 217 | // output key, and method_name are identical, and will be used by system(s) that 218 | // implement or rely upon this particular loss method. The output tensor names 219 | // differ, demonstrating how different outputs can exist for the same method. 220 | // 221 | // signature_def { 222 | // key: "loss_A" 223 | // value { 224 | // inputs { 225 | // key: "input" 226 | // value { 227 | // name: "input:0" 228 | // dtype: DT_STRING 229 | // tensor_shape: ... 230 | // } 231 | // } 232 | // outputs { 233 | // key: "loss_output" 234 | // value { 235 | // name: "loss_output_A:0" 236 | // dtype: DT_FLOAT 237 | // tensor_shape: ... 238 | // } 239 | // } 240 | // } 241 | // ... 242 | // method_name: "some/package/compute_loss" 243 | // } 244 | // signature_def { 245 | // key: "loss_B" 246 | // value { 247 | // inputs { 248 | // key: "input" 249 | // value { 250 | // name: "input:0" 251 | // dtype: DT_STRING 252 | // tensor_shape: ... 253 | // } 254 | // } 255 | // outputs { 256 | // key: "loss_output" 257 | // value { 258 | // name: "loss_output_B:0" 259 | // dtype: DT_FLOAT 260 | // tensor_shape: ... 261 | // } 262 | // } 263 | // } 264 | // ... 265 | // method_name: "some/package/compute_loss" 266 | // } 267 | message SignatureDef { 268 | // Named input parameters. 269 | map inputs = 1; 270 | // Named output parameters. 271 | map outputs = 2; 272 | // Extensible method_name information enabling third-party users to mark a 273 | // SignatureDef as supporting a particular method. This enables producers and 274 | // consumers of SignatureDefs, e.g. a model definition library and a serving 275 | // library to have a clear hand-off regarding the semantics of a computation. 276 | // 277 | // Note that multiple SignatureDefs in a single MetaGraphDef may have the same 278 | // method_name. This is commonly used to support multi-headed computation, 279 | // where a single graph computation may return multiple results. 280 | string method_name = 3; 281 | } 282 | 283 | // An asset file def for a single file or a set of sharded files with the same 284 | // name. 285 | message AssetFileDef { 286 | // The tensor to bind the asset filename to. 287 | TensorInfo tensor_info = 1; 288 | // The filename within an assets directory. Note: does not include the path 289 | // prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename 290 | // would be "vocab.txt". 291 | string filename = 2; 292 | } 293 | -------------------------------------------------------------------------------- /src/main/proto/tensorflow/core/protobuf/saver.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "SaverProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.util"; 8 | 9 | // Protocol buffer representing the configuration of a Saver. 10 | message SaverDef { 11 | // The name of the tensor in which to specify the filename when saving or 12 | // restoring a model checkpoint. 13 | string filename_tensor_name = 1; 14 | 15 | // The operation to run when saving a model checkpoint. 16 | string save_tensor_name = 2; 17 | 18 | // The operation to run when restoring a model checkpoint. 19 | string restore_op_name = 3; 20 | 21 | // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. 22 | int32 max_to_keep = 4; 23 | 24 | // Shard the save files, one per device that has Variable nodes. 25 | bool sharded = 5; 26 | 27 | // How often to keep an additional checkpoint. If not specified, only the last 28 | // "max_to_keep" checkpoints are kept; if specified, in addition to keeping 29 | // the last "max_to_keep" checkpoints, an additional checkpoint will be kept 30 | // for every n hours of training. 31 | float keep_checkpoint_every_n_hours = 6; 32 | 33 | // A version number that identifies a different on-disk checkpoint format. 34 | // Usually, each subclass of BaseSaverBuilder works with a particular 35 | // version/format. However, it is possible that the same builder may be 36 | // upgraded to support a newer checkpoint format in the future. 37 | enum CheckpointFormatVersion { 38 | // Internal legacy format. 39 | LEGACY = 0; 40 | // Current format: tf.Saver() which works with tensorflow::table::Table. 41 | V1 = 1; 42 | // Experimental format under development. 43 | V2 = 2; 44 | } 45 | CheckpointFormatVersion version = 7; 46 | } 47 | -------------------------------------------------------------------------------- /src/test/java/org/jpmml/tensorflow/DNNClassifierTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import org.jpmml.evaluator.PMMLEquivalence; 22 | import org.junit.Test; 23 | 24 | public class DNNClassifierTest extends EstimatorTest { 25 | 26 | public DNNClassifierTest(){ 27 | super(new PMMLEquivalence(6e-3, 1e-6)); 28 | } 29 | 30 | @Test 31 | public void evaluateAudit() throws Exception { 32 | evaluate("DNNClassification", "Audit"); 33 | } 34 | 35 | @Test 36 | public void evaluateIris() throws Exception { 37 | evaluate("DNNClassification", "Iris"); 38 | } 39 | } -------------------------------------------------------------------------------- /src/test/java/org/jpmml/tensorflow/DNNRegressorTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import org.jpmml.evaluator.PMMLEquivalence; 22 | import org.junit.Test; 23 | 24 | public class DNNRegressorTest extends EstimatorTest { 25 | 26 | public DNNRegressorTest(){ 27 | super(new PMMLEquivalence(1e-5, 1e-6)); 28 | } 29 | 30 | @Test 31 | public void evaluateAuto() throws Exception { 32 | evaluate("DNNRegression", "Auto"); 33 | } 34 | } -------------------------------------------------------------------------------- /src/test/java/org/jpmml/tensorflow/EstimatorTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import java.io.File; 22 | import java.io.IOException; 23 | import java.net.URISyntaxException; 24 | import java.net.URL; 25 | import java.nio.file.NoSuchFileException; 26 | import java.nio.file.Paths; 27 | 28 | import com.google.common.base.Equivalence; 29 | import com.google.common.base.Predicate; 30 | import org.dmg.pmml.FieldName; 31 | import org.dmg.pmml.PMML; 32 | import org.jpmml.evaluator.ArchiveBatch; 33 | import org.jpmml.evaluator.IntegrationTest; 34 | import org.jpmml.evaluator.IntegrationTestBatch; 35 | import org.tensorflow.SavedModelBundle; 36 | 37 | abstract 38 | public class EstimatorTest extends IntegrationTest { 39 | 40 | public EstimatorTest(Equivalence equivalence){ 41 | super(equivalence); 42 | } 43 | 44 | @Override 45 | protected ArchiveBatch createBatch(String name, String dataset, Predicate predicate){ 46 | ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate){ 47 | 48 | @Override 49 | public IntegrationTest getIntegrationTest(){ 50 | return EstimatorTest.this; 51 | } 52 | 53 | @Override 54 | public PMML getPMML() throws Exception { 55 | File savedModelDir = getSavedModelDir(); 56 | 57 | SavedModelBundle bundle = SavedModelBundle.load(savedModelDir.getAbsolutePath(), "serve"); 58 | 59 | try(SavedModel savedModel = new SavedModel(bundle)){ 60 | EstimatorFactory estimatorFactory = EstimatorFactory.newInstance(); 61 | 62 | Estimator estimator = estimatorFactory.newEstimator(savedModel); 63 | 64 | PMML pmml = estimator.encodePMML(); 65 | 66 | ensureValidity(pmml); 67 | 68 | return pmml; 69 | } 70 | } 71 | 72 | private File getSavedModelDir() throws IOException, URISyntaxException { 73 | ClassLoader classLoader = (EstimatorTest.this.getClass()).getClassLoader(); 74 | 75 | String protoPath = ("savedmodel/" + getName() + getDataset() + "/saved_model.pbtxt"); 76 | 77 | URL protoResource = classLoader.getResource(protoPath); 78 | if(protoResource == null){ 79 | throw new NoSuchFileException(protoPath); 80 | } 81 | 82 | File protoFile = (Paths.get(protoResource.toURI())).toFile(); 83 | 84 | return protoFile.getParentFile(); 85 | } 86 | }; 87 | 88 | return result; 89 | } 90 | } -------------------------------------------------------------------------------- /src/test/java/org/jpmml/tensorflow/LinearClassifierTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import org.jpmml.evaluator.PMMLEquivalence; 22 | import org.junit.Test; 23 | 24 | public class LinearClassifierTest extends EstimatorTest { 25 | 26 | public LinearClassifierTest(){ 27 | super(new PMMLEquivalence(1e-6, 1e-6)); 28 | } 29 | 30 | @Test 31 | public void evaluateAudit() throws Exception { 32 | evaluate("LinearClassification", "Audit"); 33 | } 34 | 35 | @Test 36 | public void evaluateIris() throws Exception { 37 | evaluate("LinearClassification", "Iris"); 38 | } 39 | } -------------------------------------------------------------------------------- /src/test/java/org/jpmml/tensorflow/LinearRegressorTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-TensorFlow 5 | * 6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-TensorFlow is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-TensorFlow. If not, see . 18 | */ 19 | package org.jpmml.tensorflow; 20 | 21 | import org.jpmml.evaluator.FloatEquivalence; 22 | import org.junit.Test; 23 | 24 | public class LinearRegressorTest extends EstimatorTest { 25 | 26 | public LinearRegressorTest(){ 27 | super(new FloatEquivalence(0)); 28 | } 29 | 30 | @Test 31 | public void evaluateAuto() throws Exception { 32 | evaluate("LinearRegression", "Auto"); 33 | } 34 | } -------------------------------------------------------------------------------- /src/test/resources/csv/Auto.csv: -------------------------------------------------------------------------------- 1 | cylinders,displacement,horsepower,weight,acceleration,model_year,origin,mpg 2 | 8,307,130,3504,12,70,1,18 3 | 8,350,165,3693,11.5,70,1,15 4 | 8,318,150,3436,11,70,1,18 5 | 8,304,150,3433,12,70,1,16 6 | 8,302,140,3449,10.5,70,1,17 7 | 8,429,198,4341,10,70,1,15 8 | 8,454,220,4354,9,70,1,14 9 | 8,440,215,4312,8.5,70,1,14 10 | 8,455,225,4425,10,70,1,14 11 | 8,390,190,3850,8.5,70,1,15 12 | 8,383,170,3563,10,70,1,15 13 | 8,340,160,3609,8,70,1,14 14 | 8,400,150,3761,9.5,70,1,15 15 | 8,455,225,3086,10,70,1,14 16 | 4,113,95,2372,15,70,3,24 17 | 6,198,95,2833,15.5,70,1,22 18 | 6,199,97,2774,15.5,70,1,18 19 | 6,200,85,2587,16,70,1,21 20 | 4,97,88,2130,14.5,70,3,27 21 | 4,97,46,1835,20.5,70,2,26 22 | 4,110,87,2672,17.5,70,2,25 23 | 4,107,90,2430,14.5,70,2,24 24 | 4,104,95,2375,17.5,70,2,25 25 | 4,121,113,2234,12.5,70,2,26 26 | 6,199,90,2648,15,70,1,21 27 | 8,360,215,4615,14,70,1,10 28 | 8,307,200,4376,15,70,1,10 29 | 8,318,210,4382,13.5,70,1,11 30 | 8,304,193,4732,18.5,70,1,9 31 | 4,97,88,2130,14.5,71,3,27 32 | 4,140,90,2264,15.5,71,1,28 33 | 4,113,95,2228,14,71,3,25 34 | 6,232,100,2634,13,71,1,19 35 | 6,225,105,3439,15.5,71,1,16 36 | 6,250,100,3329,15.5,71,1,17 37 | 6,250,88,3302,15.5,71,1,19 38 | 6,232,100,3288,15.5,71,1,18 39 | 8,350,165,4209,12,71,1,14 40 | 8,400,175,4464,11.5,71,1,14 41 | 8,351,153,4154,13.5,71,1,14 42 | 8,318,150,4096,13,71,1,14 43 | 8,383,180,4955,11.5,71,1,12 44 | 8,400,170,4746,12,71,1,13 45 | 8,400,175,5140,12,71,1,13 46 | 6,258,110,2962,13.5,71,1,18 47 | 4,140,72,2408,19,71,1,22 48 | 6,250,100,3282,15,71,1,19 49 | 6,250,88,3139,14.5,71,1,18 50 | 4,122,86,2220,14,71,1,23 51 | 4,116,90,2123,14,71,2,28 52 | 4,79,70,2074,19.5,71,2,30 53 | 4,88,76,2065,14.5,71,2,30 54 | 4,71,65,1773,19,71,3,31 55 | 4,72,69,1613,18,71,3,35 56 | 4,97,60,1834,19,71,2,27 57 | 4,91,70,1955,20.5,71,1,26 58 | 4,113,95,2278,15.5,72,3,24 59 | 4,97.5,80,2126,17,72,1,25 60 | 4,97,54,2254,23.5,72,2,23 61 | 4,140,90,2408,19.5,72,1,20 62 | 4,122,86,2226,16.5,72,1,21 63 | 8,350,165,4274,12,72,1,13 64 | 8,400,175,4385,12,72,1,14 65 | 8,318,150,4135,13.5,72,1,15 66 | 8,351,153,4129,13,72,1,14 67 | 8,304,150,3672,11.5,72,1,17 68 | 8,429,208,4633,11,72,1,11 69 | 8,350,155,4502,13.5,72,1,13 70 | 8,350,160,4456,13.5,72,1,12 71 | 8,400,190,4422,12.5,72,1,13 72 | 3,70,97,2330,13.5,72,3,19 73 | 8,304,150,3892,12.5,72,1,15 74 | 8,307,130,4098,14,72,1,13 75 | 8,302,140,4294,16,72,1,13 76 | 8,318,150,4077,14,72,1,14 77 | 4,121,112,2933,14.5,72,2,18 78 | 4,121,76,2511,18,72,2,22 79 | 4,120,87,2979,19.5,72,2,21 80 | 4,96,69,2189,18,72,2,26 81 | 4,122,86,2395,16,72,1,22 82 | 4,97,92,2288,17,72,3,28 83 | 4,120,97,2506,14.5,72,3,23 84 | 4,98,80,2164,15,72,1,28 85 | 4,97,88,2100,16.5,72,3,27 86 | 8,350,175,4100,13,73,1,13 87 | 8,304,150,3672,11.5,73,1,14 88 | 8,350,145,3988,13,73,1,13 89 | 8,302,137,4042,14.5,73,1,14 90 | 8,318,150,3777,12.5,73,1,15 91 | 8,429,198,4952,11.5,73,1,12 92 | 8,400,150,4464,12,73,1,13 93 | 8,351,158,4363,13,73,1,13 94 | 8,318,150,4237,14.5,73,1,14 95 | 8,440,215,4735,11,73,1,13 96 | 8,455,225,4951,11,73,1,12 97 | 8,360,175,3821,11,73,1,13 98 | 6,225,105,3121,16.5,73,1,18 99 | 6,250,100,3278,18,73,1,16 100 | 6,232,100,2945,16,73,1,18 101 | 6,250,88,3021,16.5,73,1,18 102 | 6,198,95,2904,16,73,1,23 103 | 4,97,46,1950,21,73,2,26 104 | 8,400,150,4997,14,73,1,11 105 | 8,400,167,4906,12.5,73,1,12 106 | 8,360,170,4654,13,73,1,13 107 | 8,350,180,4499,12.5,73,1,12 108 | 6,232,100,2789,15,73,1,18 109 | 4,97,88,2279,19,73,3,20 110 | 4,140,72,2401,19.5,73,1,21 111 | 4,108,94,2379,16.5,73,3,22 112 | 3,70,90,2124,13.5,73,3,18 113 | 4,122,85,2310,18.5,73,1,19 114 | 6,155,107,2472,14,73,1,21 115 | 4,98,90,2265,15.5,73,2,26 116 | 8,350,145,4082,13,73,1,15 117 | 8,400,230,4278,9.5,73,1,16 118 | 4,68,49,1867,19.5,73,2,29 119 | 4,116,75,2158,15.5,73,2,24 120 | 4,114,91,2582,14,73,2,20 121 | 4,121,112,2868,15.5,73,2,19 122 | 8,318,150,3399,11,73,1,15 123 | 4,121,110,2660,14,73,2,24 124 | 6,156,122,2807,13.5,73,3,20 125 | 8,350,180,3664,11,73,1,11 126 | 6,198,95,3102,16.5,74,1,20 127 | 6,232,100,2901,16,74,1,19 128 | 6,250,100,3336,17,74,1,15 129 | 4,79,67,1950,19,74,3,31 130 | 4,122,80,2451,16.5,74,1,26 131 | 4,71,65,1836,21,74,3,32 132 | 4,140,75,2542,17,74,1,25 133 | 6,250,100,3781,17,74,1,16 134 | 6,258,110,3632,18,74,1,16 135 | 6,225,105,3613,16.5,74,1,18 136 | 8,302,140,4141,14,74,1,16 137 | 8,350,150,4699,14.5,74,1,13 138 | 8,318,150,4457,13.5,74,1,14 139 | 8,302,140,4638,16,74,1,14 140 | 8,304,150,4257,15.5,74,1,14 141 | 4,98,83,2219,16.5,74,2,29 142 | 4,79,67,1963,15.5,74,2,26 143 | 4,97,78,2300,14.5,74,2,26 144 | 4,76,52,1649,16.5,74,3,31 145 | 4,83,61,2003,19,74,3,32 146 | 4,90,75,2125,14.5,74,1,28 147 | 4,90,75,2108,15.5,74,2,24 148 | 4,116,75,2246,14,74,2,26 149 | 4,120,97,2489,15,74,3,24 150 | 4,108,93,2391,15.5,74,3,26 151 | 4,79,67,2000,16,74,2,31 152 | 6,225,95,3264,16,75,1,19 153 | 6,250,105,3459,16,75,1,18 154 | 6,250,72,3432,21,75,1,15 155 | 6,250,72,3158,19.5,75,1,15 156 | 8,400,170,4668,11.5,75,1,16 157 | 8,350,145,4440,14,75,1,15 158 | 8,318,150,4498,14.5,75,1,16 159 | 8,351,148,4657,13.5,75,1,14 160 | 6,231,110,3907,21,75,1,17 161 | 6,250,105,3897,18.5,75,1,16 162 | 6,258,110,3730,19,75,1,15 163 | 6,225,95,3785,19,75,1,18 164 | 6,231,110,3039,15,75,1,21 165 | 8,262,110,3221,13.5,75,1,20 166 | 8,302,129,3169,12,75,1,13 167 | 4,97,75,2171,16,75,3,29 168 | 4,140,83,2639,17,75,1,23 169 | 6,232,100,2914,16,75,1,20 170 | 4,140,78,2592,18.5,75,1,23 171 | 4,134,96,2702,13.5,75,3,24 172 | 4,90,71,2223,16.5,75,2,25 173 | 4,119,97,2545,17,75,3,24 174 | 6,171,97,2984,14.5,75,1,18 175 | 4,90,70,1937,14,75,2,29 176 | 6,232,90,3211,17,75,1,19 177 | 4,115,95,2694,15,75,2,23 178 | 4,120,88,2957,17,75,2,23 179 | 4,121,98,2945,14.5,75,2,22 180 | 4,121,115,2671,13.5,75,2,25 181 | 4,91,53,1795,17.5,75,3,33 182 | 4,107,86,2464,15.5,76,2,28 183 | 4,116,81,2220,16.9,76,2,25 184 | 4,140,92,2572,14.9,76,1,25 185 | 4,98,79,2255,17.7,76,1,26 186 | 4,101,83,2202,15.3,76,2,27 187 | 8,305,140,4215,13,76,1,17.5 188 | 8,318,150,4190,13,76,1,16 189 | 8,304,120,3962,13.9,76,1,15.5 190 | 8,351,152,4215,12.8,76,1,14.5 191 | 6,225,100,3233,15.4,76,1,22 192 | 6,250,105,3353,14.5,76,1,22 193 | 6,200,81,3012,17.6,76,1,24 194 | 6,232,90,3085,17.6,76,1,22.5 195 | 4,85,52,2035,22.2,76,1,29 196 | 4,98,60,2164,22.1,76,1,24.5 197 | 4,90,70,1937,14.2,76,2,29 198 | 4,91,53,1795,17.4,76,3,33 199 | 6,225,100,3651,17.7,76,1,20 200 | 6,250,78,3574,21,76,1,18 201 | 6,250,110,3645,16.2,76,1,18.5 202 | 6,258,95,3193,17.8,76,1,17.5 203 | 4,97,71,1825,12.2,76,2,29.5 204 | 4,85,70,1990,17,76,3,32 205 | 4,97,75,2155,16.4,76,3,28 206 | 4,140,72,2565,13.6,76,1,26.5 207 | 4,130,102,3150,15.7,76,2,20 208 | 8,318,150,3940,13.2,76,1,13 209 | 4,120,88,3270,21.9,76,2,19 210 | 6,156,108,2930,15.5,76,3,19 211 | 6,168,120,3820,16.7,76,2,16.5 212 | 8,350,180,4380,12.1,76,1,16.5 213 | 8,350,145,4055,12,76,1,13 214 | 8,302,130,3870,15,76,1,13 215 | 8,318,150,3755,14,76,1,13 216 | 4,98,68,2045,18.5,77,3,31.5 217 | 4,111,80,2155,14.8,77,1,30 218 | 4,79,58,1825,18.6,77,2,36 219 | 4,122,96,2300,15.5,77,1,25.5 220 | 4,85,70,1945,16.8,77,3,33.5 221 | 8,305,145,3880,12.5,77,1,17.5 222 | 8,260,110,4060,19,77,1,17 223 | 8,318,145,4140,13.7,77,1,15.5 224 | 8,302,130,4295,14.9,77,1,15 225 | 6,250,110,3520,16.4,77,1,17.5 226 | 6,231,105,3425,16.9,77,1,20.5 227 | 6,225,100,3630,17.7,77,1,19 228 | 6,250,98,3525,19,77,1,18.5 229 | 8,400,180,4220,11.1,77,1,16 230 | 8,350,170,4165,11.4,77,1,15.5 231 | 8,400,190,4325,12.2,77,1,15.5 232 | 8,351,149,4335,14.5,77,1,16 233 | 4,97,78,1940,14.5,77,2,29 234 | 4,151,88,2740,16,77,1,24.5 235 | 4,97,75,2265,18.2,77,3,26 236 | 4,140,89,2755,15.8,77,1,25.5 237 | 4,98,63,2051,17,77,1,30.5 238 | 4,98,83,2075,15.9,77,1,33.5 239 | 4,97,67,1985,16.4,77,3,30 240 | 4,97,78,2190,14.1,77,2,30.5 241 | 6,146,97,2815,14.5,77,3,22 242 | 4,121,110,2600,12.8,77,2,21.5 243 | 3,80,110,2720,13.5,77,3,21.5 244 | 4,90,48,1985,21.5,78,2,43.1 245 | 4,98,66,1800,14.4,78,1,36.1 246 | 4,78,52,1985,19.4,78,3,32.8 247 | 4,85,70,2070,18.6,78,3,39.4 248 | 4,91,60,1800,16.4,78,3,36.1 249 | 8,260,110,3365,15.5,78,1,19.9 250 | 8,318,140,3735,13.2,78,1,19.4 251 | 8,302,139,3570,12.8,78,1,20.2 252 | 6,231,105,3535,19.2,78,1,19.2 253 | 6,200,95,3155,18.2,78,1,20.5 254 | 6,200,85,2965,15.8,78,1,20.2 255 | 4,140,88,2720,15.4,78,1,25.1 256 | 6,225,100,3430,17.2,78,1,20.5 257 | 6,232,90,3210,17.2,78,1,19.4 258 | 6,231,105,3380,15.8,78,1,20.6 259 | 6,200,85,3070,16.7,78,1,20.8 260 | 6,225,110,3620,18.7,78,1,18.6 261 | 6,258,120,3410,15.1,78,1,18.1 262 | 8,305,145,3425,13.2,78,1,19.2 263 | 6,231,165,3445,13.4,78,1,17.7 264 | 8,302,139,3205,11.2,78,1,18.1 265 | 8,318,140,4080,13.7,78,1,17.5 266 | 4,98,68,2155,16.5,78,1,30 267 | 4,134,95,2560,14.2,78,3,27.5 268 | 4,119,97,2300,14.7,78,3,27.2 269 | 4,105,75,2230,14.5,78,1,30.9 270 | 4,134,95,2515,14.8,78,3,21.1 271 | 4,156,105,2745,16.7,78,1,23.2 272 | 4,151,85,2855,17.6,78,1,23.8 273 | 4,119,97,2405,14.9,78,3,23.9 274 | 5,131,103,2830,15.9,78,2,20.3 275 | 6,163,125,3140,13.6,78,2,17 276 | 4,121,115,2795,15.7,78,2,21.6 277 | 6,163,133,3410,15.8,78,2,16.2 278 | 4,89,71,1990,14.9,78,2,31.5 279 | 4,98,68,2135,16.6,78,3,29.5 280 | 6,231,115,3245,15.4,79,1,21.5 281 | 6,200,85,2990,18.2,79,1,19.8 282 | 4,140,88,2890,17.3,79,1,22.3 283 | 6,232,90,3265,18.2,79,1,20.2 284 | 6,225,110,3360,16.6,79,1,20.6 285 | 8,305,130,3840,15.4,79,1,17 286 | 8,302,129,3725,13.4,79,1,17.6 287 | 8,351,138,3955,13.2,79,1,16.5 288 | 8,318,135,3830,15.2,79,1,18.2 289 | 8,350,155,4360,14.9,79,1,16.9 290 | 8,351,142,4054,14.3,79,1,15.5 291 | 8,267,125,3605,15,79,1,19.2 292 | 8,360,150,3940,13,79,1,18.5 293 | 4,89,71,1925,14,79,2,31.9 294 | 4,86,65,1975,15.2,79,3,34.1 295 | 4,98,80,1915,14.4,79,1,35.7 296 | 4,121,80,2670,15,79,1,27.4 297 | 5,183,77,3530,20.1,79,2,25.4 298 | 8,350,125,3900,17.4,79,1,23 299 | 4,141,71,3190,24.8,79,2,27.2 300 | 8,260,90,3420,22.2,79,1,23.9 301 | 4,105,70,2200,13.2,79,1,34.2 302 | 4,105,70,2150,14.9,79,1,34.5 303 | 4,85,65,2020,19.2,79,3,31.8 304 | 4,91,69,2130,14.7,79,2,37.3 305 | 4,151,90,2670,16,79,1,28.4 306 | 6,173,115,2595,11.3,79,1,28.8 307 | 6,173,115,2700,12.9,79,1,26.8 308 | 4,151,90,2556,13.2,79,1,33.5 309 | 4,98,76,2144,14.7,80,2,41.5 310 | 4,89,60,1968,18.8,80,3,38.1 311 | 4,98,70,2120,15.5,80,1,32.1 312 | 4,86,65,2019,16.4,80,3,37.2 313 | 4,151,90,2678,16.5,80,1,28 314 | 4,140,88,2870,18.1,80,1,26.4 315 | 4,151,90,3003,20.1,80,1,24.3 316 | 6,225,90,3381,18.7,80,1,19.1 317 | 4,97,78,2188,15.8,80,2,34.3 318 | 4,134,90,2711,15.5,80,3,29.8 319 | 4,120,75,2542,17.5,80,3,31.3 320 | 4,119,92,2434,15,80,3,37 321 | 4,108,75,2265,15.2,80,3,32.2 322 | 4,86,65,2110,17.9,80,3,46.6 323 | 4,156,105,2800,14.4,80,1,27.9 324 | 4,85,65,2110,19.2,80,3,40.8 325 | 4,90,48,2085,21.7,80,2,44.3 326 | 4,90,48,2335,23.7,80,2,43.4 327 | 5,121,67,2950,19.9,80,2,36.4 328 | 4,146,67,3250,21.8,80,2,30 329 | 4,91,67,1850,13.8,80,3,44.6 330 | 4,97,67,2145,18,80,3,33.8 331 | 4,89,62,1845,15.3,80,2,29.8 332 | 6,168,132,2910,11.4,80,3,32.7 333 | 3,70,100,2420,12.5,80,3,23.7 334 | 4,122,88,2500,15.1,80,2,35 335 | 4,107,72,2290,17,80,3,32.4 336 | 4,135,84,2490,15.7,81,1,27.2 337 | 4,151,84,2635,16.4,81,1,26.6 338 | 4,156,92,2620,14.4,81,1,25.8 339 | 6,173,110,2725,12.6,81,1,23.5 340 | 4,135,84,2385,12.9,81,1,30 341 | 4,79,58,1755,16.9,81,3,39.1 342 | 4,86,64,1875,16.4,81,1,39 343 | 4,81,60,1760,16.1,81,3,35.1 344 | 4,97,67,2065,17.8,81,3,32.3 345 | 4,85,65,1975,19.4,81,3,37 346 | 4,89,62,2050,17.3,81,3,37.7 347 | 4,91,68,1985,16,81,3,34.1 348 | 4,105,63,2215,14.9,81,1,34.7 349 | 4,98,65,2045,16.2,81,1,34.4 350 | 4,98,65,2380,20.7,81,1,29.9 351 | 4,105,74,2190,14.2,81,2,33 352 | 4,107,75,2210,14.4,81,3,33.7 353 | 4,108,75,2350,16.8,81,3,32.4 354 | 4,119,100,2615,14.8,81,3,32.9 355 | 4,120,74,2635,18.3,81,3,31.6 356 | 4,141,80,3230,20.4,81,2,28.1 357 | 6,145,76,3160,19.6,81,2,30.7 358 | 6,168,116,2900,12.6,81,3,25.4 359 | 6,146,120,2930,13.8,81,3,24.2 360 | 6,231,110,3415,15.8,81,1,22.4 361 | 8,350,105,3725,19,81,1,26.6 362 | 6,200,88,3060,17.1,81,1,20.2 363 | 6,225,85,3465,16.6,81,1,17.6 364 | 4,112,88,2605,19.6,82,1,28 365 | 4,112,88,2640,18.6,82,1,27 366 | 4,112,88,2395,18,82,1,34 367 | 4,112,85,2575,16.2,82,1,31 368 | 4,135,84,2525,16,82,1,29 369 | 4,151,90,2735,18,82,1,27 370 | 4,140,92,2865,16.4,82,1,24 371 | 4,105,74,1980,15.3,82,2,36 372 | 4,91,68,2025,18.2,82,3,37 373 | 4,91,68,1970,17.6,82,3,31 374 | 4,105,63,2125,14.7,82,1,38 375 | 4,98,70,2125,17.3,82,1,36 376 | 4,120,88,2160,14.5,82,3,36 377 | 4,107,75,2205,14.5,82,3,36 378 | 4,108,70,2245,16.9,82,3,34 379 | 4,91,67,1965,15,82,3,38 380 | 4,91,67,1965,15.7,82,3,32 381 | 4,91,67,1995,16.2,82,3,38 382 | 6,181,110,2945,16.4,82,1,25 383 | 6,262,85,3015,17,82,1,38 384 | 4,156,92,2585,14.5,82,1,26 385 | 6,232,112,2835,14.7,82,1,22 386 | 4,144,96,2665,13.9,82,3,32 387 | 4,135,84,2370,13,82,1,36 388 | 4,151,90,2950,17.3,82,1,27 389 | 4,140,86,2790,15.6,82,1,27 390 | 4,97,52,2130,24.6,82,2,44 391 | 4,135,84,2295,11.6,82,1,32 392 | 4,120,79,2625,18.6,82,1,28 393 | 4,119,82,2720,19.4,82,1,31 394 | -------------------------------------------------------------------------------- /src/test/resources/csv/DNNClassificationIris.csv: -------------------------------------------------------------------------------- 1 | _target,probability(0),probability(1),probability(2) 2 | 0,0.9989047050476074,0.0008360623614862561,0.0002592703967820853 3 | 0,0.9979177117347717,0.0016249773325398564,0.0004573028127197176 4 | 0,0.9982225298881531,0.0013437297893688083,0.0004337556893005967 5 | 0,0.9962031245231628,0.0031924087088555098,0.0006044614128768444 6 | 0,0.9988418221473694,0.0008937669917941093,0.00026450029690749943 7 | 0,0.9990516304969788,0.0007657141541130841,0.0001826960069593042 8 | 0,0.9980208873748779,0.0015651624416932464,0.0004139202937949449 9 | 0,0.9982287287712097,0.0014381676446646452,0.0003331175248604268 10 | 0,0.9956466555595398,0.0035738914739340544,0.0007794885314069688 11 | 0,0.9970353841781616,0.002444528741762042,0.0005200660671107471 12 | 0,0.9991681575775146,0.0006470356020145118,0.00018488752539269626 13 | 0,0.9969224333763123,0.0026435167528688908,0.00043406395707279444 14 | 0,0.9972326159477234,0.00220040837302804,0.0005669626407325268 15 | 0,0.9975994229316711,0.0017001996748149395,0.0007002817583270371 16 | 0,0.9998074173927307,0.00010991662566084415,8.27241747174412e-05 17 | 0,0.9997325539588928,0.00017794525774661452,8.952804637374356e-05 18 | 0,0.9996616840362549,0.0002095497475238517,0.00012878111738245934 19 | 0,0.9990488886833191,0.0007109538419172168,0.0002401395613560453 20 | 0,0.9992166757583618,0.000634459953289479,0.00014886280405335128 21 | 0,0.9989922642707825,0.0007808840600773692,0.00022690632613375783 22 | 0,0.9981746673583984,0.0015530585078522563,0.00027225862140767276 23 | 0,0.9990561604499817,0.0007163219852373004,0.0002275236911373213 24 | 0,0.9992896318435669,0.0004327555070631206,0.00027752271853387356 25 | 0,0.9978926777839661,0.0017603086307644844,0.00034696166403591633 26 | 0,0.9913307428359985,0.008043371140956879,0.0006258317152969539 27 | 0,0.9968171715736389,0.0026912239845842123,0.0004916128236800432 28 | 0,0.9982463121414185,0.0014370633289217949,0.0003166229580529034 29 | 0,0.9987409710884094,0.0010021397611126304,0.0002568713098298758 30 | 0,0.9989625215530396,0.000781965209171176,0.0002555005776230246 31 | 0,0.9958993792533875,0.0035461997613310814,0.0005544034065678716 32 | 0,0.9961361289024353,0.0033173358533531427,0.0005464403657242656 33 | 0,0.99921715259552,0.0005878541269339621,0.00019503400835674256 34 | 0,0.9990599751472473,0.0007452968275174499,0.00019475595036055893 35 | 0,0.9996028542518616,0.0002765266108326614,0.00012063606845913455 36 | 0,0.997458279132843,0.0020820172503590584,0.00045964642777107656 37 | 0,0.9990630745887756,0.0006333969649858773,0.0003035677073057741 38 | 0,0.9994930028915405,0.00034153310116380453,0.00016550954023841769 39 | 0,0.9984629154205322,0.0012108131777495146,0.0003262842947151512 40 | 0,0.9969363212585449,0.0023982736747711897,0.0006653706077486277 41 | 0,0.9984490871429443,0.0012469399953261018,0.00030399279785342515 42 | 0,0.9991639852523804,0.0005930583574809134,0.00024304295948240906 43 | 0,0.9960720539093018,0.0030092080123722553,0.0009186401730403304 44 | 0,0.9973681569099426,0.0020611975342035294,0.0005706422962248325 45 | 0,0.9975513815879822,0.0019216471118852496,0.0005268990062177181 46 | 0,0.9930242300033569,0.006232454441487789,0.0007433429127559066 47 | 0,0.997961163520813,0.0015959955053403974,0.00044282712042331696 48 | 0,0.9984704852104187,0.0012704470427706838,0.0002590469957794994 49 | 0,0.9973421692848206,0.00214195204898715,0.0005158890271559358 50 | 0,0.9990496039390564,0.0007461805362254381,0.00020420258806552738 51 | 0,0.9985463619232178,0.0011223034234717488,0.0003313718771096319 52 | 1,0.0003461647720541805,0.9979695677757263,0.001684160204604268 53 | 1,0.0006451268563978374,0.9905067086219788,0.008848185651004314 54 | 1,0.0006193189765326679,0.9823654294013977,0.01701533980667591 55 | 1,0.0030198381282389164,0.9448526501655579,0.05212751775979996 56 | 1,0.0011560042621567845,0.9626206159591675,0.03622337058186531 57 | 1,0.0016463467618450522,0.9369542002677917,0.061399418860673904 58 | 1,0.0008215613197535276,0.9529606103897095,0.046217817813158035 59 | 1,0.002778339199721813,0.9935497641563416,0.0036718735937029123 60 | 1,0.0006453708629123867,0.9948211908340454,0.004533397499471903 61 | 1,0.0024087727069854736,0.9488183259963989,0.048772942274808884 62 | 1,0.0033394417259842157,0.9898724555969238,0.006788043770939112 63 | 1,0.0011488626478239894,0.9799173474311829,0.01893375813961029 64 | 1,0.0013161293463781476,0.9969717264175415,0.001712133060209453 65 | 1,0.001299033872783184,0.9298685193061829,0.0688323900103569 66 | 1,0.0014014269690960646,0.9950200319290161,0.003578474745154381 67 | 1,0.0004372586263343692,0.9981188178062439,0.001443826942704618 68 | 1,0.0016704555600881577,0.826302170753479,0.17202739417552948 69 | 1,0.000965489016380161,0.9972748160362244,0.0017596861580386758 70 | 1,0.002768319332972169,0.7102066874504089,0.28702500462532043 71 | 1,0.0013546883128583431,0.9956467747688293,0.0029984505381435156 72 | 2,0.0008822132949717343,0.3759615421295166,0.6231562495231628 73 | 1,0.0008366480469703674,0.9971413016319275,0.002022022381424904 74 | 1,0.0016971061704680324,0.5640475153923035,0.4342553913593292 75 | 1,0.0011201782617717981,0.9762335419654846,0.022646168246865273 76 | 1,0.0006247040582820773,0.9973479509353638,0.002027334878221154 77 | 1,0.0005620552110485733,0.9967799782752991,0.002657989738509059 78 | 1,0.0008521045674569905,0.981255292892456,0.017892606556415558 79 | 1,0.0010623474372550845,0.759124219417572,0.23981352150440216 80 | 1,0.0014546927995979786,0.9232752323150635,0.07527007162570953 81 | 1,0.010719776153564453,0.9857681393623352,0.003512156894430518 82 | 1,0.0016219413373619318,0.9948831796646118,0.003494875505566597 83 | 1,0.0015995304565876722,0.9965848922729492,0.0018154870485886931 84 | 1,0.001080506481230259,0.9965248703956604,0.0023945486173033714 85 | 2,0.0006219488568603992,0.14923861622810364,0.8501394391059875 86 | 1,0.0017800258938223124,0.7180883884429932,0.28013163805007935 87 | 1,0.0008225410128943622,0.9656193852424622,0.03355801850557327 88 | 1,0.0006506352801807225,0.9870443940162659,0.012304888106882572 89 | 1,0.0018836450763046741,0.968732476234436,0.02938387170433998 90 | 1,0.0010221664560958743,0.9920328259468079,0.006944986060261726 91 | 1,0.002256561303511262,0.9694513082504272,0.028292158618569374 92 | 1,0.0021356134675443172,0.933732271194458,0.06413212418556213 93 | 1,0.0010642328998073936,0.9666807651519775,0.03225507214665413 94 | 1,0.0012185521190986037,0.9948511719703674,0.0039302813820540905 95 | 1,0.004043115768581629,0.9922090768814087,0.0037478890735656023 96 | 1,0.0017410650616511703,0.9697855114936829,0.02847347781062126 97 | 1,0.0008970940252766013,0.994446873664856,0.004656018223613501 98 | 1,0.0011661611497402191,0.9876282215118408,0.011205561459064484 99 | 1,0.0007645345758646727,0.9954438209533691,0.0037915813736617565 100 | 1,0.02692677266895771,0.9635664224624634,0.009506762959063053 101 | 1,0.0012638334883376956,0.9891725778579712,0.009563514031469822 102 | 2,6.70488361720345e-07,6.782418495276943e-05,0.9999314546585083 103 | 2,8.786226680967957e-05,0.010090313851833344,0.9898217916488647 104 | 2,2.9988314054207876e-05,0.007728721015155315,0.9922412633895874 105 | 2,8.277977030957118e-05,0.017429351806640625,0.9824878573417664 106 | 2,7.1886474870552775e-06,0.0009111366234719753,0.9990816116333008 107 | 2,5.3449311963049695e-06,0.001313116867095232,0.9986814856529236 108 | 2,0.0003139247710350901,0.024134358391165733,0.9755516648292542 109 | 2,4.24602440034505e-05,0.014188208617269993,0.9857693314552307 110 | 2,4.6444692998193204e-05,0.006649806164205074,0.9933037161827087 111 | 2,4.636980065697571e-06,0.0014352883445098996,0.9985601305961609 112 | 2,0.00040848381468094885,0.17713280022144318,0.8224586844444275 113 | 2,0.00013909097469877452,0.025111092254519463,0.9747498035430908 114 | 2,7.725215255049989e-05,0.02016610838472843,0.9797566533088684 115 | 2,4.417638047016226e-05,0.0030252181459218264,0.9969306588172913 116 | 2,6.308760475803865e-06,0.0003555802395567298,0.9996380805969238 117 | 2,3.3335738407913595e-05,0.006532712373882532,0.993433952331543 118 | 2,0.00020959824905730784,0.06866618990898132,0.931124210357666 119 | 2,9.03612635738682e-06,0.006755535956472158,0.9932354092597961 120 | 2,3.4310443197682616e-07,2.732526081672404e-05,0.9999723434448242 121 | 2,0.0007662875577807426,0.10765309631824493,0.8915805816650391 122 | 2,1.9564093236112967e-05,0.00464129913598299,0.9953391551971436 123 | 2,8.547856123186648e-05,0.009210662916302681,0.9907038807868958 124 | 2,5.648945716529852e-06,0.0012344352435320616,0.9987599849700928 125 | 2,0.0007866835221648216,0.2106846570968628,0.7885286808013916 126 | 2,4.572494435706176e-05,0.014789164066314697,0.9851651191711426 127 | 2,0.0001669866469455883,0.10296441614627838,0.8968686461448669 128 | 2,0.000983618083409965,0.301876425743103,0.6971399188041687 129 | 2,0.0007688715704716742,0.2701977491378784,0.729033350944519 130 | 2,1.642716051719617e-05,0.0019181155366823077,0.9980654120445251 131 | 2,0.0005448490264825523,0.40934574604034424,0.5901094079017639 132 | 2,5.643894473905675e-05,0.017497925087809563,0.9824455976486206 133 | 2,0.00011451374302851036,0.1901913285255432,0.8096941709518433 134 | 2,9.376764865010045e-06,0.0009166600648313761,0.9990739822387695 135 | 1,0.0012366522569209337,0.563909649848938,0.434853732585907 136 | 2,0.00030155113199725747,0.06550092250108719,0.9341975450515747 137 | 2,1.5726973288110457e-05,0.004466318525373936,0.9955180287361145 138 | 2,6.5268527578155044e-06,0.0011002128012478352,0.9988933205604553 139 | 2,0.0001988871954381466,0.06895013898611069,0.9308509230613708 140 | 2,0.0008838910725899041,0.303771436214447,0.6953446269035339 141 | 2,0.00016035721637308598,0.058819640427827835,0.9410200119018555 142 | 2,9.300411875301506e-06,0.0013922522775828838,0.9985983967781067 143 | 2,0.0001729057403281331,0.05645694211125374,0.9433701634407043 144 | 2,8.786226680967957e-05,0.010090313851833344,0.9898217916488647 145 | 2,7.152757916628616e-06,0.0013067243853583932,0.9986861348152161 146 | 2,4.903526132693514e-06,0.0007792363758198917,0.999215841293335 147 | 2,6.762499106116593e-05,0.014191126450896263,0.9857412576675415 148 | 2,0.0002661229227669537,0.04012950137257576,0.9596043825149536 149 | 2,0.0002260944020235911,0.06593544036149979,0.933838427066803 150 | 2,2.0221576050971635e-05,0.004108963999897242,0.9958707690238953 151 | 2,0.0002888355520553887,0.07058849930763245,0.9291226267814636 152 | -------------------------------------------------------------------------------- /src/test/resources/csv/DNNRegressionAuto.csv: -------------------------------------------------------------------------------- 1 | _target 2 | 15.01767635345459 3 | 13.843924522399902 4 | 15.190770149230957 5 | 15.867392539978027 6 | 15.213959693908691 7 | 9.638413429260254 8 | 9.164948463439941 9 | 9.459259986877441 10 | 9.267571449279785 11 | 11.955185890197754 12 | 13.107895851135254 13 | 13.378771781921387 14 | 11.258580207824707 15 | 21.37030029296875 16 | 25.955530166625977 17 | 20.715490341186523 18 | 20.954431533813477 19 | 21.445667266845703 20 | 26.992570877075195 21 | 27.733427047729492 22 | 24.51389503479004 23 | 24.765304565429688 24 | 25.97898292541504 25 | 25.132768630981445 26 | 21.10872459411621 27 | 12.09019947052002 28 | 14.369145393371582 29 | 13.885125160217285 30 | 13.901936531066895 31 | 27.40926170349121 32 | 24.822734832763672 33 | 26.635480880737305 34 | 20.35489273071289 35 | 18.384361267089844 36 | 17.95117950439453 37 | 17.748802185058594 38 | 18.614864349365234 39 | 12.49507999420166 40 | 10.250367164611816 41 | 12.767239570617676 42 | 13.712067604064941 43 | 9.05346393585205 44 | 9.220158576965332 45 | 7.895920753479004 46 | 18.791942596435547 47 | 24.7735652923584 48 | 17.990825653076172 49 | 18.081926345825195 50 | 24.996854782104492 51 | 25.92203712463379 52 | 28.120319366455078 53 | 26.71446418762207 54 | 30.085500717163086 55 | 30.479711532592773 56 | 28.105886459350586 57 | 28.186382293701172 58 | 27.26825714111328 59 | 27.1066837310791 60 | 28.02620506286621 61 | 25.776338577270508 62 | 26.058229446411133 63 | 12.672635078430176 64 | 11.091118812561035 65 | 14.118607521057129 66 | 13.142632484436035 67 | 15.687932014465332 68 | 9.91542911529541 69 | 11.982121467590332 70 | 12.277104377746582 71 | 11.465508460998535 72 | 27.853803634643555 73 | 15.145010948181152 74 | 14.19857120513916 75 | 14.404557228088379 76 | 14.465394020080566 77 | 23.902074813842773 78 | 25.483150482177734 79 | 24.466026306152344 80 | 27.204307556152344 81 | 25.302967071533203 82 | 28.011890411376953 83 | 26.01335334777832 84 | 26.41921043395996 85 | 28.469789505004883 86 | 14.247818946838379 87 | 16.104646682739258 88 | 13.90551471710205 89 | 15.273171424865723 90 | 15.586092948913574 91 | 9.040131568908691 92 | 10.58840274810791 93 | 12.82396411895752 94 | 14.426648139953613 95 | 9.819470405578613 96 | 8.848740577697754 97 | 14.45628833770752 98 | 20.654747009277344 99 | 19.63905143737793 100 | 20.843915939331055 101 | 19.88296127319336 102 | 21.837705612182617 103 | 28.6937198638916 104 | 9.16026782989502 105 | 9.522671699523926 106 | 11.7983980178833 107 | 12.771889686584473 108 | 21.151302337646484 109 | 28.894498825073242 110 | 25.7661075592041 111 | 27.697267532348633 112 | 28.85256004333496 113 | 26.674060821533203 114 | 24.420930862426758 115 | 27.145740509033203 116 | 13.559624671936035 117 | 12.618054389953613 118 | 29.500791549682617 119 | 26.649465560913086 120 | 25.148454666137695 121 | 24.824684143066406 122 | 16.57712173461914 123 | 25.13978385925293 124 | 24.58885955810547 125 | 15.444716453552246 126 | 21.65910530090332 127 | 21.422561645507812 128 | 19.57562255859375 129 | 30.506628036499023 130 | 25.91273307800293 131 | 31.637182235717773 132 | 25.07264518737793 133 | 17.937984466552734 134 | 18.776613235473633 135 | 19.260887145996094 136 | 15.267655372619629 137 | 12.23153018951416 138 | 13.767060279846191 139 | 13.972039222717285 140 | 15.435334205627441 141 | 27.822429656982422 142 | 28.636730194091797 143 | 26.8936767578125 144 | 30.65583610534668 145 | 30.0467529296875 146 | 27.3649845123291 147 | 27.99091911315918 148 | 26.342296600341797 149 | 27.042680740356445 150 | 27.7779598236084 151 | 28.633909225463867 152 | 20.57707405090332 153 | 19.398727416992188 154 | 20.001663208007812 155 | 20.609952926635742 156 | 11.040692329406738 157 | 13.342263221740723 158 | 14.299607276916504 159 | 12.45729923248291 160 | 19.75057029724121 161 | 18.453643798828125 162 | 19.09935760498047 163 | 19.459854125976562 164 | 21.34469223022461 165 | 19.321794509887695 166 | 18.45132827758789 167 | 28.99836540222168 168 | 25.333581924438477 169 | 21.791412353515625 170 | 25.780838012695312 171 | 25.851482391357422 172 | 28.15052604675293 173 | 27.81520652770996 174 | 22.796205520629883 175 | 28.51112174987793 176 | 20.713672637939453 177 | 25.908510208129883 178 | 25.15557861328125 179 | 24.755977630615234 180 | 25.92513656616211 181 | 30.399757385253906 182 | 27.306547164916992 183 | 28.195714950561523 184 | 25.663129806518555 185 | 28.44609260559082 186 | 28.31287956237793 187 | 15.476578712463379 188 | 15.449719429016113 189 | 16.173154830932617 190 | 14.414521217346191 191 | 21.07360076904297 192 | 19.80551528930664 193 | 22.70804214477539 194 | 21.75406837463379 195 | 30.147197723388672 196 | 29.47661781311035 197 | 28.981184005737305 198 | 30.7898006439209 199 | 20.148767471313477 200 | 20.046720504760742 201 | 19.310047149658203 202 | 20.794986724853516 203 | 28.685686111450195 204 | 30.563989639282227 205 | 29.58062171936035 206 | 24.839216232299805 207 | 24.58248519897461 208 | 16.423070907592773 209 | 25.72722625732422 210 | 25.567638397216797 211 | 21.683713912963867 212 | 14.35326099395752 213 | 14.642354011535645 214 | 17.113584518432617 215 | 17.317256927490234 216 | 30.757665634155273 217 | 28.112184524536133 218 | 30.995134353637695 219 | 27.854263305664062 220 | 31.092958450317383 221 | 17.118465423583984 222 | 18.591459274291992 223 | 16.111360549926758 224 | 15.939587593078613 225 | 20.24010467529297 226 | 21.138608932495117 227 | 20.642778396606445 228 | 20.613292694091797 229 | 13.66762638092041 230 | 15.123013496398926 231 | 13.826024055480957 232 | 14.767544746398926 233 | 29.468639373779297 234 | 25.340967178344727 235 | 30.072572708129883 236 | 25.570987701416016 237 | 29.02448081970215 238 | 29.145784378051758 239 | 30.421735763549805 240 | 28.441919326782227 241 | 26.149118423461914 242 | 26.707416534423828 243 | 28.544170379638672 244 | 31.03156280517578 245 | 29.746917724609375 246 | 31.80262565612793 247 | 31.52972984313965 248 | 31.514171600341797 249 | 20.632360458374023 250 | 17.759380340576172 251 | 18.69063949584961 252 | 21.763891220092773 253 | 23.52735137939453 254 | 23.33498191833496 255 | 25.984661102294922 256 | 21.662111282348633 257 | 22.020809173583984 258 | 21.427560806274414 259 | 23.188613891601562 260 | 21.61444854736328 261 | 20.738473892211914 262 | 19.396343231201172 263 | 22.0572509765625 264 | 19.607097625732422 265 | 16.62314796447754 266 | 29.050870895385742 267 | 27.785737991333008 268 | 29.353559494018555 269 | 28.21805763244629 270 | 28.111343383789062 271 | 26.211063385009766 272 | 25.685771942138672 273 | 29.02049446105957 274 | 26.608570098876953 275 | 24.461013793945312 276 | 27.305665969848633 277 | 24.255294799804688 278 | 29.859888076782227 279 | 30.336463928222656 280 | 22.485891342163086 281 | 24.299768447875977 282 | 26.282501220703125 283 | 22.501827239990234 284 | 22.427915573120117 285 | 18.495302200317383 286 | 18.445430755615234 287 | 16.376070022583008 288 | 18.234130859375 289 | 15.795002937316895 290 | 16.405675888061523 291 | 20.210346221923828 292 | 16.423307418823242 293 | 30.27579116821289 294 | 31.235036849975586 295 | 30.092504501342773 296 | 26.81884002685547 297 | 23.43402862548828 298 | 17.400117874145508 299 | 27.01934814453125 300 | 22.130535125732422 301 | 28.272754669189453 302 | 28.91010856628418 303 | 32.16470718383789 304 | 29.600778579711914 305 | 26.482309341430664 306 | 25.436864852905273 307 | 25.4771728515625 308 | 26.155080795288086 309 | 29.94257354736328 310 | 32.42641067504883 311 | 29.79669189453125 312 | 31.80986976623535 313 | 27.00292205810547 314 | 26.98615264892578 315 | 26.767000198364258 316 | 22.824426651000977 317 | 30.152801513671875 318 | 28.28445053100586 319 | 29.46137237548828 320 | 29.64811134338379 321 | 30.20924186706543 322 | 31.87501335144043 323 | 26.22868537902832 324 | 32.250205993652344 325 | 31.550317764282227 326 | 31.163707733154297 327 | 27.44671058654785 328 | 26.172151565551758 329 | 31.646240234375 330 | 31.509798049926758 331 | 31.107263565063477 332 | 26.476343154907227 333 | 30.665056228637695 334 | 28.357206344604492 335 | 30.550350189208984 336 | 28.203081130981445 337 | 27.40032196044922 338 | 26.980871200561523 339 | 26.012845993041992 340 | 27.842737197875977 341 | 33.35484313964844 342 | 31.546018600463867 343 | 33.11640167236328 344 | 32.16755294799805 345 | 33.21706771850586 346 | 32.191619873046875 347 | 32.17801284790039 348 | 29.32831382751465 349 | 30.550352096557617 350 | 30.517683029174805 351 | 29.80692481994629 352 | 30.643491744995117 353 | 30.739858627319336 354 | 29.546594619750977 355 | 29.72406768798828 356 | 26.758499145507812 357 | 26.51824951171875 358 | 26.847524642944336 359 | 27.784530639648438 360 | 22.674644470214844 361 | 18.801280975341797 362 | 24.657682418823242 363 | 22.24619483947754 364 | 29.99256134033203 365 | 29.597064971923828 366 | 30.33865737915039 367 | 29.12074851989746 368 | 28.57099723815918 369 | 28.026649475097656 370 | 27.48520851135254 371 | 31.289819717407227 372 | 33.03423309326172 373 | 33.076637268066406 374 | 30.02290153503418 375 | 31.091768264770508 376 | 31.227453231811523 377 | 31.105281829833984 378 | 31.443906784057617 379 | 32.37647247314453 380 | 32.56315612792969 381 | 32.586124420166016 382 | 26.40549087524414 383 | 23.371509552001953 384 | 27.553056716918945 385 | 24.954225540161133 386 | 28.726423263549805 387 | 28.341327667236328 388 | 27.04876136779785 389 | 27.396957397460938 390 | 32.89274978637695 391 | 28.243947982788086 392 | 29.19800567626953 393 | 29.16570281982422 394 | -------------------------------------------------------------------------------- /src/test/resources/csv/Iris.csv: -------------------------------------------------------------------------------- 1 | Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species 2 | 5.1,3.5,1.4,0.2,setosa 3 | 4.9,3,1.4,0.2,setosa 4 | 4.7,3.2,1.3,0.2,setosa 5 | 4.6,3.1,1.5,0.2,setosa 6 | 5,3.6,1.4,0.2,setosa 7 | 5.4,3.9,1.7,0.4,setosa 8 | 4.6,3.4,1.4,0.3,setosa 9 | 5,3.4,1.5,0.2,setosa 10 | 4.4,2.9,1.4,0.2,setosa 11 | 4.9,3.1,1.5,0.1,setosa 12 | 5.4,3.7,1.5,0.2,setosa 13 | 4.8,3.4,1.6,0.2,setosa 14 | 4.8,3,1.4,0.1,setosa 15 | 4.3,3,1.1,0.1,setosa 16 | 5.8,4,1.2,0.2,setosa 17 | 5.7,4.4,1.5,0.4,setosa 18 | 5.4,3.9,1.3,0.4,setosa 19 | 5.1,3.5,1.4,0.3,setosa 20 | 5.7,3.8,1.7,0.3,setosa 21 | 5.1,3.8,1.5,0.3,setosa 22 | 5.4,3.4,1.7,0.2,setosa 23 | 5.1,3.7,1.5,0.4,setosa 24 | 4.6,3.6,1,0.2,setosa 25 | 5.1,3.3,1.7,0.5,setosa 26 | 4.8,3.4,1.9,0.2,setosa 27 | 5,3,1.6,0.2,setosa 28 | 5,3.4,1.6,0.4,setosa 29 | 5.2,3.5,1.5,0.2,setosa 30 | 5.2,3.4,1.4,0.2,setosa 31 | 4.7,3.2,1.6,0.2,setosa 32 | 4.8,3.1,1.6,0.2,setosa 33 | 5.4,3.4,1.5,0.4,setosa 34 | 5.2,4.1,1.5,0.1,setosa 35 | 5.5,4.2,1.4,0.2,setosa 36 | 4.9,3.1,1.5,0.2,setosa 37 | 5,3.2,1.2,0.2,setosa 38 | 5.5,3.5,1.3,0.2,setosa 39 | 4.9,3.6,1.4,0.1,setosa 40 | 4.4,3,1.3,0.2,setosa 41 | 5.1,3.4,1.5,0.2,setosa 42 | 5,3.5,1.3,0.3,setosa 43 | 4.5,2.3,1.3,0.3,setosa 44 | 4.4,3.2,1.3,0.2,setosa 45 | 5,3.5,1.6,0.6,setosa 46 | 5.1,3.8,1.9,0.4,setosa 47 | 4.8,3,1.4,0.3,setosa 48 | 5.1,3.8,1.6,0.2,setosa 49 | 4.6,3.2,1.4,0.2,setosa 50 | 5.3,3.7,1.5,0.2,setosa 51 | 5,3.3,1.4,0.2,setosa 52 | 7,3.2,4.7,1.4,versicolor 53 | 6.4,3.2,4.5,1.5,versicolor 54 | 6.9,3.1,4.9,1.5,versicolor 55 | 5.5,2.3,4,1.3,versicolor 56 | 6.5,2.8,4.6,1.5,versicolor 57 | 5.7,2.8,4.5,1.3,versicolor 58 | 6.3,3.3,4.7,1.6,versicolor 59 | 4.9,2.4,3.3,1,versicolor 60 | 6.6,2.9,4.6,1.3,versicolor 61 | 5.2,2.7,3.9,1.4,versicolor 62 | 5,2,3.5,1,versicolor 63 | 5.9,3,4.2,1.5,versicolor 64 | 6,2.2,4,1,versicolor 65 | 6.1,2.9,4.7,1.4,versicolor 66 | 5.6,2.9,3.6,1.3,versicolor 67 | 6.7,3.1,4.4,1.4,versicolor 68 | 5.6,3,4.5,1.5,versicolor 69 | 5.8,2.7,4.1,1,versicolor 70 | 6.2,2.2,4.5,1.5,versicolor 71 | 5.6,2.5,3.9,1.1,versicolor 72 | 5.9,3.2,4.8,1.8,versicolor 73 | 6.1,2.8,4,1.3,versicolor 74 | 6.3,2.5,4.9,1.5,versicolor 75 | 6.1,2.8,4.7,1.2,versicolor 76 | 6.4,2.9,4.3,1.3,versicolor 77 | 6.6,3,4.4,1.4,versicolor 78 | 6.8,2.8,4.8,1.4,versicolor 79 | 6.7,3,5,1.7,versicolor 80 | 6,2.9,4.5,1.5,versicolor 81 | 5.7,2.6,3.5,1,versicolor 82 | 5.5,2.4,3.8,1.1,versicolor 83 | 5.5,2.4,3.7,1,versicolor 84 | 5.8,2.7,3.9,1.2,versicolor 85 | 6,2.7,5.1,1.6,versicolor 86 | 5.4,3,4.5,1.5,versicolor 87 | 6,3.4,4.5,1.6,versicolor 88 | 6.7,3.1,4.7,1.5,versicolor 89 | 6.3,2.3,4.4,1.3,versicolor 90 | 5.6,3,4.1,1.3,versicolor 91 | 5.5,2.5,4,1.3,versicolor 92 | 5.5,2.6,4.4,1.2,versicolor 93 | 6.1,3,4.6,1.4,versicolor 94 | 5.8,2.6,4,1.2,versicolor 95 | 5,2.3,3.3,1,versicolor 96 | 5.6,2.7,4.2,1.3,versicolor 97 | 5.7,3,4.2,1.2,versicolor 98 | 5.7,2.9,4.2,1.3,versicolor 99 | 6.2,2.9,4.3,1.3,versicolor 100 | 5.1,2.5,3,1.1,versicolor 101 | 5.7,2.8,4.1,1.3,versicolor 102 | 6.3,3.3,6,2.5,virginica 103 | 5.8,2.7,5.1,1.9,virginica 104 | 7.1,3,5.9,2.1,virginica 105 | 6.3,2.9,5.6,1.8,virginica 106 | 6.5,3,5.8,2.2,virginica 107 | 7.6,3,6.6,2.1,virginica 108 | 4.9,2.5,4.5,1.7,virginica 109 | 7.3,2.9,6.3,1.8,virginica 110 | 6.7,2.5,5.8,1.8,virginica 111 | 7.2,3.6,6.1,2.5,virginica 112 | 6.5,3.2,5.1,2,virginica 113 | 6.4,2.7,5.3,1.9,virginica 114 | 6.8,3,5.5,2.1,virginica 115 | 5.7,2.5,5,2,virginica 116 | 5.8,2.8,5.1,2.4,virginica 117 | 6.4,3.2,5.3,2.3,virginica 118 | 6.5,3,5.5,1.8,virginica 119 | 7.7,3.8,6.7,2.2,virginica 120 | 7.7,2.6,6.9,2.3,virginica 121 | 6,2.2,5,1.5,virginica 122 | 6.9,3.2,5.7,2.3,virginica 123 | 5.6,2.8,4.9,2,virginica 124 | 7.7,2.8,6.7,2,virginica 125 | 6.3,2.7,4.9,1.8,virginica 126 | 6.7,3.3,5.7,2.1,virginica 127 | 7.2,3.2,6,1.8,virginica 128 | 6.2,2.8,4.8,1.8,virginica 129 | 6.1,3,4.9,1.8,virginica 130 | 6.4,2.8,5.6,2.1,virginica 131 | 7.2,3,5.8,1.6,virginica 132 | 7.4,2.8,6.1,1.9,virginica 133 | 7.9,3.8,6.4,2,virginica 134 | 6.4,2.8,5.6,2.2,virginica 135 | 6.3,2.8,5.1,1.5,virginica 136 | 6.1,2.6,5.6,1.4,virginica 137 | 7.7,3,6.1,2.3,virginica 138 | 6.3,3.4,5.6,2.4,virginica 139 | 6.4,3.1,5.5,1.8,virginica 140 | 6,3,4.8,1.8,virginica 141 | 6.9,3.1,5.4,2.1,virginica 142 | 6.7,3.1,5.6,2.4,virginica 143 | 6.9,3.1,5.1,2.3,virginica 144 | 5.8,2.7,5.1,1.9,virginica 145 | 6.8,3.2,5.9,2.3,virginica 146 | 6.7,3.3,5.7,2.5,virginica 147 | 6.7,3,5.2,2.3,virginica 148 | 6.3,2.5,5,1.9,virginica 149 | 6.5,3,5.2,2,virginica 150 | 6.2,3.4,5.4,2.3,virginica 151 | 5.9,3,5.1,1.8,virginica 152 | -------------------------------------------------------------------------------- /src/test/resources/csv/LinearClassificationIris.csv: -------------------------------------------------------------------------------- 1 | _target,probability(0),probability(1),probability(2) 2 | 0,0.9907627701759338,0.009237217716872692,3.749123234086937e-09 3 | 0,0.956351637840271,0.04364838823676109,4.663083075229224e-08 4 | 0,0.9832416772842407,0.016758281737565994,1.4245249424504891e-08 5 | 0,0.96475750207901,0.035242341458797455,6.261941365437451e-08 6 | 0,0.9936339855194092,0.006365958601236343,2.5858908347942133e-09 7 | 0,0.9948780536651611,0.005121925380080938,4.344276938184066e-09 8 | 0,0.989938497543335,0.010061517357826233,1.5592462432323373e-08 9 | 0,0.9844430088996887,0.015556998550891876,1.0484257018106291e-08 10 | 0,0.9506532549858093,0.04934665560722351,1.2117403969114093e-07 11 | 0,0.9586722254753113,0.04132779687643051,2.9145320823431575e-08 12 | 0,0.9932101368904114,0.006789925508201122,1.7355793469064906e-09 13 | 0,0.9819263815879822,0.018073581159114838,2.0241261111664244e-08 14 | 0,0.9562978148460388,0.043702177703380585,3.2291886498114764e-08 15 | 0,0.9819265007972717,0.01807350292801857,1.340281929884668e-08 16 | 0,0.9985393285751343,0.0014606263721361756,7.09817898902898e-11 17 | 0,0.9993151426315308,0.0006848605698905885,1.212019234975159e-10 18 | 0,0.9979902505874634,0.002009730786085129,6.787366779725801e-10 19 | 0,0.9911426305770874,0.008857413195073605,5.971869665444274e-09 20 | 0,0.9915858507156372,0.008414182811975479,3.251790836600321e-09 21 | 0,0.9958694577217102,0.004130557645112276,2.317200431534161e-09 22 | 0,0.9710810780525208,0.02891896851360798,1.7723605694186517e-08 23 | 0,0.9944779872894287,0.0055219619534909725,5.9075246916506785e-09 24 | 0,0.9978783130645752,0.002121762605383992,5.980346218237287e-10 25 | 0,0.9686342477798462,0.031365592032670975,1.5343113091148552e-07 26 | 0,0.9641212821006775,0.03587863966822624,8.016237984520558e-08 27 | 0,0.9293369650840759,0.07066291570663452,1.0412511386448386e-07 28 | 0,0.9819720387458801,0.018027935177087784,4.220543559085854e-08 29 | 0,0.9878655076026917,0.012134560383856297,5.396251889777659e-09 30 | 0,0.986613929271698,0.01338605210185051,5.428536287155339e-09 31 | 0,0.9666884541511536,0.033311571925878525,5.6490549127374834e-08 32 | 0,0.952247679233551,0.04775216802954674,8.091247849506544e-08 33 | 0,0.9831758737564087,0.01682407781481743,1.795434911855409e-08 34 | 0,0.9982764720916748,0.0017235465347766876,2.015663885801544e-10 35 | 0,0.9989411234855652,0.0010588397271931171,9.384583388172274e-11 36 | 0,0.9603186249732971,0.03968139365315437,4.6486725580052735e-08 37 | 0,0.9849821329116821,0.015017908066511154,6.68600907971495e-09 38 | 0,0.991384744644165,0.008615230210125446,1.5939382036478378e-09 39 | 0,0.9936258792877197,0.00637411093339324,1.790820824965067e-09 40 | 0,0.971447229385376,0.028552835807204247,4.8515222061951135e-08 41 | 0,0.9838011860847473,0.016198858618736267,9.501575526371653e-09 42 | 0,0.9932628273963928,0.006737218704074621,4.145765952756619e-09 43 | 0,0.7666867971420288,0.2333117127418518,1.5061916656122776e-06 44 | 0,0.9851560592651367,0.014843973331153393,1.9137836204663472e-08 45 | 0,0.9880671501159668,0.011932728812098503,6.71504380989063e-08 46 | 0,0.989938497543335,0.010061578825116158,2.3548112793037035e-08 47 | 0,0.9597053527832031,0.04029455780982971,8.216127866944589e-08 48 | 0,0.9945580959320068,0.005441893823444843,2.3135002802376903e-09 49 | 0,0.9797348976135254,0.020265132188796997,2.4915539853509472e-08 50 | 0,0.9934816360473633,0.006518419366329908,1.914356229093528e-09 51 | 0,0.9828447103500366,0.0171553585678339,1.0543256045991711e-08 52 | 1,0.01235899142920971,0.9509155750274658,0.03672543913125992 53 | 1,0.02474919520318508,0.892164409160614,0.08308641612529755 54 | 1,0.005547469016164541,0.8768611550331116,0.11759127676486969 55 | 1,0.005196007899940014,0.8811468482017517,0.11365707963705063 56 | 1,0.004678367171436548,0.845477819442749,0.1498437374830246 57 | 1,0.007672473788261414,0.8592108488082886,0.13311661779880524 58 | 1,0.020712710916996002,0.7863435745239258,0.1929437816143036 59 | 1,0.045286133885383606,0.9440891742706299,0.010624716058373451 60 | 1,0.006456250790506601,0.947411835193634,0.04613187909126282 61 | 1,0.028324998915195465,0.8459406495094299,0.125734344124794 62 | 1,0.007314296904951334,0.9666246175765991,0.026061132550239563 63 | 1,0.030478034168481827,0.8631949424743652,0.10632701218128204 64 | 1,0.0029708181973546743,0.9811984300613403,0.015830641612410545 65 | 1,0.005702598951756954,0.8259149193763733,0.1683824211359024 66 | 1,0.0952998697757721,0.8873710632324219,0.017329050227999687 67 | 1,0.02017264999449253,0.9478764533996582,0.031950827687978745 68 | 1,0.014186575077474117,0.7181479930877686,0.2676653265953064 69 | 1,0.013448399491608143,0.9734619855880737,0.013089646585285664 70 | 1,0.0007157018408179283,0.6712086200714111,0.328075647354126 71 | 1,0.012363800778985023,0.9639973044395447,0.02363891899585724 72 | 2,0.007023070007562637,0.367062509059906,0.6259143948554993 73 | 1,0.023268360644578934,0.9500163197517395,0.02671527862548828 74 | 1,0.0006387917674146593,0.5857351422309875,0.41362616419792175 75 | 1,0.0041693891398608685,0.9179723858833313,0.07785815000534058 76 | 1,0.014283771626651287,0.9549478888511658,0.03076826222240925 77 | 1,0.014973260462284088,0.9430962204933167,0.041930537670850754 78 | 1,0.0026314561255276203,0.8973376154899597,0.10003097355365753 79 | 1,0.0024750817101448774,0.5847184658050537,0.4128064215183258 80 | 1,0.009542977437376976,0.7952045202255249,0.19525246322155 81 | 1,0.040253568440675735,0.9554885029792786,0.004257922992110252 82 | 1,0.011658105067908764,0.9635844826698303,0.02475738897919655 83 | 1,0.01427911315113306,0.9737570285797119,0.011963841505348682 84 | 1,0.02295318990945816,0.9546934962272644,0.022353362292051315 85 | 2,0.0004998893709853292,0.3183695077896118,0.6811306476593018 86 | 1,0.01417017262428999,0.6607367396354675,0.3250930607318878 87 | 1,0.05234669893980026,0.7867223024368286,0.16093102097511292 88 | 1,0.009781667962670326,0.8907176852226257,0.0995006114244461 89 | 1,0.0014981826534494758,0.902254045009613,0.09624774754047394 90 | 1,0.042174533009529114,0.9089418649673462,0.048883602023124695 91 | 1,0.010368790477514267,0.9014078974723816,0.08822337538003922 92 | 1,0.005307402461767197,0.8811231851577759,0.11356951296329498 93 | 1,0.010573361068964005,0.8671032190322876,0.12232339382171631 94 | 1,0.01299689058214426,0.9546982645988464,0.03230488672852516 95 | 1,0.031566865742206573,0.9576645493507385,0.010768615640699863 96 | 1,0.0120697608217597,0.8961371183395386,0.09179309755563736 97 | 1,0.03157100826501846,0.9352532625198364,0.0331757552921772 98 | 1,0.023072900250554085,0.9150280952453613,0.06189899146556854 99 | 1,0.015337142162024975,0.9444908499717712,0.040171992033720016 100 | 1,0.11454834043979645,0.8800100684165955,0.005441547837108374 101 | 1,0.021049922332167625,0.9220680594444275,0.05688204616308212 102 | 2,3.864954578602919e-06,0.002117301570251584,0.997878909111023 103 | 2,0.00013877416495233774,0.07169366627931595,0.9281675815582275 104 | 2,2.019292878685519e-05,0.03922269120812416,0.9607571363449097 105 | 2,8.818122296361253e-05,0.09672459214925766,0.9031872153282166 106 | 2,1.15319153337623e-05,0.013270167633891106,0.9867182374000549 107 | 2,1.301973838963022e-06,0.01605234481394291,0.9839463233947754 108 | 2,0.0008393985335715115,0.15562355518341064,0.8435370326042175 109 | 2,9.230624527845066e-06,0.07892212271690369,0.9210686683654785 110 | 2,8.07259766588686e-06,0.06349460780620575,0.9364972710609436 111 | 2,2.396378113189712e-05,0.008818681351840496,0.9911574125289917 112 | 2,0.0015819192631170154,0.19650980830192566,0.8019083142280579 113 | 2,9.537961886962876e-05,0.10081013292074203,0.8990944623947144 114 | 2,9.426080214325339e-05,0.06331287324428558,0.9365928769111633 115 | 2,5.07617587572895e-05,0.03721673786640167,0.9627324938774109 116 | 2,2.3259566660271958e-05,0.006961107719689608,0.9930156469345093 117 | 2,0.0001698750420473516,0.0285175908356905,0.9713125824928284 118 | 2,0.0002515354426577687,0.16960451006889343,0.830143928527832 119 | 2,2.4483155357302167e-05,0.026327962055802345,0.9736475944519043 120 | 2,2.1547108985942032e-08,0.0019564770627766848,0.9980435371398926 121 | 2,0.00011778096086345613,0.32891443371772766,0.6709678173065186 122 | 2,4.3393971282057464e-05,0.0228702612221241,0.9770863652229309 123 | 2,0.0002947582979686558,0.06020559370517731,0.9394996166229248 124 | 2,5.572689474320214e-07,0.01842220313847065,0.9815772175788879 125 | 2,0.00069845886901021,0.28913193941116333,0.7101695537567139 126 | 2,0.00014055469364393502,0.05317571014165878,0.9466837048530579 127 | 2,0.00012332187907304615,0.18371708691120148,0.8161595463752747 128 | 2,0.001503189792856574,0.3381592929363251,0.660337507724762 129 | 2,0.0022681071422994137,0.31743723154067993,0.6802946329116821 130 | 2,1.6253045032499358e-05,0.022846922278404236,0.9771367907524109 131 | 2,0.00021609279792755842,0.42747440934181213,0.5723094940185547 132 | 2,1.0138602192455437e-05,0.07562297582626343,0.92436683177948 133 | 2,0.00026125286240130663,0.16419392824172974,0.8355448842048645 134 | 2,1.0301246220478788e-05,0.01387974712997675,0.9861099720001221 135 | 1,0.0010670357150956988,0.5742083787918091,0.42472463846206665 136 | 2,8.889858145266771e-05,0.2899028956890106,0.7100082039833069 137 | 2,5.758323368354468e-06,0.021023638546466827,0.9789705872535706 138 | 2,6.284495611907914e-05,0.01005922257900238,0.9898779392242432 139 | 2,0.000365751184290275,0.16946913301944733,0.8301650881767273 140 | 2,0.0031745396554470062,0.3372253179550171,0.6596000790596008 141 | 2,0.0002545989118516445,0.10089132189750671,0.8988540768623352 142 | 2,2.2509013433591463e-05,0.011568943969905376,0.988408625125885 143 | 2,0.0004166320723015815,0.07502555847167969,0.9245578050613403 144 | 2,0.00013877416495233774,0.07169366627931595,0.9281675815582275 145 | 2,1.5692636225139722e-05,0.012691783718764782,0.9872925281524658 146 | 2,2.2928788894205354e-05,0.0073222815990448,0.9926548004150391 147 | 2,0.00013925803068559617,0.040794309228658676,0.9590664505958557 148 | 2,0.00013149397273082286,0.128694087266922,0.8711743950843811 149 | 2,0.00041983998380601406,0.1286415457725525,0.8709385991096497 150 | 2,0.00022695286315865815,0.0227490346878767,0.9770240187644958 151 | 2,0.0008839622605592012,0.18220959603786469,0.8169064521789551 152 | -------------------------------------------------------------------------------- /src/test/resources/csv/LinearRegressionAuto.csv: -------------------------------------------------------------------------------- 1 | _target 2 | 14.524336814880371 3 | 13.114449501037598 4 | 14.379570960998535 5 | 15.452152252197266 6 | 14.740394592285156 7 | 9.00944709777832 8 | 8.106887817382812 9 | 8.577186584472656 10 | 8.478606224060059 11 | 10.781034469604492 12 | 11.127243995666504 13 | 12.488423347473145 14 | 9.184306144714355 15 | 9.493946075439453 16 | 25.72986602783203 17 | 20.543155670166016 18 | 20.604040145874023 19 | 20.416790008544922 20 | 26.390485763549805 21 | 26.667198181152344 22 | 25.865636825561523 23 | 25.41106414794922 24 | 26.703521728515625 25 | 25.00238609313965 26 | 20.299753189086914 27 | 14.39854621887207 28 | 17.25547981262207 29 | 16.548337936401367 30 | 17.9620418548584 31 | 26.739355087280273 32 | 24.042133331298828 33 | 25.88494873046875 34 | 18.59935188293457 35 | 19.307435989379883 36 | 17.84322738647461 37 | 17.437973022460938 38 | 18.860883712768555 39 | 13.223536491394043 40 | 10.493000984191895 41 | 13.239171028137207 42 | 14.83393383026123 43 | 11.229829788208008 44 | 10.253267288208008 45 | 10.131891250610352 46 | 17.431861877441406 47 | 24.35477066040039 48 | 17.727378845214844 49 | 17.258594512939453 50 | 24.46568489074707 51 | 25.347957611083984 52 | 28.36988639831543 53 | 26.581398010253906 54 | 28.98252296447754 55 | 28.887969970703125 56 | 27.059043884277344 57 | 27.767436981201172 58 | 26.65037727355957 59 | 26.924732208251953 60 | 28.239978790283203 61 | 25.493730545043945 62 | 25.567453384399414 63 | 13.52312183380127 64 | 11.053267478942871 65 | 15.304723739624023 66 | 13.455512046813965 67 | 15.817177772521973 68 | 10.14352798461914 69 | 13.449926376342773 70 | 13.66219425201416 71 | 11.70886516571045 72 | 28.291288375854492 73 | 15.953335762023926 74 | 15.377619743347168 75 | 16.463775634765625 76 | 15.50019359588623 77 | 25.740571975708008 78 | 25.843799591064453 79 | 26.388456344604492 80 | 27.209854125976562 81 | 25.28781509399414 82 | 27.867778778076172 83 | 25.861801147460938 84 | 26.262556076049805 85 | 27.716936111450195 86 | 14.661688804626465 87 | 16.16604995727539 88 | 13.682291030883789 89 | 16.442832946777344 90 | 15.622081756591797 91 | 10.047220230102539 92 | 10.455296516418457 93 | 13.80433177947998 94 | 15.879228591918945 95 | 10.060495376586914 96 | 9.429344177246094 97 | 13.719199180603027 98 | 20.549293518066406 99 | 19.337093353271484 100 | 19.9702091217041 101 | 18.651771545410156 102 | 21.687423706054688 103 | 27.778099060058594 104 | 10.657090187072754 105 | 10.874743461608887 106 | 13.516121864318848 107 | 14.385031700134277 108 | 19.7855224609375 109 | 28.68752098083496 110 | 25.209314346313477 111 | 27.464208602905273 112 | 28.54802131652832 113 | 26.423110961914062 114 | 24.19156837463379 115 | 27.379060745239258 116 | 13.61101245880127 117 | 12.677093505859375 118 | 29.082468032836914 119 | 25.94146728515625 120 | 25.842744827270508 121 | 26.441709518432617 122 | 15.454242706298828 123 | 26.07400894165039 124 | 24.87615966796875 125 | 14.563728332519531 126 | 22.037643432617188 127 | 20.35244369506836 128 | 19.339004516601562 129 | 29.527402877807617 130 | 25.8817195892334 131 | 30.587324142456055 132 | 24.800251007080078 133 | 19.001569747924805 134 | 19.333833694458008 135 | 20.525089263916016 136 | 16.67157745361328 137 | 14.123878479003906 138 | 15.758299827575684 139 | 16.90066909790039 140 | 17.283241271972656 141 | 27.817445755004883 142 | 28.182323455810547 143 | 27.027490615844727 144 | 28.630462646484375 145 | 29.055112838745117 146 | 27.099468231201172 147 | 27.7532901763916 148 | 25.769140243530273 149 | 26.723922729492188 150 | 27.465524673461914 151 | 28.305755615234375 152 | 20.63233757019043 153 | 19.469018936157227 154 | 19.833627700805664 155 | 19.586929321289062 156 | 11.556412696838379 157 | 14.3402681350708 158 | 16.37906265258789 159 | 14.07585620880127 160 | 21.862964630126953 161 | 19.894336700439453 162 | 19.911375045776367 163 | 21.146209716796875 164 | 20.703279495239258 165 | 18.69462013244629 166 | 16.76129150390625 167 | 28.097015380859375 168 | 25.359392166137695 169 | 20.69145965576172 170 | 25.67211151123047 171 | 25.654006958007812 172 | 28.1760311126709 173 | 27.691099166870117 174 | 23.420839309692383 175 | 27.5999755859375 176 | 20.414453506469727 177 | 26.84564208984375 178 | 26.72978401184082 179 | 26.281402587890625 180 | 26.789310455322266 181 | 28.384946823120117 182 | 27.63958168029785 183 | 27.578105926513672 184 | 25.442110061645508 185 | 28.371604919433594 186 | 28.000080108642578 187 | 16.845802307128906 188 | 16.507017135620117 189 | 16.65558624267578 190 | 14.689711570739746 191 | 21.000316619873047 192 | 19.44379997253418 193 | 22.53060531616211 194 | 21.04065704345703 195 | 29.656457901000977 196 | 29.099641799926758 197 | 28.009441375732422 198 | 28.703519821166992 199 | 21.380207061767578 200 | 20.287689208984375 201 | 19.914833068847656 202 | 19.771705627441406 203 | 27.14022445678711 204 | 29.366436004638672 205 | 28.57921028137207 206 | 24.343994140625 207 | 26.48702621459961 208 | 16.75718116760254 209 | 28.32590675354004 210 | 25.938779830932617 211 | 25.120704650878906 212 | 15.4006929397583 213 | 14.375121116638184 214 | 17.523019790649414 215 | 17.139848709106445 216 | 29.34459686279297 217 | 27.240625381469727 218 | 29.953519821166992 219 | 27.307498931884766 220 | 29.6888370513916 221 | 17.474597930908203 222 | 20.532169342041016 223 | 16.92850112915039 224 | 17.519325256347656 225 | 20.419086456298828 226 | 21.50659942626953 227 | 21.74500274658203 228 | 20.77730941772461 229 | 12.8274507522583 230 | 15.345734596252441 231 | 13.435884475708008 232 | 15.356223106384277 233 | 28.347089767456055 234 | 25.252056121826172 235 | 29.390033721923828 236 | 25.818466186523438 237 | 28.095441818237305 238 | 28.453516006469727 239 | 28.77317237854004 240 | 28.036327362060547 241 | 26.22971534729004 242 | 27.151418685913086 243 | 29.653034210205078 244 | 30.102027893066406 245 | 27.95332908630371 246 | 30.540189743041992 247 | 30.488285064697266 248 | 29.342836380004883 249 | 20.34762191772461 250 | 17.255598068237305 251 | 18.100994110107422 252 | 22.468910217285156 253 | 23.798391342163086 254 | 22.860538482666016 255 | 26.037208557128906 256 | 22.09404182434082 257 | 21.522424697875977 258 | 21.556316375732422 259 | 23.053600311279297 260 | 22.75921058654785 261 | 20.373798370361328 262 | 18.380571365356445 263 | 22.908533096313477 264 | 17.893001556396484 265 | 17.14548110961914 266 | 28.39134979248047 267 | 26.984907150268555 268 | 28.226642608642578 269 | 27.593198776245117 270 | 27.200817108154297 271 | 26.138294219970703 272 | 25.892059326171875 273 | 28.207618713378906 274 | 27.610136032104492 275 | 25.846281051635742 276 | 28.4084529876709 277 | 26.591917037963867 278 | 28.96936798095703 279 | 29.049564361572266 280 | 22.241140365600586 281 | 23.917604446411133 282 | 26.83283233642578 283 | 22.132570266723633 284 | 22.668981552124023 285 | 18.549148559570312 286 | 18.159343719482422 287 | 15.557988166809082 288 | 17.96100616455078 289 | 16.423877716064453 290 | 15.95810604095459 291 | 20.51151466369629 292 | 15.441217422485352 293 | 29.0948486328125 294 | 29.646869659423828 295 | 28.711685180664062 296 | 27.06035614013672 297 | 24.9282283782959 298 | 16.465810775756836 299 | 28.557714462280273 300 | 21.975196838378906 301 | 27.3935604095459 302 | 27.9465389251709 303 | 30.879470825195312 304 | 28.97091293334961 305 | 26.073835372924805 306 | 24.670745849609375 307 | 25.075891494750977 308 | 25.311939239501953 309 | 29.173847198486328 310 | 30.749956130981445 311 | 28.883609771728516 312 | 30.325950622558594 313 | 26.568130493164062 314 | 27.439252853393555 315 | 27.41241455078125 316 | 22.928632736206055 317 | 29.599523544311523 318 | 27.784637451171875 319 | 28.753910064697266 320 | 28.736284255981445 321 | 28.92481231689453 322 | 30.711416244506836 323 | 26.097482681274414 324 | 31.160097122192383 325 | 30.7845401763916 326 | 31.200927734375 327 | 28.699703216552734 328 | 27.536195755004883 329 | 29.46326446533203 330 | 30.1832275390625 331 | 29.578956604003906 332 | 26.30097007751465 333 | 30.817472457885742 334 | 28.135398864746094 335 | 29.399593353271484 336 | 27.481258392333984 337 | 26.706445693969727 338 | 26.12163543701172 339 | 25.48639678955078 340 | 26.712535858154297 341 | 31.161819458007812 342 | 30.135786056518555 343 | 30.876981735229492 344 | 30.53216552734375 345 | 31.671932220458984 346 | 30.65313720703125 347 | 30.41179847717285 348 | 28.346649169921875 349 | 29.32404899597168 350 | 30.433427810668945 351 | 28.88172721862793 352 | 29.1278133392334 353 | 29.693994522094727 354 | 29.171131134033203 355 | 29.239166259765625 356 | 28.211318969726562 357 | 27.943769454956055 358 | 26.453359603881836 359 | 28.141897201538086 360 | 22.753780364990234 361 | 17.07146644592285 362 | 24.33542251586914 363 | 22.400165557861328 364 | 30.32706642150879 365 | 29.997549057006836 366 | 30.001541137695312 367 | 29.213254928588867 368 | 27.89448356628418 369 | 27.677122116088867 370 | 27.767635345458984 371 | 29.72311782836914 372 | 31.396892547607422 373 | 31.256813049316406 374 | 28.70317268371582 375 | 30.1229248046875 376 | 29.293588638305664 377 | 29.51077651977539 378 | 29.975399017333984 379 | 30.437381744384766 380 | 30.64946746826172 381 | 30.778207778930664 382 | 26.38129234313965 383 | 21.183517456054688 384 | 26.527347564697266 385 | 23.225326538085938 386 | 27.697267532348633 387 | 27.10308265686035 388 | 27.302005767822266 389 | 27.3692569732666 390 | 32.085044860839844 391 | 26.735782623291016 392 | 29.25115203857422 393 | 29.582740783691406 394 | -------------------------------------------------------------------------------- /src/test/resources/main.py: -------------------------------------------------------------------------------- 1 | from pandas import DataFrame 2 | from tensorflow.contrib.learn import DNNClassifier, DNNRegressor, LinearClassifier, LinearRegressor, RunConfig 3 | from tensorflow.contrib.layers import one_hot_column, real_valued_column, sparse_column_with_keys 4 | from tensorflow.contrib.layers.python.layers.feature_column import _OneHotColumn, _RealValuedColumn, _SparseColumnKeys 5 | from tensorflow.contrib.learn.python.learn.utils.input_fn_utils import InputFnOps 6 | 7 | import numpy 8 | import os 9 | import pandas 10 | import shutil 11 | import tempfile 12 | import tensorflow as tf 13 | 14 | tf.logging.set_verbosity(tf.logging.INFO) 15 | 16 | #tf.reset_default_graph() 17 | 18 | estimator_conf = RunConfig(num_cores = 1, tf_random_seed = 42) 19 | 20 | def load_csv(name): 21 | return pandas.read_csv("csv/" + name) 22 | 23 | def store_csv(df, name): 24 | df.to_csv("csv/" + name, index = False) 25 | 26 | def store_savedmodel(estimator, serving_input_fn, name): 27 | savemodel_dir = estimator.export_savedmodel(tempfile.mkdtemp(), serving_input_fn = serving_input_fn, as_text = True) 28 | savemodel_dir = savemodel_dir.decode("UTF-8") 29 | 30 | if(os.path.isdir("savedmodel/" + name)): 31 | shutil.rmtree("savedmodel/" + name) 32 | shutil.move(savemodel_dir, "savedmodel/" + name) 33 | 34 | def _dnn_feature_columns(feature_columns): 35 | return list(map(lambda x: one_hot_column(x) if isinstance(x, _SparseColumnKeys) else x, feature_columns)) 36 | 37 | def _input_fn(df, cont_feature_columns, cat_feature_columns, label_column): 38 | cont_features = {column : tf.constant(df[column].values, dtype = tf.float64, shape = [df[column].size, 1]) for column in cont_feature_columns} 39 | cat_features = {column : tf.constant(df[column].values, dtype = tf.string, shape = [df[column].size, 1]) for column in cat_feature_columns} 40 | features = dict(list(cont_features.items()) + list(cat_features.items())) 41 | label = tf.constant(df[label_column].values, shape = [df[label_column].size, 1]) 42 | return features, label 43 | 44 | def _serving_input_fn(cont_feature_columns, cat_feature_columns): 45 | cont_feature_placeholders = {column : tf.placeholder(dtype = tf.float64, shape = [None, 1], name = column) for column in cont_feature_columns} 46 | cat_feature_placeholders = {column : tf.placeholder(dtype = tf.string, shape = [None, 1], name = column) for column in cat_feature_columns} 47 | feature_placeholders = dict(list(cont_feature_placeholders.items()) + list(cat_feature_placeholders.items())) 48 | features = {column : tensor for column, tensor in feature_placeholders.items()} 49 | label = None 50 | return InputFnOps(features, label, feature_placeholders) 51 | 52 | # 53 | # Binary classification 54 | # 55 | 56 | audit_df = load_csv("Audit.csv") 57 | audit_df["Adjusted"] = audit_df["Adjusted"].astype(int) 58 | 59 | audit_cont_columns = ["Age", "Income", "Deductions", "Hours"] 60 | audit_cat_columns = ["Employment", "Education", "Marital", "Occupation", "Gender"] 61 | 62 | audit_feature_columns = [real_valued_column(column, dtype = tf.float64) for column in audit_cont_columns] + [sparse_column_with_keys(column, dtype = tf.string, keys = sorted(audit_df[column].unique())) for column in audit_cat_columns] 63 | 64 | def audit_input_fn(): 65 | return _input_fn(audit_df, audit_cont_columns, audit_cat_columns, "Adjusted") 66 | 67 | def audit_serving_input_fn(): 68 | return _serving_input_fn(audit_cont_columns, audit_cat_columns) 69 | 70 | def build_audit(classifier, max_steps, name, with_proba = True): 71 | classifier.fit(input_fn = audit_input_fn, max_steps = max_steps) 72 | 73 | adjusted = DataFrame(classifier.predict(input_fn = audit_input_fn, as_iterable = False), columns = ["_target"]) 74 | if(with_proba): 75 | adjusted_proba = DataFrame(classifier.predict_proba(input_fn = audit_input_fn, as_iterable = False), columns = ["probability(0)", "probability(1)"]) 76 | adjusted = pandas.concat((adjusted, adjusted_proba), axis = 1) 77 | store_csv(adjusted, name + ".csv") 78 | 79 | store_savedmodel(classifier, audit_serving_input_fn, name) 80 | 81 | build_audit(DNNClassifier(hidden_units = [2 * 49], feature_columns = _dnn_feature_columns(audit_feature_columns), optimizer = tf.train.AdamOptimizer(learning_rate = 0.00001), config = estimator_conf), 2000, "DNNClassificationAudit") 82 | build_audit(LinearClassifier(feature_columns = audit_feature_columns, optimizer = tf.train.AdamOptimizer(learning_rate = 0.00025), config = estimator_conf), 5000, "LinearClassificationAudit") 83 | 84 | # 85 | # Multi-class classification 86 | # 87 | 88 | iris_df = load_csv("Iris.csv") 89 | iris_df["Species"] = iris_df["Species"].replace("setosa", "0").replace("versicolor", "1").replace("virginica", "2").astype(int) 90 | 91 | iris_cont_columns = ["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"] 92 | 93 | iris_feature_columns = [real_valued_column(column, dtype = tf.float64) for column in iris_cont_columns] 94 | 95 | def iris_input_fn(): 96 | return _input_fn(iris_df, iris_cont_columns, [], "Species") 97 | 98 | def iris_serving_input_fn(): 99 | return _serving_input_fn(iris_cont_columns, []) 100 | 101 | def build_iris(classifier, max_steps, name, with_proba = True): 102 | classifier.fit(input_fn = iris_input_fn, max_steps = max_steps) 103 | 104 | species = DataFrame(classifier.predict(input_fn = iris_input_fn, as_iterable = False), columns = ["_target"]) 105 | if(with_proba): 106 | species_proba = DataFrame(classifier.predict_proba(input_fn = iris_input_fn, as_iterable = False), columns = ["probability(0)", "probability(1)", "probability(2)"]) 107 | species = pandas.concat((species, species_proba), axis = 1) 108 | store_csv(species, name + ".csv") 109 | 110 | store_savedmodel(classifier, iris_serving_input_fn, name) 111 | 112 | build_iris(DNNClassifier(hidden_units = [4 * 3, 2 * 3], feature_columns = _dnn_feature_columns(iris_feature_columns), n_classes = 3, optimizer = tf.train.AdamOptimizer, config = estimator_conf), 2000, "DNNClassificationIris") 113 | build_iris(LinearClassifier(feature_columns = iris_feature_columns, n_classes = 3, optimizer = tf.train.AdamOptimizer, config = estimator_conf), 5000, "LinearClassificationIris") 114 | 115 | # 116 | # Regression 117 | # 118 | 119 | auto_df = load_csv("Auto.csv") 120 | auto_df["origin"] = auto_df["origin"].astype(str) 121 | 122 | auto_cont_columns = ["cylinders", "displacement", "horsepower", "weight", "acceleration", "model_year"] 123 | auto_cat_columns = ["origin"] 124 | 125 | auto_feature_columns = [real_valued_column(column, dtype = tf.float64) for column in auto_cont_columns] + [sparse_column_with_keys(column, dtype = tf.string, keys = sorted(auto_df[column].unique())) for column in auto_cat_columns] 126 | 127 | def auto_input_fn(): 128 | return _input_fn(auto_df, auto_cont_columns, auto_cat_columns, "mpg") 129 | 130 | def auto_serving_input_fn(): 131 | return _serving_input_fn(auto_cont_columns, auto_cat_columns) 132 | 133 | def build_auto(regressor, max_steps, name): 134 | regressor.fit(input_fn = auto_input_fn, max_steps = max_steps) 135 | 136 | mpg = DataFrame(regressor.predict(input_fn = auto_input_fn, as_iterable = False), columns = ["_target"]) 137 | store_csv(mpg, name + ".csv") 138 | 139 | store_savedmodel(regressor, auto_serving_input_fn, name) 140 | 141 | build_auto(DNNRegressor(hidden_units = [2 * 9, 9, 3], feature_columns = _dnn_feature_columns(auto_feature_columns), optimizer = tf.train.AdamOptimizer(learning_rate = 0.001), config = estimator_conf), 2000, "DNNRegressionAuto") 142 | build_auto(LinearRegressor(feature_columns = auto_feature_columns, optimizer = tf.train.AdamOptimizer, config = estimator_conf), 1000, "LinearRegressionAuto") 143 | -------------------------------------------------------------------------------- /src/test/resources/savedmodel/DNNClassificationAudit/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/DNNClassificationAudit/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /src/test/resources/savedmodel/DNNClassificationAudit/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/DNNClassificationAudit/variables/variables.index -------------------------------------------------------------------------------- /src/test/resources/savedmodel/DNNClassificationIris/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/DNNClassificationIris/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /src/test/resources/savedmodel/DNNClassificationIris/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/DNNClassificationIris/variables/variables.index -------------------------------------------------------------------------------- /src/test/resources/savedmodel/DNNRegressionAuto/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/DNNRegressionAuto/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /src/test/resources/savedmodel/DNNRegressionAuto/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/DNNRegressionAuto/variables/variables.index -------------------------------------------------------------------------------- /src/test/resources/savedmodel/LinearClassificationAudit/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/LinearClassificationAudit/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /src/test/resources/savedmodel/LinearClassificationAudit/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/LinearClassificationAudit/variables/variables.index -------------------------------------------------------------------------------- /src/test/resources/savedmodel/LinearClassificationIris/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/LinearClassificationIris/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /src/test/resources/savedmodel/LinearClassificationIris/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/LinearClassificationIris/variables/variables.index -------------------------------------------------------------------------------- /src/test/resources/savedmodel/LinearRegressionAuto/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/LinearRegressionAuto/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /src/test/resources/savedmodel/LinearRegressionAuto/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-tensorflow/81a93e3df2af037dcb2d8af067cfe973f744b35c/src/test/resources/savedmodel/LinearRegressionAuto/variables/variables.index --------------------------------------------------------------------------------