├── LICENSE.txt
├── NOTICE.txt
├── README.md
├── pom.xml
└── src
├── main
├── java
│ └── org
│ │ └── jpmml
│ │ └── tensorflow
│ │ ├── DNNClassifier.java
│ │ ├── DNNEstimator.java
│ │ ├── DNNRegressor.java
│ │ ├── Estimator.java
│ │ ├── EstimatorFactory.java
│ │ ├── LinearClassifier.java
│ │ ├── LinearEstimator.java
│ │ ├── LinearRegressor.java
│ │ ├── Main.java
│ │ ├── SavedModel.java
│ │ ├── ShapeUtil.java
│ │ ├── TensorFlowEncoder.java
│ │ ├── TensorUtil.java
│ │ ├── Trail.java
│ │ └── TypeUtil.java
└── proto
│ └── tensorflow
│ └── core
│ ├── framework
│ ├── attr_value.proto
│ ├── function.proto
│ ├── graph.proto
│ ├── node_def.proto
│ ├── op_def.proto
│ ├── resource_handle.proto
│ ├── tensor.proto
│ ├── tensor_shape.proto
│ ├── types.proto
│ └── versions.proto
│ └── protobuf
│ ├── meta_graph.proto
│ └── saver.proto
└── test
├── java
└── org
│ └── jpmml
│ └── tensorflow
│ ├── DNNClassifierTest.java
│ ├── DNNRegressorTest.java
│ ├── EstimatorTest.java
│ ├── LinearClassifierTest.java
│ └── LinearRegressorTest.java
└── resources
├── csv
├── Audit.csv
├── Auto.csv
├── DNNClassificationAudit.csv
├── DNNClassificationIris.csv
├── DNNRegressionAuto.csv
├── Iris.csv
├── LinearClassificationAudit.csv
├── LinearClassificationIris.csv
└── LinearRegressionAuto.csv
├── main.py
└── savedmodel
├── DNNClassificationAudit
├── saved_model.pbtxt
└── variables
│ ├── variables.data-00000-of-00001
│ └── variables.index
├── DNNClassificationIris
├── saved_model.pbtxt
└── variables
│ ├── variables.data-00000-of-00001
│ └── variables.index
├── DNNRegressionAuto
├── saved_model.pbtxt
└── variables
│ ├── variables.data-00000-of-00001
│ └── variables.index
├── LinearClassificationAudit
├── saved_model.pbtxt
└── variables
│ ├── variables.data-00000-of-00001
│ └── variables.index
├── LinearClassificationIris
├── saved_model.pbtxt
└── variables
│ ├── variables.data-00000-of-00001
│ └── variables.index
└── LinearRegressionAuto
├── saved_model.pbtxt
└── variables
├── variables.data-00000-of-00001
└── variables.index
/NOTICE.txt:
--------------------------------------------------------------------------------
1 | JPMML-TensorFlow includes third-party dependencies that are released under the Apache License, Version 2.0:
2 | * Guava - https://github.com/google/guava
3 | * JCommander - http://jcommander.org
4 | * Protocol Buffers - https://github.com/google/protobuf
5 | * TensorFlow - https://github.com/tensorflow/tensorflow
6 |
7 | Apache License
8 | Version 2.0, January 2004
9 | http://www.apache.org/licenses/
10 |
11 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
12 |
13 | 1. Definitions.
14 |
15 | "License" shall mean the terms and conditions for use, reproduction,
16 | and distribution as defined by Sections 1 through 9 of this document.
17 |
18 | "Licensor" shall mean the copyright owner or entity authorized by
19 | the copyright owner that is granting the License.
20 |
21 | "Legal Entity" shall mean the union of the acting entity and all
22 | other entities that control, are controlled by, or are under common
23 | control with that entity. For the purposes of this definition,
24 | "control" means (i) the power, direct or indirect, to cause the
25 | direction or management of such entity, whether by contract or
26 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
27 | outstanding shares, or (iii) beneficial ownership of such entity.
28 |
29 | "You" (or "Your") shall mean an individual or Legal Entity
30 | exercising permissions granted by this License.
31 |
32 | "Source" form shall mean the preferred form for making modifications,
33 | including but not limited to software source code, documentation
34 | source, and configuration files.
35 |
36 | "Object" form shall mean any form resulting from mechanical
37 | transformation or translation of a Source form, including but
38 | not limited to compiled object code, generated documentation,
39 | and conversions to other media types.
40 |
41 | "Work" shall mean the work of authorship, whether in Source or
42 | Object form, made available under the License, as indicated by a
43 | copyright notice that is included in or attached to the work
44 | (an example is provided in the Appendix below).
45 |
46 | "Derivative Works" shall mean any work, whether in Source or Object
47 | form, that is based on (or derived from) the Work and for which the
48 | editorial revisions, annotations, elaborations, or other modifications
49 | represent, as a whole, an original work of authorship. For the purposes
50 | of this License, Derivative Works shall not include works that remain
51 | separable from, or merely link (or bind by name) to the interfaces of,
52 | the Work and Derivative Works thereof.
53 |
54 | "Contribution" shall mean any work of authorship, including
55 | the original version of the Work and any modifications or additions
56 | to that Work or Derivative Works thereof, that is intentionally
57 | submitted to Licensor for inclusion in the Work by the copyright owner
58 | or by an individual or Legal Entity authorized to submit on behalf of
59 | the copyright owner. For the purposes of this definition, "submitted"
60 | means any form of electronic, verbal, or written communication sent
61 | to the Licensor or its representatives, including but not limited to
62 | communication on electronic mailing lists, source code control systems,
63 | and issue tracking systems that are managed by, or on behalf of, the
64 | Licensor for the purpose of discussing and improving the Work, but
65 | excluding communication that is conspicuously marked or otherwise
66 | designated in writing by the copyright owner as "Not a Contribution."
67 |
68 | "Contributor" shall mean Licensor and any individual or Legal Entity
69 | on behalf of whom a Contribution has been received by Licensor and
70 | subsequently incorporated within the Work.
71 |
72 | 2. Grant of Copyright License. Subject to the terms and conditions of
73 | this License, each Contributor hereby grants to You a perpetual,
74 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
75 | copyright license to reproduce, prepare Derivative Works of,
76 | publicly display, publicly perform, sublicense, and distribute the
77 | Work and such Derivative Works in Source or Object form.
78 |
79 | 3. Grant of Patent License. Subject to the terms and conditions of
80 | this License, each Contributor hereby grants to You a perpetual,
81 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
82 | (except as stated in this section) patent license to make, have made,
83 | use, offer to sell, sell, import, and otherwise transfer the Work,
84 | where such license applies only to those patent claims licensable
85 | by such Contributor that are necessarily infringed by their
86 | Contribution(s) alone or by combination of their Contribution(s)
87 | with the Work to which such Contribution(s) was submitted. If You
88 | institute patent litigation against any entity (including a
89 | cross-claim or counterclaim in a lawsuit) alleging that the Work
90 | or a Contribution incorporated within the Work constitutes direct
91 | or contributory patent infringement, then any patent licenses
92 | granted to You under this License for that Work shall terminate
93 | as of the date such litigation is filed.
94 |
95 | 4. Redistribution. You may reproduce and distribute copies of the
96 | Work or Derivative Works thereof in any medium, with or without
97 | modifications, and in Source or Object form, provided that You
98 | meet the following conditions:
99 |
100 | (a) You must give any other recipients of the Work or
101 | Derivative Works a copy of this License; and
102 |
103 | (b) You must cause any modified files to carry prominent notices
104 | stating that You changed the files; and
105 |
106 | (c) You must retain, in the Source form of any Derivative Works
107 | that You distribute, all copyright, patent, trademark, and
108 | attribution notices from the Source form of the Work,
109 | excluding those notices that do not pertain to any part of
110 | the Derivative Works; and
111 |
112 | (d) If the Work includes a "NOTICE" text file as part of its
113 | distribution, then any Derivative Works that You distribute must
114 | include a readable copy of the attribution notices contained
115 | within such NOTICE file, excluding those notices that do not
116 | pertain to any part of the Derivative Works, in at least one
117 | of the following places: within a NOTICE text file distributed
118 | as part of the Derivative Works; within the Source form or
119 | documentation, if provided along with the Derivative Works; or,
120 | within a display generated by the Derivative Works, if and
121 | wherever such third-party notices normally appear. The contents
122 | of the NOTICE file are for informational purposes only and
123 | do not modify the License. You may add Your own attribution
124 | notices within Derivative Works that You distribute, alongside
125 | or as an addendum to the NOTICE text from the Work, provided
126 | that such additional attribution notices cannot be construed
127 | as modifying the License.
128 |
129 | You may add Your own copyright statement to Your modifications and
130 | may provide additional or different license terms and conditions
131 | for use, reproduction, or distribution of Your modifications, or
132 | for any such Derivative Works as a whole, provided Your use,
133 | reproduction, and distribution of the Work otherwise complies with
134 | the conditions stated in this License.
135 |
136 | 5. Submission of Contributions. Unless You explicitly state otherwise,
137 | any Contribution intentionally submitted for inclusion in the Work
138 | by You to the Licensor shall be under the terms and conditions of
139 | this License, without any additional terms or conditions.
140 | Notwithstanding the above, nothing herein shall supersede or modify
141 | the terms of any separate license agreement you may have executed
142 | with Licensor regarding such Contributions.
143 |
144 | 6. Trademarks. This License does not grant permission to use the trade
145 | names, trademarks, service marks, or product names of the Licensor,
146 | except as required for reasonable and customary use in describing the
147 | origin of the Work and reproducing the content of the NOTICE file.
148 |
149 | 7. Disclaimer of Warranty. Unless required by applicable law or
150 | agreed to in writing, Licensor provides the Work (and each
151 | Contributor provides its Contributions) on an "AS IS" BASIS,
152 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
153 | implied, including, without limitation, any warranties or conditions
154 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
155 | PARTICULAR PURPOSE. You are solely responsible for determining the
156 | appropriateness of using or redistributing the Work and assume any
157 | risks associated with Your exercise of permissions under this License.
158 |
159 | 8. Limitation of Liability. In no event and under no legal theory,
160 | whether in tort (including negligence), contract, or otherwise,
161 | unless required by applicable law (such as deliberate and grossly
162 | negligent acts) or agreed to in writing, shall any Contributor be
163 | liable to You for damages, including any direct, indirect, special,
164 | incidental, or consequential damages of any character arising as a
165 | result of this License or out of the use or inability to use the
166 | Work (including but not limited to damages for loss of goodwill,
167 | work stoppage, computer failure or malfunction, or any and all
168 | other commercial damages or losses), even if such Contributor
169 | has been advised of the possibility of such damages.
170 |
171 | 9. Accepting Warranty or Additional Liability. While redistributing
172 | the Work or Derivative Works thereof, You may choose to offer,
173 | and charge a fee for, acceptance of support, warranty, indemnity,
174 | or other liability obligations and/or rights consistent with this
175 | License. However, in accepting such obligations, You may act only
176 | on Your own behalf and on Your sole responsibility, not on behalf
177 | of any other Contributor, and only if You agree to indemnify,
178 | defend, and hold each Contributor harmless for any liability
179 | incurred by, or claims asserted against, such Contributor by reason
180 | of your accepting any such warranty or additional liability.
181 |
182 | END OF TERMS AND CONDITIONS
183 |
184 | APPENDIX: How to apply the Apache License to your work.
185 |
186 | To apply the Apache License to your work, attach the following
187 | boilerplate notice, with the fields enclosed by brackets "[]"
188 | replaced with your own identifying information. (Don't include
189 | the brackets!) The text should be enclosed in the appropriate
190 | comment syntax for the file format. We also recommend that a
191 | file or class name and description of purpose be included on the
192 | same "printed page" as the copyright notice for easier
193 | identification within third-party archives.
194 |
195 | Copyright [yyyy] [name of copyright owner]
196 |
197 | Licensed under the Apache License, Version 2.0 (the "License");
198 | you may not use this file except in compliance with the License.
199 | You may obtain a copy of the License at
200 |
201 | http://www.apache.org/licenses/LICENSE-2.0
202 |
203 | Unless required by applicable law or agreed to in writing, software
204 | distributed under the License is distributed on an "AS IS" BASIS,
205 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
206 | See the License for the specific language governing permissions and
207 | limitations under the License.
208 |
209 | --------------------------------------------------------------------------------
210 |
211 | Additionally, JPMML-TensorFlow includes third-party dependencies that are released under the MIT License:
212 | * Simple Logging Facade for Java (SLF4J) - http://www.slf4j.org/
213 |
214 | Copyright (c) by Irmen de Jong (irmen@razorvine.net)
215 | All rights reserved.
216 |
217 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
218 |
219 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
220 |
221 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
222 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | JPMML-TensorFlow
2 | ================
3 |
4 | Java library and command-line application for converting [TensorFlow](http://tensorflow.org) models to PMML.
5 |
6 | # Features #
7 |
8 | * Supported Estimator types:
9 | * [`learn.DNNClassifier`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNClassifier)
10 | * [`learn.DNNRegressor`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNRegressor)
11 | * [`learn.LinearClassifier`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/LinearClassifier)
12 | * [`learn.LinearRegressor`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/LinearRegressor)
13 | * Supported Feature column types:
14 | * [`layers.one_hot_column`](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/one_hot_column)
15 | * [`layers.real_valued_column`](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/real_valued_column)
16 | * [`layers.sparse_column_with_keys`](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/sparse_column_with_keys)
17 | * Production quality:
18 | * Complete test coverage.
19 | * Fully compliant with the [JPMML-Evaluator](https://github.com/jpmml/jpmml-evaluator) library.
20 |
21 | # Prerequisites #
22 |
23 | ### The TensorFlow side of operations
24 |
25 | * Protocol Buffers 3.2.0 or newer
26 | * TensorFlow 1.1.0 or newer
27 |
28 | ### The Java side of operations
29 |
30 | * Java 1.8 or newer
31 |
32 | # Installation #
33 |
34 | Enter the project root directory and build using [Apache Maven](http://maven.apache.org/); use the `protoc.exe` system property to specify the location of the Protocol Buffers compiler:
35 | ```
36 | mvn -Dprotoc.exe=/usr/local/bin/protoc clean install
37 | ```
38 |
39 | The build produces an executable uber-JAR file `target/converter-executable-1.0-SNAPSHOT.jar`.
40 |
41 | # Usage #
42 |
43 | A typical workflow can be summarized as follows:
44 |
45 | 1. Use TensorFlow to train an estimator.
46 | 2. Export the estimator in `SavedModel` data format to a directory in a local filesystem.
47 | 3. Use the JPMML-TensorFlow command-line converter application to turn the SavedModel directory to a PMML file.
48 |
49 | ### The TensorFlow side of operations
50 |
51 | Please see the test script file [main.py](https://github.com/jpmml/jpmml-tensorflow/blob/master/src/test/resources/main.py) for sample workflows.
52 |
53 | ### The Java side of operations
54 |
55 | Converting the estimator SavedModel directory `estimator/` to a PMML file `estimator.pmml`:
56 | ```
57 | java -jar target/converter-executable-1.0-SNAPSHOT.jar --tf-savedmodel-input estimator/ --pmml-output estimator.pmml
58 | ```
59 |
60 | Getting help:
61 | ```
62 | java -jar target/converter-executable-1.0-SNAPSHOT.jar --help
63 | ```
64 |
65 | # License #
66 |
67 | JPMML-TensorFlow is licensed under the [GNU Affero General Public License (AGPL) version 3.0](http://www.gnu.org/licenses/agpl-3.0.html). Other licenses are available on request.
68 |
69 | # Additional information #
70 |
71 | Please contact [info@openscoring.io](mailto:info@openscoring.io)
72 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | 4.0.0
4 |
5 | org.jpmml
6 | jpmml-tensorflow
7 | 1.0-SNAPSHOT
8 |
9 | JPMML-TensorFlow
10 | Java library and command-line application for converting TensorFlow models to PMML
11 | https://github.com/jpmml/jpmml-tensorflow
12 |
13 |
14 |
15 | GNU Affero General Public License (AGPL) version 3.0
16 | http://www.gnu.org/licenses/agpl-3.0.html
17 | repo
18 |
19 |
20 |
21 |
22 |
23 | villu.ruusmann
24 | Villu Ruusmann
25 |
26 |
27 |
28 |
29 | scm:git:git@github.com:jpmml/jpmml-tensorflow.git
30 | scm:git:git@github.com:jpmml/jpmml-tensorflow.git
31 | git://github.com/jpmml/jpmml-tensorflow.git
32 | HEAD
33 |
34 |
35 | GitHub
36 | https://github.com/jpmml/jpmml-tensorflow/issues
37 |
38 |
39 |
40 | protoc
41 |
42 |
43 |
44 |
45 | com.beust
46 | jcommander
47 | 1.48
48 |
49 |
50 |
51 | org.jpmml
52 | jpmml-converter
53 | 1.2.5
54 |
55 |
56 | com.sun.xml.fastinfoset
57 | FastInfoset
58 |
59 |
60 | javax.xml.bind
61 | jaxb-api
62 |
63 |
64 | org.glassfish.jaxb
65 | txw2
66 |
67 |
68 | org.jvnet.staxex
69 | stax-ex
70 |
71 |
72 |
73 |
74 |
75 | org.slf4j
76 | slf4j-api
77 | 1.7.25
78 |
79 |
80 | org.slf4j
81 | slf4j-jdk14
82 | 1.7.25
83 |
84 |
85 |
86 | org.tensorflow
87 | proto
88 | [1.3.0, )
89 |
90 |
91 | org.tensorflow
92 | tensorflow
93 | [1.1.0, )
94 |
95 |
96 |
97 | junit
98 | junit
99 | 4.12
100 | test
101 |
102 |
103 |
104 | org.jpmml
105 | pmml-evaluator
106 | 1.3.8
107 | test
108 |
109 |
110 | org.jpmml
111 | pmml-evaluator-test
112 | 1.3.8
113 | test
114 |
115 |
116 |
117 |
118 |
119 |
120 | org.apache.maven.plugins
121 | maven-compiler-plugin
122 | 3.5.1
123 |
124 | 1.8
125 | 1.8
126 |
127 |
128 |
129 | org.apache.maven.plugins
130 | maven-enforcer-plugin
131 | 1.4.1
132 |
133 |
134 |
135 | enforce
136 |
137 |
138 |
139 |
140 | 1.8
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 | org.apache.maven.plugins
149 | maven-jar-plugin
150 | 3.0.2
151 |
152 |
153 |
154 | true
155 |
156 |
157 |
158 |
159 |
160 | org.apache.maven.plugins
161 | maven-shade-plugin
162 | 2.4.3
163 |
164 |
165 | package
166 |
167 | shade
168 |
169 |
170 | converter-executable-${project.version}
171 |
172 |
173 | org.jpmml.tensorflow.Main
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | org.apache.maven.plugins
182 | maven-source-plugin
183 | 3.0.1
184 |
185 |
186 | attach-sources
187 |
188 | jar
189 |
190 |
191 |
192 |
193 |
194 | org.apache.maven.plugins
195 | maven-surefire-plugin
196 | 2.19.1
197 |
198 | ${jacoco.agent}
199 | false
200 |
201 |
202 |
203 | org.jacoco
204 | jacoco-maven-plugin
205 | 0.7.9
206 |
207 |
208 | pre-unit-test
209 |
210 | prepare-agent
211 |
212 |
213 | jacoco.agent
214 |
215 |
216 |
217 | post-unit-test
218 | prepare-package
219 |
220 | report
221 |
222 |
223 |
224 |
225 |
226 | org.xolstice.maven.plugins
227 | protobuf-maven-plugin
228 | 0.5.0
229 |
230 | ${protoc.exe}
231 |
232 |
233 |
234 |
235 | compile
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/DNNClassifier.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import java.util.ArrayList;
22 | import java.util.Arrays;
23 | import java.util.List;
24 |
25 | import com.google.common.collect.Iterables;
26 | import org.dmg.pmml.DataField;
27 | import org.dmg.pmml.DataType;
28 | import org.dmg.pmml.FieldName;
29 | import org.dmg.pmml.MiningFunction;
30 | import org.dmg.pmml.OpType;
31 | import org.dmg.pmml.neural_network.Connection;
32 | import org.dmg.pmml.neural_network.NeuralLayer;
33 | import org.dmg.pmml.neural_network.NeuralNetwork;
34 | import org.dmg.pmml.neural_network.Neuron;
35 | import org.jpmml.converter.CategoricalLabel;
36 | import org.jpmml.converter.ModelUtil;
37 | import org.jpmml.converter.ValueUtil;
38 | import org.jpmml.converter.neural_network.NeuralNetworkUtil;
39 |
40 | public class DNNClassifier extends DNNEstimator {
41 |
42 | public DNNClassifier(SavedModel savedModel, String head){
43 | super(savedModel, head);
44 | }
45 |
46 | @Override
47 | public NeuralNetwork encodeModel(TensorFlowEncoder encoder){
48 | DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CATEGORICAL, DataType.INTEGER);
49 |
50 | NeuralNetwork neuralNetwork = encodeNeuralNetwork(encoder);
51 |
52 | List neuralLayers = neuralNetwork.getNeuralLayers();
53 |
54 | NeuralLayer neuralLayer = Iterables.getLast(neuralLayers);
55 |
56 | List neurons = neuralLayer.getNeurons();
57 |
58 | List categories;
59 |
60 | if(neurons.size() == 1){
61 | neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
62 |
63 | Neuron neuron = Iterables.getOnlyElement(neurons);
64 |
65 | neuralLayer = new NeuralLayer()
66 | .setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
67 |
68 | categories = Arrays.asList("0", "1");
69 |
70 | // p(no event) = 1 - p(event)
71 | Neuron passiveNeuron = new Neuron()
72 | .setId(String.valueOf(neuralLayers.size() + 1) + "/" + categories.get(0))
73 | .setBias(ValueUtil.floatToDouble(1f))
74 | .addConnections(new Connection(neuron.getId(), -1f));
75 |
76 | // p(event)
77 | Neuron activeNeuron = new Neuron()
78 | .setId(String.valueOf(neuralLayers.size() + 1) + "/" + categories.get(1))
79 | .setBias(null)
80 | .addConnections(new Connection(neuron.getId(), 1f));
81 |
82 | neuralLayer.addNeurons(passiveNeuron, activeNeuron);
83 |
84 | neuralNetwork.addNeuralLayers(neuralLayer);
85 |
86 | neurons = neuralLayer.getNeurons();
87 | } else
88 |
89 | if(neurons.size() > 2){
90 | neuralLayer
91 | .setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY)
92 | .setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
93 |
94 | categories = new ArrayList<>();
95 |
96 | for(int i = 0; i < neurons.size(); i++){
97 | String category = String.valueOf(i);
98 |
99 | categories.add(category);
100 | }
101 | } else
102 |
103 | {
104 | throw new IllegalArgumentException();
105 | }
106 |
107 | dataField = encoder.toCategorical(dataField.getName(), categories);
108 |
109 | CategoricalLabel categoricalLabel = new CategoricalLabel(dataField);
110 |
111 | neuralNetwork
112 | .setMiningFunction(MiningFunction.CLASSIFICATION)
113 | .setMiningSchema(ModelUtil.createMiningSchema(categoricalLabel))
114 | .setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(neurons, categoricalLabel))
115 | .setOutput(ModelUtil.createProbabilityOutput(DataType.FLOAT, categoricalLabel));
116 |
117 | return neuralNetwork;
118 | }
119 |
120 | public static final String BINARY_LOGISTIC_HEAD = "dnn/binary_logistic_head/predictions/probabilities";
121 | public static final String MULTI_CLASS_HEAD = "dnn/multi_class_head/predictions/probabilities";
122 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/DNNEstimator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import java.util.ArrayList;
22 | import java.util.List;
23 | import java.util.Map;
24 |
25 | import com.google.common.collect.Lists;
26 | import com.google.common.primitives.Floats;
27 | import org.dmg.pmml.DataType;
28 | import org.dmg.pmml.Entity;
29 | import org.dmg.pmml.MathContext;
30 | import org.dmg.pmml.neural_network.NeuralInputs;
31 | import org.dmg.pmml.neural_network.NeuralLayer;
32 | import org.dmg.pmml.neural_network.NeuralNetwork;
33 | import org.dmg.pmml.neural_network.Neuron;
34 | import org.jpmml.converter.BinaryFeature;
35 | import org.jpmml.converter.CMatrixUtil;
36 | import org.jpmml.converter.Feature;
37 | import org.jpmml.converter.ValueUtil;
38 | import org.jpmml.converter.neural_network.NeuralNetworkUtil;
39 | import org.tensorflow.Operation;
40 | import org.tensorflow.Output;
41 | import org.tensorflow.Tensor;
42 | import org.tensorflow.framework.NodeDef;
43 |
44 | abstract
45 | public class DNNEstimator extends Estimator {
46 |
47 | public DNNEstimator(SavedModel savedModel, String head){
48 | super(savedModel, head);
49 | }
50 |
51 | protected NeuralNetwork encodeNeuralNetwork(TensorFlowEncoder encoder){
52 | SavedModel savedModel = getSavedModel();
53 |
54 | NeuralNetwork neuralNetwork = new NeuralNetwork()
55 | .setActivationFunction(NeuralNetwork.ActivationFunction.RECTIFIER)
56 | .setMathContext(MathContext.FLOAT);
57 |
58 | List biasAdds = Lists.newArrayList(savedModel.getInputs(getHead(), "BiasAdd"));
59 |
60 | biasAdds = Lists.reverse(biasAdds);
61 |
62 | List extends Entity> entities;
63 |
64 | {
65 | NodeDef biasAdd = biasAdds.get(0);
66 |
67 | NodeDef matMul = savedModel.getNodeDef(biasAdd.getInput(0));
68 | if(!("MatMul").equals(matMul.getOp())){
69 | throw new IllegalArgumentException();
70 | }
71 |
72 | NodeDef concat = savedModel.getNodeDef(matMul.getInput(0));
73 | if(!("ConcatV2").equals(concat.getOp())){
74 | throw new IllegalArgumentException();
75 | }
76 |
77 | List features = new ArrayList<>();
78 |
79 | List inputNames = concat.getInputList();
80 | for(int i = 0; i < inputNames.size() - 1; i++){
81 | String inputName = inputNames.get(i);
82 |
83 | NodeDef term = savedModel.getNodeDef(inputName);
84 |
85 | // "real_valued_column"
86 | if(("Cast").equals(term.getOp()) || ("Placeholder").equals(term.getOp())){
87 | NodeDef placeholder = term;
88 |
89 | Feature feature = encoder.createContinuousFeature(savedModel, placeholder);
90 |
91 | features.add(feature);
92 | } else
93 |
94 | // "one_hot_column(sparse_column_with_keys)"
95 | if(("Sum").equals(term.getOp())){
96 | NodeDef oneHot = savedModel.getOnlyInput(term.getInput(0), "OneHot");
97 |
98 | NodeDef placeholder = savedModel.getOnlyInput(oneHot.getInput(0), "Placeholder");
99 | NodeDef findTable = savedModel.getOnlyInput(oneHot.getInput(0), "LookupTableFind");
100 |
101 | Map, ?> table = savedModel.getTable(findTable.getInput(0));
102 |
103 | List categories = (List)new ArrayList<>(table.keySet());
104 |
105 | List binaryFeatures = encoder.createBinaryFeatures(savedModel, placeholder, categories);
106 |
107 | features.addAll(binaryFeatures);
108 | } else
109 |
110 | {
111 | throw new IllegalArgumentException(term.getName());
112 | }
113 | }
114 |
115 | NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.FLOAT);
116 |
117 | neuralNetwork.setNeuralInputs(neuralInputs);
118 |
119 | entities = neuralInputs.getNeuralInputs();
120 | }
121 |
122 | for(int i = 0; i < biasAdds.size(); i++){
123 | NodeDef biasAdd = biasAdds.get(i);
124 |
125 | NodeDef matMul = savedModel.getNodeDef(biasAdd.getInput(0));
126 | if(!("MatMul").equals(matMul.getOp())){
127 | throw new IllegalArgumentException();
128 | }
129 |
130 | int count;
131 |
132 | {
133 | Operation operation = savedModel.getOperation(matMul.getName());
134 |
135 | Output output = operation.output(0);
136 |
137 | long[] shape = ShapeUtil.toArray(output.shape());
138 | if(shape.length != 2 || shape[0] != -1){
139 | throw new IllegalArgumentException();
140 | }
141 |
142 | count = (int)shape[1];
143 | }
144 |
145 | NodeDef weights = savedModel.getOnlyInput(matMul.getInput(1), "VariableV2");
146 |
147 | float[] weightValues;
148 |
149 | try(Tensor tensor = savedModel.run(weights.getName())){
150 | weightValues = TensorUtil.toFloatArray(tensor);
151 | }
152 |
153 | NodeDef bias = savedModel.getOnlyInput(biasAdd.getInput(1), "VariableV2");
154 |
155 | float[] biasValues;
156 |
157 | try(Tensor tensor = savedModel.run(bias.getName())){
158 | biasValues = TensorUtil.toFloatArray(tensor);
159 | }
160 |
161 | NeuralLayer neuralLayer = new NeuralLayer();
162 |
163 | for(int j = 0; j < count; j++){
164 | List entityWeights = CMatrixUtil.getColumn(Floats.asList(weightValues), entities.size(), count, j);
165 |
166 | Neuron neuron = NeuralNetworkUtil.createNeuron(entities, ValueUtil.floatsToDoubles(entityWeights), ValueUtil.floatToDouble(biasValues[j]))
167 | .setId(String.valueOf(i + 1) + "/" + String.valueOf(j + 1));
168 |
169 | neuralLayer.addNeurons(neuron);
170 | }
171 |
172 | neuralNetwork.addNeuralLayers(neuralLayer);
173 |
174 | entities = neuralLayer.getNeurons();
175 | }
176 |
177 | return neuralNetwork;
178 | }
179 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/DNNRegressor.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import java.util.List;
22 |
23 | import com.google.common.collect.Iterables;
24 | import org.dmg.pmml.DataField;
25 | import org.dmg.pmml.DataType;
26 | import org.dmg.pmml.FieldName;
27 | import org.dmg.pmml.MiningFunction;
28 | import org.dmg.pmml.OpType;
29 | import org.dmg.pmml.neural_network.NeuralLayer;
30 | import org.dmg.pmml.neural_network.NeuralNetwork;
31 | import org.dmg.pmml.neural_network.Neuron;
32 | import org.jpmml.converter.ContinuousLabel;
33 | import org.jpmml.converter.ModelUtil;
34 | import org.jpmml.converter.neural_network.NeuralNetworkUtil;
35 |
36 | public class DNNRegressor extends DNNEstimator {
37 |
38 | public DNNRegressor(SavedModel savedModel, String head){
39 | super(savedModel, head);
40 | }
41 |
42 | @Override
43 | public NeuralNetwork encodeModel(TensorFlowEncoder encoder){
44 | DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CONTINUOUS, DataType.FLOAT);
45 |
46 | NeuralNetwork neuralNetwork = encodeNeuralNetwork(encoder);
47 |
48 | List neuralLayers = neuralNetwork.getNeuralLayers();
49 |
50 | NeuralLayer neuralLayer = Iterables.getLast(neuralLayers);
51 |
52 | neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
53 |
54 | List neurons = neuralLayer.getNeurons();
55 |
56 | ContinuousLabel continuousLabel = new ContinuousLabel(dataField);
57 |
58 | neuralNetwork
59 | .setMiningFunction(MiningFunction.REGRESSION)
60 | .setMiningSchema(ModelUtil.createMiningSchema(continuousLabel))
61 | .setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs(neurons, continuousLabel));
62 |
63 | return neuralNetwork;
64 | }
65 |
66 | public static final String REGRESSION_HEAD = "dnn/regression_head/predictions/scores";
67 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/Estimator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import org.dmg.pmml.Model;
22 | import org.dmg.pmml.PMML;
23 |
24 | abstract
25 | public class Estimator {
26 |
27 | private SavedModel savedModel = null;
28 |
29 | private String head = null;
30 |
31 |
32 | public Estimator(SavedModel savedModel, String head){
33 | setSavedModel(savedModel);
34 | setHead(head);
35 | }
36 |
37 | abstract
38 | public Model encodeModel(TensorFlowEncoder encoder);
39 |
40 | public PMML encodePMML(){
41 | TensorFlowEncoder encoder = new TensorFlowEncoder();
42 |
43 | Model model = encodeModel(encoder);
44 |
45 | PMML pmml = encoder.encodePMML(model);
46 |
47 | return pmml;
48 | }
49 |
50 | public SavedModel getSavedModel(){
51 | return this.savedModel;
52 | }
53 |
54 | private void setSavedModel(SavedModel savedModel){
55 | this.savedModel = savedModel;
56 | }
57 |
58 | public String getHead(){
59 | return this.head;
60 | }
61 |
62 | private void setHead(String head){
63 | this.head = head;
64 | }
65 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/EstimatorFactory.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import java.util.Map;
22 |
23 | import org.tensorflow.framework.NodeDef;
24 |
25 | public class EstimatorFactory {
26 |
27 | protected EstimatorFactory(){
28 | }
29 |
30 | public Estimator newEstimator(SavedModel savedModel){
31 | Map nodeMap = savedModel.getNodeMap();
32 |
33 | if(nodeMap.containsKey(DNNClassifier.BINARY_LOGISTIC_HEAD)){
34 | return new DNNClassifier(savedModel, DNNClassifier.BINARY_LOGISTIC_HEAD);
35 | } else
36 |
37 | if(nodeMap.containsKey(DNNClassifier.MULTI_CLASS_HEAD)){
38 | return new DNNClassifier(savedModel, DNNClassifier.MULTI_CLASS_HEAD);
39 | } else
40 |
41 | if(nodeMap.containsKey(DNNRegressor.REGRESSION_HEAD)){
42 | return new DNNRegressor(savedModel, DNNRegressor.REGRESSION_HEAD);
43 | } else
44 |
45 | if(nodeMap.containsKey(LinearClassifier.BINARY_LOGISTIC_HEAD)){
46 | return new LinearClassifier(savedModel, LinearClassifier.BINARY_LOGISTIC_HEAD);
47 | } else
48 |
49 | if(nodeMap.containsKey(LinearClassifier.MULTI_CLASS_HEAD)){
50 | return new LinearClassifier(savedModel, LinearClassifier.MULTI_CLASS_HEAD);
51 | } else
52 |
53 | if(nodeMap.containsKey(LinearRegressor.REGRESSION_HEAD)){
54 | return new LinearRegressor(savedModel, LinearRegressor.REGRESSION_HEAD);
55 | } else
56 |
57 | {
58 | throw new IllegalArgumentException();
59 | }
60 | }
61 |
62 | static
63 | public EstimatorFactory newInstance(){
64 | return new EstimatorFactory();
65 | }
66 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/LinearClassifier.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import java.util.ArrayList;
22 | import java.util.Arrays;
23 | import java.util.List;
24 |
25 | import org.dmg.pmml.DataField;
26 | import org.dmg.pmml.DataType;
27 | import org.dmg.pmml.FieldName;
28 | import org.dmg.pmml.MiningFunction;
29 | import org.dmg.pmml.OpType;
30 | import org.dmg.pmml.regression.RegressionModel;
31 | import org.dmg.pmml.regression.RegressionTable;
32 | import org.jpmml.converter.CategoricalLabel;
33 | import org.jpmml.converter.ModelUtil;
34 |
35 | public class LinearClassifier extends LinearEstimator {
36 |
37 | public LinearClassifier(SavedModel savedModel, String head){
38 | super(savedModel, head);
39 | }
40 |
41 | @Override
42 | public RegressionModel encodeModel(TensorFlowEncoder encoder){
43 | DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CATEGORICAL, DataType.INTEGER);
44 |
45 | RegressionModel regressionModel = encodeRegressionModel(encoder);
46 |
47 | List regressionTables = regressionModel.getRegressionTables();
48 |
49 | List categories;
50 |
51 | if(regressionTables.size() == 1){
52 | categories = Arrays.asList("0", "1");
53 |
54 | RegressionTable activeRegressionTable = regressionTables.get(0)
55 | .setTargetCategory(categories.get(1));
56 |
57 | RegressionTable passiveRegressionTable = new RegressionTable(0)
58 | .setTargetCategory(categories.get(0));
59 |
60 | regressionModel.addRegressionTables(passiveRegressionTable);
61 | } else
62 |
63 | if(regressionTables.size() > 2){
64 | categories = new ArrayList<>();
65 |
66 | for(int i = 0; i < regressionTables.size(); i++){
67 | RegressionTable regressionTable = regressionTables.get(i);
68 | String category = String.valueOf(i);
69 |
70 | regressionTable.setTargetCategory(category);
71 |
72 | categories.add(category);
73 | }
74 | } else
75 |
76 | {
77 | throw new IllegalArgumentException();
78 | }
79 |
80 | dataField = encoder.toCategorical(dataField.getName(), categories);
81 |
82 | CategoricalLabel categoricalLabel = new CategoricalLabel(dataField);
83 |
84 | regressionModel
85 | .setMiningFunction(MiningFunction.CLASSIFICATION)
86 | .setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX)
87 | .setMiningSchema(ModelUtil.createMiningSchema(categoricalLabel))
88 | .setOutput(ModelUtil.createProbabilityOutput(DataType.FLOAT, categoricalLabel));
89 |
90 | return regressionModel;
91 | }
92 |
93 | public static final String BINARY_LOGISTIC_HEAD = "linear/binary_logistic_head/predictions/probabilities";
94 | public static final String MULTI_CLASS_HEAD = "linear/multi_class_head/predictions/probabilities";
95 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/LinearEstimator.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import java.util.ArrayList;
22 | import java.util.List;
23 | import java.util.Map;
24 |
25 | import com.google.common.primitives.Floats;
26 | import org.dmg.pmml.MathContext;
27 | import org.dmg.pmml.regression.RegressionModel;
28 | import org.dmg.pmml.regression.RegressionTable;
29 | import org.jpmml.converter.CMatrixUtil;
30 | import org.jpmml.converter.Feature;
31 | import org.jpmml.converter.ValueUtil;
32 | import org.jpmml.converter.regression.RegressionModelUtil;
33 | import org.tensorflow.Operation;
34 | import org.tensorflow.Output;
35 | import org.tensorflow.Tensor;
36 | import org.tensorflow.framework.NodeDef;
37 |
38 | abstract
39 | public class LinearEstimator extends Estimator {
40 |
41 | public LinearEstimator(SavedModel savedModel, String head){
42 | super(savedModel, head);
43 | }
44 |
45 | public RegressionModel encodeRegressionModel(TensorFlowEncoder encoder){
46 | SavedModel savedModel = getSavedModel();
47 |
48 | NodeDef biasAdd = savedModel.getOnlyInput(getHead(), "BiasAdd");
49 |
50 | int count;
51 |
52 | {
53 | Operation operation = savedModel.getOperation(biasAdd.getName());
54 |
55 | Output output = operation.output(0);
56 |
57 | long[] shape = ShapeUtil.toArray(output.shape());
58 | if((shape.length != 2) || (shape[0] != -1)){
59 | throw new IllegalArgumentException();
60 | }
61 |
62 | count = (int)shape[1];
63 | }
64 |
65 | List equations = new ArrayList<>();
66 |
67 | for(int i = 0; i < count; i++){
68 | Equation equation = new Equation();
69 |
70 | equations.add(equation);
71 | }
72 |
73 | NodeDef addN = savedModel.getOnlyInput(biasAdd.getInput(0), "AddN");
74 |
75 | List inputNames = addN.getInputList();
76 | for(String inputName : inputNames){
77 | NodeDef term = savedModel.getOnlyInput(inputName, "MatMul", "Select");
78 |
79 | // "real_valued_column"
80 | if(("MatMul").equals(term.getOp())){
81 | NodeDef placeholder = savedModel.getNodeDef(term.getInput(0));
82 | NodeDef multiplier = savedModel.getOnlyInput(term.getInput(1), "VariableV2");
83 |
84 | Feature feature = encoder.createContinuousFeature(savedModel, placeholder);
85 |
86 | try(Tensor tensor = savedModel.run(multiplier.getName())){
87 | float[] values = TensorUtil.toFloatArray(tensor);
88 |
89 | for(int i = 0; i < count; i++){
90 | Equation equation = equations.get(i);
91 |
92 | equation.addTerm(feature, ValueUtil.floatToDouble(values[i]));
93 | }
94 | }
95 | } else
96 |
97 | // "sparse_column_with_keys"
98 | if(("Select").equals(term.getOp())){
99 | NodeDef placeholder = savedModel.getOnlyInput(term.getInput(0), "Placeholder");
100 | NodeDef findTable = savedModel.getOnlyInput(term.getInput(1), "LookupTableFind");
101 | NodeDef multiplier = savedModel.getOnlyInput(term.getInput(2), "VariableV2");
102 |
103 | Map, ?> table = savedModel.getTable(findTable.getInput(0));
104 |
105 | List categories = (List)new ArrayList<>(table.keySet());
106 |
107 | List extends Feature> features = encoder.createBinaryFeatures(savedModel, placeholder, categories);
108 |
109 | float[] values;
110 |
111 | try(Tensor tensor = savedModel.run(multiplier.getName())){
112 | values = TensorUtil.toFloatArray(tensor);
113 | }
114 |
115 | for(int i = 0; i < equations.size(); i++){
116 | Equation equation = equations.get(i);
117 |
118 | List categoryValues = CMatrixUtil.getColumn(Floats.asList(values), features.size(), equations.size(), i);
119 |
120 | for(int j = 0; j < features.size(); j++){
121 | Feature feature = features.get(j);
122 |
123 | int index = ValueUtil.asInt((Number)table.get(categories.get(j)));
124 |
125 | equation.addTerm(feature, ValueUtil.floatToDouble(categoryValues.get(index)));
126 | }
127 | }
128 | } else
129 |
130 | {
131 | throw new IllegalArgumentException(term.getName());
132 | }
133 | }
134 |
135 | NodeDef bias = savedModel.getOnlyInput(biasAdd.getInput(1), "VariableV2");
136 |
137 | try(Tensor tensor = savedModel.run(bias.getName())){
138 | float[] values = TensorUtil.toFloatArray(tensor);
139 |
140 | for(int i = 0; i < count; i++){
141 | Equation equation = equations.get(i);
142 |
143 | equation.setIntercept(ValueUtil.floatToDouble(values[i]));
144 | }
145 | }
146 |
147 | RegressionModel regressionModel = new RegressionModel()
148 | .setMathContext(MathContext.FLOAT);
149 |
150 | for(Equation equation : equations){
151 | RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(equation.getFeatures(), equation.getCoefficients(), equation.getIntercept());
152 |
153 | regressionModel.addRegressionTables(regressionTable);
154 | }
155 |
156 | return regressionModel;
157 | }
158 |
159 | static
160 | private class Equation {
161 |
162 | private List features = new ArrayList<>();
163 |
164 | private List coefficients = new ArrayList<>();
165 |
166 | private Double intercept = null;
167 |
168 |
169 | private Equation(){
170 | }
171 |
172 | public void addTerm(Feature feature, Double coefficient){
173 | this.features.add(feature);
174 | this.coefficients.add(coefficient);
175 | }
176 |
177 | public List getFeatures(){
178 | return this.features;
179 | }
180 |
181 | public List getCoefficients(){
182 | return this.coefficients;
183 | }
184 |
185 | public Double getIntercept(){
186 | return this.intercept;
187 | }
188 |
189 | public void setIntercept(Double intercept){
190 | this.intercept = intercept;
191 | }
192 | }
193 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/LinearRegressor.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import org.dmg.pmml.DataField;
22 | import org.dmg.pmml.DataType;
23 | import org.dmg.pmml.FieldName;
24 | import org.dmg.pmml.MiningFunction;
25 | import org.dmg.pmml.OpType;
26 | import org.dmg.pmml.regression.RegressionModel;
27 | import org.jpmml.converter.ContinuousLabel;
28 | import org.jpmml.converter.Label;
29 | import org.jpmml.converter.ModelUtil;
30 |
31 | public class LinearRegressor extends LinearEstimator {
32 |
33 | public LinearRegressor(SavedModel savedModel, String head){
34 | super(savedModel, head);
35 | }
36 |
37 | @Override
38 | public RegressionModel encodeModel(TensorFlowEncoder encoder){
39 | DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CONTINUOUS, DataType.FLOAT);
40 |
41 | Label label = new ContinuousLabel(dataField);
42 |
43 | RegressionModel regressionModel = encodeRegressionModel(encoder)
44 | .setMiningFunction(MiningFunction.REGRESSION)
45 | .setMiningSchema(ModelUtil.createMiningSchema(label));
46 |
47 | return regressionModel;
48 | }
49 |
50 | public static final String REGRESSION_HEAD = "linear/regression_head/predictions/scores";
51 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/Main.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import java.io.File;
22 | import java.io.FileOutputStream;
23 | import java.io.OutputStream;
24 |
25 | import com.beust.jcommander.JCommander;
26 | import com.beust.jcommander.Parameter;
27 | import com.beust.jcommander.ParameterException;
28 | import org.dmg.pmml.PMML;
29 | import org.jpmml.model.MetroJAXBUtil;
30 | import org.slf4j.Logger;
31 | import org.slf4j.LoggerFactory;
32 | import org.tensorflow.SavedModelBundle;
33 |
34 | public class Main {
35 |
36 | @Parameter (
37 | names = "--help",
38 | description = "Show the list of configuration options and exit",
39 | help = true
40 | )
41 | private boolean help = false;
42 |
43 | @Parameter (
44 | names = {"--tf-input", "--tf-savedmodel-input"},
45 | description = "TF SavedModel input directory",
46 | required = true
47 | )
48 | private File input = null;
49 |
50 | @Parameter (
51 | names = "--pmml-output",
52 | description = "PMML output file",
53 | required = true
54 | )
55 | private File output = null;
56 |
57 |
58 | static
59 | public void main(String[] args) throws Exception {
60 | Main main = new Main();
61 |
62 | JCommander commander = new JCommander(main);
63 | commander.setProgramName(Main.class.getName());
64 |
65 | try {
66 | commander.parse(args);
67 | } catch(ParameterException pe){
68 | StringBuilder sb = new StringBuilder();
69 |
70 | sb.append(pe.toString());
71 | sb.append("\n");
72 |
73 | commander.usage(sb);
74 |
75 | System.err.println(sb.toString());
76 |
77 | System.exit(-1);
78 | }
79 |
80 | if(main.help){
81 | StringBuilder sb = new StringBuilder();
82 |
83 | commander.usage(sb);
84 |
85 | System.out.println(sb.toString());
86 |
87 | System.exit(0);
88 | }
89 |
90 | main.run();
91 | }
92 |
93 | private void run() throws Exception {
94 | SavedModelBundle bundle;
95 |
96 | try {
97 | logger.info("Parsing SavedModel..");
98 |
99 | long begin = System.currentTimeMillis();
100 | bundle = SavedModelBundle.load(this.input.getAbsolutePath(), "serve");
101 | long end = System.currentTimeMillis();
102 |
103 | logger.info("Parsed SavedModel in {} ms.", (end - begin));
104 | } catch(Exception e){
105 | logger.error("Failed to parse SavedModel", e);
106 |
107 | throw e;
108 | }
109 |
110 | PMML pmml;
111 |
112 | try(SavedModel savedModel = new SavedModel(bundle)){
113 | logger.info("Converting..");
114 |
115 | EstimatorFactory estimatorFactory = EstimatorFactory.newInstance();
116 |
117 | Estimator estimator = estimatorFactory.newEstimator(savedModel);
118 |
119 | long begin = System.currentTimeMillis();
120 | pmml = estimator.encodePMML();
121 | long end = System.currentTimeMillis();
122 |
123 | logger.info("Converted in {} ms.", (end - begin));
124 | } catch(Exception e){
125 | logger.error("Failed to convert", e);
126 |
127 | throw e;
128 | }
129 |
130 | try(OutputStream os = new FileOutputStream(this.output)){
131 | logger.info("Marshalling PMML..");
132 |
133 | long begin = System.currentTimeMillis();
134 | MetroJAXBUtil.marshalPMML(pmml, os);
135 | long end = System.currentTimeMillis();
136 |
137 | logger.info("Marshalled PMML in {}", (end - begin));
138 | } catch(Exception e){
139 | logger.error("Failed to marshal PMML", e);
140 |
141 | throw e;
142 | }
143 | }
144 |
145 | private static final Logger logger = LoggerFactory.getLogger(Main.class);
146 | }
--------------------------------------------------------------------------------
/src/main/java/org/jpmml/tensorflow/SavedModel.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2017 Villu Ruusmann
3 | *
4 | * This file is part of JPMML-TensorFlow
5 | *
6 | * JPMML-TensorFlow is free software: you can redistribute it and/or modify
7 | * it under the terms of the GNU Affero General Public License as published by
8 | * the Free Software Foundation, either version 3 of the License, or
9 | * (at your option) any later version.
10 | *
11 | * JPMML-TensorFlow is distributed in the hope that it will be useful,
12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | * GNU Affero General Public License for more details.
15 | *
16 | * You should have received a copy of the GNU Affero General Public License
17 | * along with JPMML-TensorFlow. If not, see .
18 | */
19 | package org.jpmml.tensorflow;
20 |
21 | import java.util.ArrayDeque;
22 | import java.util.Arrays;
23 | import java.util.Collection;
24 | import java.util.Collections;
25 | import java.util.Deque;
26 | import java.util.HashSet;
27 | import java.util.LinkedHashMap;
28 | import java.util.LinkedHashSet;
29 | import java.util.List;
30 | import java.util.Map;
31 | import java.util.Set;
32 |
33 | import com.google.common.base.Function;
34 | import com.google.common.collect.Iterables;
35 | import com.google.protobuf.InvalidProtocolBufferException;
36 | import org.tensorflow.Graph;
37 | import org.tensorflow.Operation;
38 | import org.tensorflow.SavedModelBundle;
39 | import org.tensorflow.Session;
40 | import org.tensorflow.Session.Runner;
41 | import org.tensorflow.Tensor;
42 | import org.tensorflow.framework.CollectionDef;
43 | import org.tensorflow.framework.GraphDef;
44 | import org.tensorflow.framework.MetaGraphDef;
45 | import org.tensorflow.framework.NodeDef;
46 |
47 | public class SavedModel implements AutoCloseable {
48 |
49 | private SavedModelBundle bundle = null;
50 |
51 | private MetaGraphDef metaGraphDef = null;
52 |
53 | private Map nodeMap = null;
54 |
55 | private Map> tableMap = new LinkedHashMap<>();
56 |
57 |
58 | public SavedModel(SavedModelBundle bundle) throws InvalidProtocolBufferException {
59 | setBundle(bundle);
60 |
61 | byte[] metaGraphDefBytes = bundle.metaGraphDef();
62 |
63 | MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(metaGraphDefBytes);
64 |
65 | setMetaGraphDef(metaGraphDef);
66 |
67 | GraphDef graphDef = metaGraphDef.getGraphDef();
68 |
69 | Map nodeMap = new LinkedHashMap<>();
70 |
71 | List nodeDefs = graphDef.getNodeList();
72 | for(NodeDef nodeDef : nodeDefs){
73 | nodeMap.put(nodeDef.getName(), nodeDef);
74 | }
75 |
76 | setNodeMap(nodeMap);
77 |
78 | initializeTables();
79 | }
80 |
81 | private void initializeTables(){
82 | Collection tableInitializerNames = Collections.emptyList();
83 |
84 | try {
85 | CollectionDef collectionDef = getCollectionDef("table_initializer");
86 |
87 | CollectionDef.NodeList nodeList = collectionDef.getNodeList();
88 |
89 | tableInitializerNames = nodeList.getValueList();
90 | } catch(IllegalArgumentException iae){
91 | // Ignored
92 | }
93 |
94 | for(String tableInitializerName : tableInitializerNames){
95 | NodeDef tableInitializer = getNodeDef(tableInitializerName);
96 |
97 | String name = tableInitializer.getInput(0);
98 |
99 | List> keys;
100 | List> values;
101 |
102 | try(Tensor tensor = run(tableInitializer.getInput(1))){
103 | keys = TensorUtil.getValues(tensor);
104 | } // End try
105 |
106 | try(Tensor tensor = run(tableInitializer.getInput(2))){
107 | values = TensorUtil.getValues(tensor);
108 | }
109 |
110 | Map