├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── build.sbt
├── pom.xml
├── project
├── build.properties
└── plugins.sbt
└── src
├── main
├── resources
│ └── META-INF
│ │ └── services
│ │ └── org.apache.spark.sql.sources.DataSourceRegister
└── scala
│ └── org
│ └── trustedanalytics
│ └── spark
│ └── datasources
│ └── tensorflow
│ ├── DataTypesConvertor.scala
│ ├── DefaultSource.scala
│ ├── TensorflowInferSchema.scala
│ ├── TensorflowRelation.scala
│ └── serde
│ ├── DefaultTfRecordRowDecoder.scala
│ ├── DefaultTfRecordRowEncoder.scala
│ ├── FeatureDecoder.scala
│ └── FeatureEncoder.scala
└── test
└── scala
└── org
└── trustedanalytics
└── spark
└── datasources
└── tensorflow
├── SharedSparkSessionSuite.scala
├── TensorflowSuite.scala
└── serde
├── FeatureDecoderTest.scala
└── FeatureEncoderTest.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.iml
3 | *.jar
4 | *.log
5 | target
6 | tf-sandbox
7 | spark-warehouse/
8 | metastore_db/
9 | project/project/
10 | test-output.tfr
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: scala
2 |
3 | # Cache settings here are based on latest SBT documentation.
4 | cache:
5 | directories:
6 | - $HOME/.ivy2/cache
7 | - $HOME/.sbt/boot/
8 |
9 | before_cache:
10 | # Tricks to avoid unnecessary cache updates
11 | - find $HOME/.ivy2 -name "ivydata-*.properties" -delete
12 | - find $HOME/.sbt -name "*.lock" -delete
13 |
14 | scala:
15 | - 2.11.8
16 |
17 | jdk:
18 | - oraclejdk8
19 |
20 | script:
21 | - sbt ++$TRAVIS_SCALA_VERSION clean publish-local
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/tapanalyticstoolkit/spark-tensorflow-connector)
2 |
3 | # spark-tensorflow-connector
4 |
5 | __NOTE: This repo has been contributed to the TensorFlow ecosystem, and is no longer maintained here. Please go to [spark-tensorflow-connector](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector) in the TensorFlow ecosystem for the latest version.__
6 |
7 |
8 | This repo contains a library for loading and storing TensorFlow records with [Apache Spark](http://spark.apache.org/).
9 | The library implements data import from the standard TensorFlow record format ([TFRecords]
10 | (https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records.
11 |
12 | ## What's new
13 |
14 | This is the initial release of the `spark-tensorflow-connector` repo.
15 |
16 | ## Known issues
17 |
18 | None.
19 |
20 | ## Prerequisites
21 |
22 | 1. [Apache Spark 2.0 (or later)](http://spark.apache.org/)
23 |
24 | 2. [Apache Maven](https://maven.apache.org/)
25 |
26 | ## Building the library
27 | You can build library using both Maven and SBT build tools
28 |
29 | #### Maven
30 | Build the library using Maven(3.3) as shown below
31 |
32 | ```sh
33 | mvn clean install
34 | ```
35 |
36 | #### SBT
37 | Build the library using SBT(0.13.13) as show below
38 | ```sh
39 | sbt clean assembly
40 | ```
41 |
42 | ## Using Spark Shell
43 | Run this library in Spark using the `--jars` command line option in `spark-shell` or `spark-submit`. For example:
44 |
45 | Maven Jars
46 | ```sh
47 | $SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar
48 | ```
49 |
50 | SBT Jars
51 | ```sh
52 | $SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar
53 | ```
54 |
55 | The following code snippet demonstrates usage.
56 |
57 | ```scala
58 | import org.apache.commons.io.FileUtils
59 | import org.apache.spark.sql.{ DataFrame, Row }
60 | import org.apache.spark.sql.catalyst.expressions.GenericRow
61 | import org.apache.spark.sql.types._
62 |
63 | val path = "test-output.tfr"
64 | val testRows: Array[Row] = Array(
65 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
66 | new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
67 | val schema = StructType(List(StructField("id", IntegerType),
68 | StructField("IntegerTypelabel", IntegerType),
69 | StructField("LongTypelabel", LongType),
70 | StructField("FloatTypelabel", FloatType),
71 | StructField("DoubleTypelabel", DoubleType),
72 | StructField("vectorlabel", ArrayType(DoubleType, true)),
73 | StructField("name", StringType)))
74 |
75 | val rdd = spark.sparkContext.parallelize(testRows)
76 |
77 | //Save DataFrame as TFRecords
78 | val df: DataFrame = spark.createDataFrame(rdd, schema)
79 | df.write.format("tensorflow").save(path)
80 |
81 | //Read TFRecords into DataFrame.
82 | //The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
83 | val importedDf1: DataFrame = spark.read.format("tensorflow").load(path)
84 | importedDf1.show()
85 |
86 | //Read TFRecords into DataFrame using custom schema
87 | val importedDf2: DataFrame = spark.read.format("tensorflow").schema(schema).load(path)
88 | importedDf2.show()
89 |
90 | ```
91 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | name := "spark-tensorflow-connector"
2 |
3 | organization := "org.trustedanalytics"
4 |
5 | scalaVersion in Global := "2.11.8"
6 |
7 | spName := "tapanalyticstoolkit/spark-tensorflow-connector"
8 |
9 | sparkVersion := "2.1.0"
10 |
11 | sparkComponents ++= Seq("sql", "mllib")
12 |
13 | version := "1.0.0"
14 |
15 | def ProjectName(name: String,path:String): Project = Project(name, file(path))
16 |
17 | resolvers in Global ++= Seq("https://tap.jfrog.io/tap/public" at "https://tap.jfrog.io/tap/public" ,
18 | "https://tap.jfrog.io/tap/public-snapshots" at "https://tap.jfrog.io/tap/public-snapshots" ,
19 | "https://repo.maven.apache.org/maven2" at "https://repo.maven.apache.org/maven2" )
20 |
21 | val `junit_junit` = "junit" % "junit" % "4.12"
22 |
23 | val `org.apache.hadoop_hadoop-yarn-api` = "org.apache.hadoop" % "hadoop-yarn-api" % "2.7.3"
24 |
25 | val `org.apache.spark_spark-core_2.11` = "org.apache.spark" % "spark-core_2.11" % "2.1.0"
26 |
27 | val `org.apache.spark_spark-sql_2.11` = "org.apache.spark" % "spark-sql_2.11" % "2.1.0"
28 |
29 | val `org.apache.spark_spark-mllib_2.11` = "org.apache.spark" % "spark-mllib_2.11" % "2.1.0"
30 |
31 | val `org.scalatest_scalatest_2.11` = "org.scalatest" % "scalatest_2.11" % "2.2.6"
32 |
33 | val `org.tensorflow_tensorflow-hadoop` = "org.tensorflow" % "tensorflow-hadoop" % "1.0-01232017-SNAPSHOT"
34 |
35 | libraryDependencies in Global ++= Seq(`org.tensorflow_tensorflow-hadoop` classifier "shaded-protobuf",
36 | `org.scalatest_scalatest_2.11` % "test" ,
37 | `org.apache.spark_spark-sql_2.11` % "provided" ,
38 | `org.apache.spark_spark-mllib_2.11` % "test" classifier "tests",
39 | `org.apache.spark_spark-core_2.11` % "provided" ,
40 | `org.apache.hadoop_hadoop-yarn-api` % "provided" ,
41 | `junit_junit` % "test" )
42 |
43 | assemblyExcludedJars in assembly := {
44 | val cp = (fullClasspath in assembly).value
45 | cp filterNot {x => List("spark-tensorflow-connector-1.0-SNAPSHOT.jar",
46 | "tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar").contains(x.data.getName)}
47 | }
48 |
49 | /********************
50 | * Release settings *
51 | ********************/
52 |
53 | spIgnoreProvided := true
54 |
55 | spAppendScalaVersion := true
56 |
57 | // If you published your package to Maven Central for this release (must be done prior to spPublish)
58 | spIncludeMaven := false
59 |
60 | publishMavenStyle := true
61 |
62 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0"))
63 |
64 | pomExtra :=
65 | https://github.com/tapanalyticstoolkit/spark-tensorflow-connector
66 |
67 | git@github.com:tapanalyticstoolkit/spark-tensorflow-connector.git
68 | scm:git:git@github.com:tapanalyticstoolkit/spark-tensorflow-connector.git
69 |
70 |
71 |
72 | karthikvadla
73 | Karthik Vadla
74 | https://github.com/karthikvadla
75 |
76 |
77 | skavulya
78 | Soila Kavulya
79 | https://github.com/skavulya
80 |
81 |
82 | joyeshmishra
83 | Joyesh Mishra
84 | https://github.com/joyeshmishra
85 |
86 |
87 |
88 | credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") // A file containing credentials
89 |
90 | // Add assembly jar to Spark package
91 | test in assembly := {}
92 |
93 | spShade := true
94 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | org.trustedanalytics
8 | spark-tensorflow-connector
9 | jar
10 | 1.0-SNAPSHOT
11 |
12 |
13 |
14 | central1
15 | http://central1.maven.org/maven2
16 |
17 | true
18 |
19 |
20 | false
21 |
22 |
23 |
24 |
25 | tap
26 | https://tap.jfrog.io/tap/public
27 |
28 | false
29 |
30 |
31 | true
32 |
33 |
34 |
35 | tap-snapshots
36 | https://tap.jfrog.io/tap/public-snapshots
37 |
38 | true
39 |
40 |
41 | false
42 |
43 |
44 |
45 |
46 |
47 |
48 | compile
49 |
50 | true
51 |
52 | !NEVERSETME
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 | true
61 | net.alchim31.maven
62 | scala-maven-plugin
63 | 3.1.6
64 |
65 |
66 | compile
67 |
68 | add-source
69 | compile
70 |
71 |
72 |
73 | -Xms256m
74 | -Xmx512m
75 |
76 |
77 | -g:vars
78 | -deprecation
79 | -feature
80 | -unchecked
81 | -Xfatal-warnings
82 | -language:implicitConversions
83 | -language:existentials
84 |
85 |
86 |
87 |
88 | test
89 |
90 | add-source
91 | testCompile
92 |
93 |
94 |
95 |
96 | incremental
97 | true
98 | 2.11
99 | false
100 |
101 |
102 |
103 | org.apache.maven.plugins
104 | maven-dependency-plugin
105 |
106 |
107 | copy-dependencies
108 | process-resources
109 |
110 | copy-dependencies
111 |
112 |
113 | provided
114 | true
115 | org.apache.spark,junit,org.scalatest
116 | ${project.build.directory}/lib
117 |
118 |
119 |
120 |
121 |
122 |
123 | org.codehaus.mojo
124 | properties-maven-plugin
125 | 1.0.0
126 |
127 |
128 | generate-resources
129 |
130 | write-project-properties
131 |
132 |
133 | ${project.build.outputDirectory}/maven.properties
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 | net.alchim31.maven
143 | scala-maven-plugin
144 |
145 |
146 |
147 |
148 |
149 |
150 | test
151 |
152 | true
153 |
154 | !NEVERSETME
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 | true
163 | net.alchim31.maven
164 | scala-maven-plugin
165 | 3.2.2
166 |
167 |
168 | compile
169 |
170 |
171 |
172 |
173 | true
174 | org.scalatest
175 | scalatest-maven-plugin
176 | 1.0
177 |
178 | ${project.build.directory}/surefire-reports
179 | .
180 | WDF TestSuite.txt
181 | false
182 | FTD
183 | -Xmx1024m -XX:PermSize=256m -XX:MaxDirectMemorySize=1000m
184 |
185 |
186 |
187 | scalaTest
188 | test
189 |
190 | test
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 | net.alchim31.maven
201 | scala-maven-plugin
202 |
203 |
204 |
205 |
206 |
207 |
208 | org.scalatest
209 | scalatest_2.11
210 | 2.2.6
211 | test
212 |
213 |
214 |
215 |
216 |
217 |
218 | org.scalatest
219 | scalatest_2.11
220 | test
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 | net.alchim31.maven
231 | scala-maven-plugin
232 |
233 |
234 | org.apache.maven.plugins
235 | maven-dependency-plugin
236 |
237 |
238 | org.scalatest
239 | scalatest-maven-plugin
240 |
241 |
242 | org.apache.maven.plugins
243 | maven-compiler-plugin
244 | 3.0
245 |
246 | 1.8
247 | 1.8
248 |
249 |
250 |
251 |
252 |
253 |
254 | core/src/main/resources
255 |
256 | reference.conf
257 |
258 |
259 |
260 | core/src/test/resources
261 |
262 |
263 |
264 |
265 |
266 |
267 | org.tensorflow
268 | tensorflow-hadoop
269 | 1.0-01232017-SNAPSHOT
270 | shaded-protobuf
271 |
272 |
273 | org.apache.spark
274 | spark-core_2.11
275 | 2.1.0
276 | provided
277 |
278 |
279 | org.apache.spark
280 | spark-sql_2.11
281 | 2.1.0
282 | provided
283 |
284 |
285 | org.apache.hadoop
286 | hadoop-yarn-api
287 | 2.7.3
288 | provided
289 |
290 |
291 |
292 | org.apache.spark
293 | spark-mllib_2.11
294 | 2.1.0
295 | test-jar
296 | test
297 |
298 |
299 | junit
300 | junit
301 | 4.12
302 | test
303 |
304 |
305 |
306 |
307 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=0.13.13
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/"
2 |
3 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3")
4 |
5 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.5")
6 |
--------------------------------------------------------------------------------
/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister:
--------------------------------------------------------------------------------
1 | org.trustedanalytics.spark.datasources.tensorflow.DefaultSource
--------------------------------------------------------------------------------
/src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/DataTypesConvertor.scala:
--------------------------------------------------------------------------------
1 | package org.trustedanalytics.spark.datasources.tensorflow
2 |
3 | /**
4 | * DataTypes supported
5 | */
6 | object DataTypesConvertor {
7 |
8 | def toLong(value: Any): Long = {
9 | value match {
10 | case null => throw new IllegalArgumentException("null cannot be converted to Long")
11 | case i: Int => i.toLong
12 | case l: Long => l
13 | case f: Float => f.toLong
14 | case d: Double => d.toLong
15 | case bd: BigDecimal => bd.toLong
16 | case s: String => s.trim().toLong
17 | case _ => throw new RuntimeException(s"${value.getClass.getName} toLong is not implemented")
18 | }
19 | }
20 |
21 | def toFloat(value: Any): Float = {
22 | value match {
23 | case null => throw new IllegalArgumentException("null cannot be converted to Float")
24 | case i: Int => i.toFloat
25 | case l: Long => l.toFloat
26 | case f: Float => f
27 | case d: Double => d.toFloat
28 | case bd: BigDecimal => bd.toFloat
29 | case s: String => s.trim().toFloat
30 | case _ => throw new RuntimeException(s"${value.getClass.getName} toFloat is not implemented")
31 | }
32 | }
33 | }
34 |
35 |
--------------------------------------------------------------------------------
/src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/DefaultSource.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow
17 |
18 | import org.apache.hadoop.io.{BytesWritable, NullWritable}
19 | import org.apache.spark.sql._
20 | import org.apache.spark.sql.sources._
21 | import org.apache.spark.sql.types.StructType
22 | import org.tensorflow.hadoop.io.TFRecordFileOutputFormat
23 | import org.trustedanalytics.spark.datasources.tensorflow.serde.DefaultTfRecordRowEncoder
24 |
25 | /**
26 | * Provides access to TensorFlow record source
27 | */
28 | class DefaultSource extends DataSourceRegister
29 | with CreatableRelationProvider
30 | with RelationProvider
31 | with SchemaRelationProvider{
32 |
33 | /**
34 | * Short alias for spark-tensorflow data source.
35 | */
36 | override def shortName(): String = "tensorflow"
37 |
38 | // Writes DataFrame as TensorFlow Records
39 | override def createRelation(
40 | sqlContext: SQLContext,
41 | mode: SaveMode,
42 | parameters: Map[String, String],
43 | data: DataFrame): BaseRelation = {
44 |
45 | val path = parameters("path")
46 |
47 | //Export DataFrame as TFRecords
48 | val features = data.rdd.map(row => {
49 | val example = DefaultTfRecordRowEncoder.encodeTfRecord(row)
50 | (new BytesWritable(example.toByteArray), NullWritable.get())
51 | })
52 | features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](path)
53 |
54 | TensorflowRelation(parameters)(sqlContext.sparkSession)
55 | }
56 |
57 | override def createRelation(sqlContext: SQLContext,
58 | parameters: Map[String, String],
59 | schema: StructType): BaseRelation = {
60 | TensorflowRelation(parameters, Some(schema))(sqlContext.sparkSession)
61 | }
62 |
63 | // Reads TensorFlow Records into DataFrame
64 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): TensorflowRelation = {
65 | TensorflowRelation(parameters)(sqlContext.sparkSession)
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/TensorflowInferSchema.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow
17 |
18 | import org.apache.spark.rdd.RDD
19 | import org.apache.spark.sql.types._
20 | import org.tensorflow.example.{Example, Feature}
21 | import scala.collection.mutable.Map
22 | import scala.util.control.Exception._
23 | import scala.collection.JavaConverters._
24 |
25 | object TensorflowInferSchema {
26 |
27 | /**
28 | * Similar to the JSON schema inference.
29 | * [[org.apache.spark.sql.execution.datasources.json.InferSchema]]
30 | * 1. Infer type of each row
31 | * 2. Merge row types to find common type
32 | * 3. Replace any null types with string type
33 | */
34 | def apply(exampleRdd: RDD[Example]): StructType = {
35 | val startType: Map[String, DataType] = Map.empty[String, DataType]
36 | val rootTypes: Map[String, DataType] = exampleRdd.aggregate(startType)(inferRowType, mergeFieldTypes)
37 | val columnsList = rootTypes.map {
38 | case (featureName, featureType) =>
39 | if (featureType == null) {
40 | StructField(featureName, StringType)
41 | }
42 | else {
43 | StructField(featureName, featureType)
44 | }
45 | }
46 | StructType(columnsList.toSeq)
47 | }
48 |
49 | private def inferRowType(schemaSoFar: Map[String, DataType], next: Example): Map[String, DataType] = {
50 | next.getFeatures.getFeatureMap.asScala.map {
51 | case (featureName, feature) => {
52 | val currentType = inferField(feature)
53 | if (schemaSoFar.contains(featureName)) {
54 | val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType)
55 | schemaSoFar(featureName) = updatedType.getOrElse(null)
56 | }
57 | else {
58 | schemaSoFar += (featureName -> currentType)
59 | }
60 | }
61 | }
62 | schemaSoFar
63 | }
64 |
65 | private def mergeFieldTypes(first: Map[String, DataType], second: Map[String, DataType]): Map[String, DataType] = {
66 | //Merge two maps and do the comparison.
67 | val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet)
68 | .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get))
69 | .toSeq: _*)
70 | mutMap
71 | }
72 |
73 | /**
74 | * Infer Feature datatype based on field number
75 | */
76 | private def inferField(feature: Feature): DataType = {
77 | feature.getKindCase.getNumber match {
78 | case Feature.BYTES_LIST_FIELD_NUMBER => {
79 | StringType
80 | }
81 | case Feature.INT64_LIST_FIELD_NUMBER => {
82 | parseInt64List(feature)
83 | }
84 | case Feature.FLOAT_LIST_FIELD_NUMBER => {
85 | parseFloatList(feature)
86 | }
87 | case _ => throw new RuntimeException("unsupported type ...")
88 | }
89 | }
90 |
91 | private def parseInt64List(feature: Feature): DataType = {
92 | val int64List = feature.getInt64List.getValueList.asScala.toArray
93 | val length = int64List.size
94 | if (length == 0) {
95 | null
96 | }
97 | else if (length > 1) {
98 | ArrayType(LongType)
99 | }
100 | else {
101 | val fieldValue = int64List(0).toString
102 | parseInteger(fieldValue)
103 | }
104 | }
105 |
106 | private def parseFloatList(feature: Feature): DataType = {
107 | val floatList = feature.getFloatList.getValueList.asScala.toArray
108 | val length = floatList.size
109 | if (length == 0) {
110 | null
111 | }
112 | else if (length > 1) {
113 | ArrayType(DoubleType)
114 | }
115 | else {
116 | val fieldValue = floatList(0).toString
117 | parseFloat(fieldValue)
118 | }
119 | }
120 |
121 | private def parseInteger(field: String): DataType = if (allCatch.opt(field.toInt).isDefined) {
122 | IntegerType
123 | }
124 | else {
125 | parseLong(field)
126 | }
127 |
128 | private def parseLong(field: String): DataType = if (allCatch.opt(field.toLong).isDefined) {
129 | LongType
130 | }
131 | else {
132 | throw new RuntimeException("Unable to parse field datatype to int64...")
133 | }
134 |
135 | private def parseFloat(field: String): DataType = {
136 | if ((allCatch opt field.toFloat).isDefined) {
137 | FloatType
138 | }
139 | else {
140 | parseDouble(field)
141 | }
142 | }
143 |
144 | private def parseDouble(field: String): DataType = if (allCatch.opt(field.toDouble).isDefined) {
145 | DoubleType
146 | }
147 | else {
148 | throw new RuntimeException("Unable to parse field datatype to float64...")
149 | }
150 | /**
151 | * Copied from internal Spark api
152 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
153 | */
154 | private val numericPrecedence: IndexedSeq[DataType] =
155 | IndexedSeq[DataType](IntegerType,
156 | LongType,
157 | FloatType,
158 | DoubleType,
159 | StringType)
160 |
161 | private def getNumericPrecedence(dataType: DataType): Int = {
162 | dataType match {
163 | case x if x.equals(IntegerType) => 0
164 | case x if x.equals(LongType) => 1
165 | case x if x.equals(FloatType) => 2
166 | case x if x.equals(DoubleType) => 3
167 | case x if x.equals(ArrayType(LongType)) => 4
168 | case x if x.equals(ArrayType(DoubleType)) => 5
169 | case x if x.equals(StringType) => 6
170 | case _ => throw new RuntimeException("Unable to get the precedence for given datatype...")
171 | }
172 | }
173 |
174 | /**
175 | * Copied from internal Spark api
176 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
177 | */
178 | private val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
179 | case (t1, t2) if t1 == t2 => Some(t1)
180 | case (null, t2) => Some(t2)
181 | case (t1, null) => Some(t1)
182 | case (t1, t2) if t1.equals(ArrayType(LongType)) && t2.equals(ArrayType(DoubleType)) => Some(ArrayType(DoubleType))
183 | case (t1, t2) if t1.equals(ArrayType(DoubleType)) && t2.equals(ArrayType(LongType)) => Some(ArrayType(DoubleType))
184 | case (StringType, t2) => Some(StringType)
185 | case (t1, StringType) => Some(StringType)
186 |
187 | // Promote numeric types to the highest of the two and all numeric types to unlimited decimal
188 | case (t1, t2) =>
189 | val t1Precedence = getNumericPrecedence(t1)
190 | val t2Precedence = getNumericPrecedence(t2)
191 | val newType = if (t1Precedence > t2Precedence) t1 else t2
192 | Some(newType)
193 | case _ => None
194 | }
195 | }
196 |
197 |
--------------------------------------------------------------------------------
/src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/TensorflowRelation.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow
17 |
18 | import org.apache.hadoop.io.{BytesWritable, NullWritable}
19 | import org.apache.spark.rdd.RDD
20 | import org.apache.spark.sql.sources.{BaseRelation, TableScan}
21 | import org.apache.spark.sql.types.StructType
22 | import org.apache.spark.sql.{Row, SQLContext, SparkSession}
23 | import org.tensorflow.example.Example
24 | import org.tensorflow.hadoop.io.TFRecordFileInputFormat
25 | import org.trustedanalytics.spark.datasources.tensorflow.serde.DefaultTfRecordRowDecoder
26 |
27 |
28 | case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None)(@transient val session: SparkSession) extends BaseRelation with TableScan {
29 |
30 | //Import TFRecords as DataFrame happens here
31 | lazy val (tf_rdd, tf_schema) = {
32 | val rdd = session.sparkContext.newAPIHadoopFile(options("path"), classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable])
33 |
34 | val exampleRdd = rdd.map {
35 | case (bytesWritable, nullWritable) => Example.parseFrom(bytesWritable.getBytes)
36 | }
37 |
38 | val finalSchema = customSchema.getOrElse(TensorflowInferSchema(exampleRdd))
39 |
40 | (exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeTfRecord(example, finalSchema)), finalSchema)
41 | }
42 |
43 | override def sqlContext: SQLContext = session.sqlContext
44 |
45 | override def schema: StructType = tf_schema
46 |
47 | override def buildScan(): RDD[Row] = tf_rdd
48 | }
49 |
50 |
--------------------------------------------------------------------------------
/src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/DefaultTfRecordRowDecoder.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow.serde
17 |
18 | import org.apache.spark.sql.types._
19 | import org.apache.spark.sql.Row
20 | import org.tensorflow.example._
21 | import scala.collection.JavaConverters._
22 |
23 | trait TfRecordRowDecoder {
24 | /**
25 | * Decodes each TensorFlow "Example" as DataFrame "Row"
26 | *
27 | * Maps each feature in Example to element in Row with DataType based on custom schema or
28 | * default mapping of Int64List, FloatList, BytesList to column data type
29 | *
30 | * @param example TensorFlow Example to decode
31 | * @param schema Decode Example using specified schema
32 | * @return a DataFrame row
33 | */
34 | def decodeTfRecord(example: Example, schema: StructType): Row
35 | }
36 |
37 | object DefaultTfRecordRowDecoder extends TfRecordRowDecoder {
38 |
39 | /**
40 | * Decodes each TensorFlow "Example" as DataFrame "Row"
41 | *
42 | * Maps each feature in Example to element in Row with DataType based on custom schema
43 | *
44 | * @param example TensorFlow Example to decode
45 | * @param schema Decode Example using specified schema
46 | * @return a DataFrame row
47 | */
48 | def decodeTfRecord(example: Example, schema: StructType): Row = {
49 | val row = Array.fill[Any](schema.length)(null)
50 | example.getFeatures.getFeatureMap.asScala.foreach {
51 | case (featureName, feature) =>
52 | val index = schema.fieldIndex(featureName)
53 | val colDataType = schema.fields(index).dataType
54 | row(index) = colDataType match {
55 | case IntegerType => IntFeatureDecoder.decode(feature)
56 | case LongType => LongFeatureDecoder.decode(feature)
57 | case FloatType => FloatFeatureDecoder.decode(feature)
58 | case DoubleType => DoubleFeatureDecoder.decode(feature)
59 | case ArrayType(IntegerType, true) => IntListFeatureDecoder.decode(feature)
60 | case ArrayType(LongType, _) => LongListFeatureDecoder.decode(feature)
61 | case ArrayType(FloatType, _) => FloatListFeatureDecoder.decode(feature)
62 | case ArrayType(DoubleType, _) => DoubleListFeatureDecoder.decode(feature)
63 | case StringType => StringFeatureDecoder.decode(feature)
64 | case _ => throw new RuntimeException(s"Cannot convert feature to unsupported data type ${colDataType}")
65 | }
66 | }
67 | Row.fromSeq(row)
68 | }
69 | }
70 |
71 |
--------------------------------------------------------------------------------
/src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/DefaultTfRecordRowEncoder.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow.serde
17 |
18 | import org.apache.spark.sql.Row
19 | import org.apache.spark.sql.types._
20 | import org.tensorflow.example._
21 |
22 | trait TfRecordRowEncoder {
23 | /**
24 | * Encodes each Row as TensorFlow "Example"
25 | *
26 | * Maps each column in Row to one of Int64List, FloatList, BytesList based on the column data type
27 | *
28 | * @param row a DataFrame row
29 | * @return TensorFlow Example
30 | */
31 | def encodeTfRecord(row: Row): Example
32 | }
33 |
34 | object DefaultTfRecordRowEncoder extends TfRecordRowEncoder {
35 |
36 | /**
37 | * Encodes each Row as TensorFlow "Example"
38 | *
39 | * Maps each column in Row to one of Int64List, FloatList, BytesList based on the column data type
40 | *
41 | * @param row a DataFrame row
42 | * @return TensorFlow Example
43 | */
44 | def encodeTfRecord(row: Row): Example = {
45 | val features = Features.newBuilder()
46 | val example = Example.newBuilder()
47 |
48 | row.schema.zipWithIndex.map {
49 | case (structField, index) =>
50 | val value = row.get(index)
51 | val feature = structField.dataType match {
52 | case IntegerType | LongType => Int64ListFeatureEncoder.encode(value)
53 | case FloatType | DoubleType => FloatListFeatureEncoder.encode(value)
54 | case ArrayType(IntegerType, _) | ArrayType(LongType, _) => Int64ListFeatureEncoder.encode(value)
55 | case ArrayType(DoubleType, _) => FloatListFeatureEncoder.encode(value)
56 | case _ => BytesListFeatureEncoder.encode(value)
57 | }
58 | features.putFeature(structField.name, feature)
59 | }
60 |
61 | features.build()
62 | example.setFeatures(features)
63 | example.build()
64 | }
65 | }
66 |
67 |
--------------------------------------------------------------------------------
/src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/FeatureDecoder.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow.serde
17 |
18 | import org.tensorflow.example.Feature
19 | import scala.collection.JavaConverters._
20 |
21 | trait FeatureDecoder[T] {
22 | /**
23 | * Decodes each TensorFlow "Feature" to desired Scala type
24 | *
25 | * @param feature TensorFlow Feature
26 | * @return Decoded feature
27 | */
28 | def decode(feature: Feature): T
29 | }
30 |
31 | /**
32 | * Decode TensorFlow "Feature" to Integer
33 | */
34 | object IntFeatureDecoder extends FeatureDecoder[Int] {
35 | override def decode(feature: Feature): Int = {
36 | require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
37 | try {
38 | val int64List = feature.getInt64List.getValueList
39 | require(int64List.size() == 1, "Length of Int64List must equal 1")
40 | int64List.get(0).intValue()
41 | }
42 | catch {
43 | case ex: Exception =>
44 | throw new RuntimeException(s"Cannot convert feature to Int.", ex)
45 | }
46 | }
47 | }
48 |
49 | /**
50 | * Decode TensorFlow "Feature" to Seq[Int]
51 | */
52 | object IntListFeatureDecoder extends FeatureDecoder[Seq[Int]] {
53 | override def decode(feature: Feature): Seq[Int] = {
54 | require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
55 | try {
56 | val array = feature.getInt64List.getValueList.asScala.toArray
57 | array.map(_.toInt)
58 | }
59 | catch {
60 | case ex: Exception =>
61 | throw new RuntimeException(s"Cannot convert feature to Seq[Int].", ex)
62 | }
63 | }
64 | }
65 |
66 | /**
67 | * Decode TensorFlow "Feature" to Long
68 | */
69 | object LongFeatureDecoder extends FeatureDecoder[Long] {
70 | override def decode(feature: Feature): Long = {
71 | require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
72 | try {
73 | val int64List = feature.getInt64List.getValueList
74 | require(int64List.size() == 1, "Length of Int64List must equal 1")
75 | int64List.get(0).longValue()
76 | }
77 | catch {
78 | case ex: Exception =>
79 | throw new RuntimeException(s"Cannot convert feature to Long.", ex)
80 | }
81 | }
82 | }
83 |
84 | /**
85 | * Decode TensorFlow "Feature" to Seq[Long]
86 | */
87 | object LongListFeatureDecoder extends FeatureDecoder[Seq[Long]] {
88 | override def decode(feature: Feature): Seq[Long] = {
89 | require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
90 | try {
91 | val array = feature.getInt64List.getValueList.asScala.toArray
92 | array.map(_.toLong)
93 | }
94 | catch {
95 | case ex: Exception =>
96 | throw new RuntimeException(s"Cannot convert feature to Array[Long].", ex)
97 | }
98 | }
99 | }
100 |
101 | /**
102 | * Decode TensorFlow "Feature" to Float
103 | */
104 | object FloatFeatureDecoder extends FeatureDecoder[Float] {
105 | override def decode(feature: Feature): Float = {
106 | require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
107 | try {
108 | val floatList = feature.getFloatList.getValueList
109 | require(floatList.size() == 1, "Length of FloatList must equal 1")
110 | floatList.get(0).floatValue()
111 | }
112 | catch {
113 | case ex: Exception =>
114 | throw new RuntimeException(s"Cannot convert feature to Float.", ex)
115 | }
116 | }
117 | }
118 |
119 | /**
120 | * Decode TensorFlow "Feature" to Seq[Float]
121 | */
122 | object FloatListFeatureDecoder extends FeatureDecoder[Seq[Float]] {
123 | override def decode(feature: Feature): Seq[Float] = {
124 | require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
125 | try {
126 | val array = feature.getFloatList.getValueList.asScala.toArray
127 | array.map(_.toFloat)
128 | }
129 | catch {
130 | case ex: Exception =>
131 | throw new RuntimeException(s"Cannot convert feature to Array[Float].", ex)
132 | }
133 | }
134 | }
135 |
136 | /**
137 | * Decode TensorFlow "Feature" to Double
138 | */
139 | object DoubleFeatureDecoder extends FeatureDecoder[Double] {
140 | override def decode(feature: Feature): Double = {
141 | require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
142 | try {
143 | val floatList = feature.getFloatList.getValueList
144 | require(floatList.size() == 1, "Length of FloatList must equal 1")
145 | floatList.get(0).doubleValue()
146 | }
147 | catch {
148 | case ex: Exception =>
149 | throw new RuntimeException(s"Cannot convert feature to Double.", ex)
150 | }
151 | }
152 | }
153 |
154 | /**
155 | * Decode TensorFlow "Feature" to Seq[Double]
156 | */
157 | object DoubleListFeatureDecoder extends FeatureDecoder[Seq[Double]] {
158 | override def decode(feature: Feature): Seq[Double] = {
159 | require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
160 | try {
161 | val array = feature.getFloatList.getValueList.asScala.toArray
162 | array.map(_.toDouble)
163 | }
164 | catch {
165 | case ex: Exception =>
166 | throw new RuntimeException(s"Cannot convert feature to Array[Double].", ex)
167 | }
168 | }
169 | }
170 |
171 | /**
172 | * Decode TensorFlow "Feature" to String
173 | */
174 | object StringFeatureDecoder extends FeatureDecoder[String] {
175 | override def decode(feature: Feature): String = {
176 | require(feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList")
177 | try {
178 | feature.getBytesList.toByteString.toStringUtf8.trim
179 | }
180 | catch {
181 | case ex: Exception =>
182 | throw new RuntimeException(s"Cannot convert feature to String.", ex)
183 | }
184 | }
185 | }
186 |
187 |
--------------------------------------------------------------------------------
/src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/FeatureEncoder.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow.serde
17 |
18 | import org.tensorflow.example.{BytesList, Feature, FloatList, Int64List}
19 | import org.tensorflow.hadoop.shaded.protobuf.ByteString
20 | import org.trustedanalytics.spark.datasources.tensorflow.DataTypesConvertor
21 |
22 | trait FeatureEncoder {
23 | /**
24 | * Encodes input value as TensorFlow "Feature"
25 | *
26 | * Maps input value to one of Int64List, FloatList, BytesList
27 | *
28 | * @param value Input value
29 | * @return TensorFlow Feature
30 | */
31 | def encode(value: Any): Feature
32 | }
33 |
34 | /**
35 | * Encode input value to Int64List
36 | */
37 | object Int64ListFeatureEncoder extends FeatureEncoder {
38 | override def encode(value: Any): Feature = {
39 | try {
40 | val int64List = value match {
41 | case i: Int => Int64List.newBuilder().addValue(i.toLong).build()
42 | case l: Long => Int64List.newBuilder().addValue(l).build()
43 | case arr: scala.collection.mutable.WrappedArray[_] => toInt64List(arr.toArray[Any])
44 | case arr: Array[_] => toInt64List(arr)
45 | case seq: Seq[_] => toInt64List(seq.toArray[Any])
46 | case _ => throw new RuntimeException(s"Cannot convert object $value to Int64List")
47 | }
48 | Feature.newBuilder().setInt64List(int64List).build()
49 | }
50 | catch {
51 | case ex: Exception =>
52 | throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to Int64List feature.", ex)
53 | }
54 | }
55 |
56 | private def toInt64List[T](arr: Array[T]): Int64List = {
57 | val intListBuilder = Int64List.newBuilder()
58 | arr.foreach(x => {
59 | require(x != null, "Int64List with null values is not supported")
60 | val longValue = DataTypesConvertor.toLong(x)
61 | intListBuilder.addValue(longValue)
62 | })
63 | intListBuilder.build()
64 | }
65 | }
66 |
67 | /**
68 | * Encode input value to FloatList
69 | */
70 | object FloatListFeatureEncoder extends FeatureEncoder {
71 | override def encode(value: Any): Feature = {
72 | try {
73 | val floatList = value match {
74 | case i: Int => FloatList.newBuilder().addValue(i.toFloat).build()
75 | case l: Long => FloatList.newBuilder().addValue(l.toFloat).build()
76 | case f: Float => FloatList.newBuilder().addValue(f).build()
77 | case d: Double => FloatList.newBuilder().addValue(d.toFloat).build()
78 | case arr: scala.collection.mutable.WrappedArray[_] => toFloatList(arr.toArray[Any])
79 | case arr: Array[_] => toFloatList(arr)
80 | case seq: Seq[_] => toFloatList(seq.toArray[Any])
81 | case _ => throw new RuntimeException(s"Cannot convert object $value to FloatList")
82 | }
83 | Feature.newBuilder().setFloatList(floatList).build()
84 | }
85 | catch {
86 | case ex: Exception =>
87 | throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to FloatList feature.", ex)
88 | }
89 | }
90 |
91 | private def toFloatList[T](arr: Array[T]): FloatList = {
92 | val floatListBuilder = FloatList.newBuilder()
93 | arr.foreach(x => {
94 | require(x != null, "FloatList with null values is not supported")
95 | val longValue = DataTypesConvertor.toFloat(x)
96 | floatListBuilder.addValue(longValue)
97 | })
98 | floatListBuilder.build()
99 | }
100 | }
101 |
102 | /**
103 | * Encode input value to ByteList
104 | */
105 | object BytesListFeatureEncoder extends FeatureEncoder {
106 | override def encode(value: Any): Feature = {
107 | try {
108 | val byteList = BytesList.newBuilder().addValue(ByteString.copyFrom(value.toString.getBytes)).build()
109 | Feature.newBuilder().setBytesList(byteList).build()
110 | }
111 | catch {
112 | case ex: Exception =>
113 | throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to ByteList feature.", ex)
114 | }
115 | }
116 | }
117 |
118 |
119 |
--------------------------------------------------------------------------------
/src/test/scala/org/trustedanalytics/spark/datasources/tensorflow/SharedSparkSessionSuite.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.trustedanalytics.spark.datasources.tensorflow
18 |
19 | import java.io.File
20 |
21 | import org.apache.commons.io.FileUtils
22 | import org.apache.spark.SharedSparkSession
23 | import org.junit.{After, Before}
24 | import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
25 |
26 |
27 | trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll
28 |
29 | class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite {
30 | val TF_SANDBOX_DIR = "tf-sandbox"
31 | val file = new File(TF_SANDBOX_DIR)
32 |
33 | @Before
34 | override def beforeAll() = {
35 | super.setUp()
36 | FileUtils.deleteQuietly(file)
37 | file.mkdirs()
38 | }
39 |
40 | @After
41 | override def afterAll() = {
42 | FileUtils.deleteQuietly(file)
43 | super.tearDown()
44 | }
45 | }
46 |
47 |
--------------------------------------------------------------------------------
/src/test/scala/org/trustedanalytics/spark/datasources/tensorflow/TensorflowSuite.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.trustedanalytics.spark.datasources.tensorflow
18 |
19 | import org.apache.spark.rdd.RDD
20 | import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
21 | import org.apache.spark.sql.types._
22 | import org.apache.spark.sql.{DataFrame, Row}
23 | import org.tensorflow.example._
24 | import org.tensorflow.hadoop.shaded.protobuf.ByteString
25 | import org.trustedanalytics.spark.datasources.tensorflow.serde.{DefaultTfRecordRowDecoder, DefaultTfRecordRowEncoder}
26 | import scala.collection.JavaConverters._
27 |
28 | class TensorflowSuite extends SharedSparkSessionSuite {
29 |
30 | "Spark TensorFlow module" should {
31 |
32 | "Test Import/Export" in {
33 |
34 | val path = s"$TF_SANDBOX_DIR/output25.tfr"
35 | val testRows: Array[Row] = Array(
36 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
37 | new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
38 |
39 | val schema = StructType(List(
40 | StructField("id", IntegerType),
41 | StructField("IntegerTypelabel", IntegerType),
42 | StructField("LongTypelabel", LongType),
43 | StructField("FloatTypelabel", FloatType),
44 | StructField("DoubleTypelabel", DoubleType),
45 | StructField("vectorlabel", ArrayType(DoubleType, true)),
46 | StructField("name", StringType)))
47 |
48 | val rdd = spark.sparkContext.parallelize(testRows)
49 |
50 | val df: DataFrame = spark.createDataFrame(rdd, schema)
51 | df.write.format("tensorflow").save(path)
52 |
53 | //If schema is not provided. It will automatically infer schema
54 | val importedDf: DataFrame = spark.read.format("tensorflow").schema(schema).load(path)
55 | val actualDf = importedDf.select("id", "IntegerTypelabel", "LongTypelabel", "FloatTypelabel", "DoubleTypelabel", "vectorlabel", "name").sort("name")
56 |
57 | val expectedRows = df.collect()
58 | val actualRows = actualDf.collect()
59 |
60 | expectedRows should equal(actualRows)
61 | }
62 |
63 | "Encode given Row as TensorFlow example" in {
64 | val schemaStructType = StructType(Array(
65 | StructField("IntegerTypelabel", IntegerType),
66 | StructField("LongTypelabel", LongType),
67 | StructField("FloatTypelabel", FloatType),
68 | StructField("DoubleTypelabel", DoubleType),
69 | StructField("vectorlabel", ArrayType(DoubleType, true)),
70 | StructField("strlabel", StringType)
71 | ))
72 | val doubleArray = Array(1.1, 111.1, 11111.1)
73 | val expectedFloatArray = Array(1.1F, 111.1F, 11111.1F)
74 |
75 | val rowWithSchema = new GenericRowWithSchema(Array[Any](1, 23L, 10.0F, 14.0, doubleArray, "r1"), schemaStructType)
76 |
77 | //Encode Sql Row to TensorFlow example
78 | val example = DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema)
79 | import org.tensorflow.example.Feature
80 |
81 | //Verify each Datatype converted to TensorFlow datatypes
82 | val featureMap = example.getFeatures.getFeatureMap.asScala
83 | assert(featureMap("IntegerTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER)
84 | assert(featureMap("IntegerTypelabel").getInt64List.getValue(0).toInt == 1)
85 |
86 | assert(featureMap("LongTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER)
87 | assert(featureMap("LongTypelabel").getInt64List.getValue(0).toInt == 23)
88 |
89 | assert(featureMap("FloatTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
90 | assert(featureMap("FloatTypelabel").getFloatList.getValue(0) == 10.0F)
91 |
92 | assert(featureMap("DoubleTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
93 | assert(featureMap("DoubleTypelabel").getFloatList.getValue(0) == 14.0F)
94 |
95 | assert(featureMap("vectorlabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
96 | assert(featureMap("vectorlabel").getFloatList.getValueList.toArray === expectedFloatArray)
97 |
98 | assert(featureMap("strlabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
99 | assert(featureMap("strlabel").getBytesList.toByteString.toStringUtf8.trim == "r1")
100 |
101 | }
102 |
103 | "Throw an exception for a vector with null values during Encode" in {
104 | intercept[Exception] {
105 | val schemaStructType = StructType(Array(
106 | StructField("vectorlabel", ArrayType(DoubleType, true))
107 | ))
108 | val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
109 |
110 | val rowWithSchema = new GenericRowWithSchema(Array[Any](doubleArray), schemaStructType)
111 |
112 | //Throws NullPointerException
113 | DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema)
114 | }
115 | }
116 |
117 | "Decode given TensorFlow Example as Row" in {
118 |
119 | //Here Vector with null's are not supported
120 | val expectedRow = new GenericRow(Array[Any](1, 23L, 10.0F, 14.0, Seq(1.0, 2.0), "r1"))
121 |
122 | val schema = StructType(List(
123 | StructField("IntegerTypelabel", IntegerType),
124 | StructField("LongTypelabel", LongType),
125 | StructField("FloatTypelabel", FloatType),
126 | StructField("DoubleTypelabel", DoubleType),
127 | StructField("vectorlabel", ArrayType(DoubleType)),
128 | StructField("strlabel", StringType)))
129 |
130 | //Build example
131 | val intFeature = Int64List.newBuilder().addValue(1)
132 | val longFeature = Int64List.newBuilder().addValue(23L)
133 | val floatFeature = FloatList.newBuilder().addValue(10.0F)
134 | val doubleFeature = FloatList.newBuilder().addValue(14.0F)
135 | val vectorFeature = FloatList.newBuilder().addValue(1F).addValue(2F).build()
136 | val strFeature = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build()
137 | val features = Features.newBuilder()
138 | .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature).build())
139 | .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature).build())
140 | .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature).build())
141 | .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature).build())
142 | .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature).build())
143 | .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature).build())
144 | .build()
145 | val example = Example.newBuilder()
146 | .setFeatures(features)
147 | .build()
148 |
149 | //Decode TensorFlow example to Sql Row
150 | val actualRow = DefaultTfRecordRowDecoder.decodeTfRecord(example, schema)
151 | actualRow should equal(expectedRow)
152 | }
153 |
154 | "Check infer schema" in {
155 |
156 | //Build example1
157 | val intFeature1 = Int64List.newBuilder().addValue(1)
158 | val longFeature1 = Int64List.newBuilder().addValue(Int.MaxValue + 10L)
159 | val floatFeature1 = FloatList.newBuilder().addValue(10.0F)
160 | val doubleFeature1 = FloatList.newBuilder().addValue(14.0F)
161 | val vectorFeature1 = FloatList.newBuilder().addValue(1F).build()
162 | val strFeature1 = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build()
163 | val features1 = Features.newBuilder()
164 | .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature1).build())
165 | .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature1).build())
166 | .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature1).build())
167 | .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature1).build())
168 | .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature1).build())
169 | .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature1).build())
170 | .build()
171 | val example1 = Example.newBuilder()
172 | .setFeatures(features1)
173 | .build()
174 |
175 | //Build example2
176 | val intFeature2 = Int64List.newBuilder().addValue(2)
177 | val longFeature2 = Int64List.newBuilder().addValue(24)
178 | val floatFeature2 = FloatList.newBuilder().addValue(12.0F)
179 | val doubleFeature2 = FloatList.newBuilder().addValue(Float.MaxValue + 15)
180 | val vectorFeature2 = FloatList.newBuilder().addValue(2F).addValue(2F).build()
181 | val strFeature2 = BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)).build()
182 | val features2 = Features.newBuilder()
183 | .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature2).build())
184 | .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature2).build())
185 | .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature2).build())
186 | .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature2).build())
187 | .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature2).build())
188 | .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature2).build())
189 | .build()
190 | val example2 = Example.newBuilder()
191 | .setFeatures(features2)
192 | .build()
193 |
194 | val exampleRDD: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2))
195 |
196 | val actualSchema = TensorflowInferSchema(exampleRDD)
197 |
198 | //Verify each TensorFlow Datatype is inferred as one of our Datatype
199 | actualSchema.fields.map { colum =>
200 | colum.name match {
201 | case "IntegerTypelabel" => colum.dataType.equals(IntegerType)
202 | case "LongTypelabel" => colum.dataType.equals(LongType)
203 | case "FloatTypelabel" | "DoubleTypelabel" | "vectorlabel" => colum.dataType.equals(FloatType)
204 | case "strlabel" => colum.dataType.equals(StringType)
205 | }
206 | }
207 | }
208 | }
209 | }
210 |
211 |
--------------------------------------------------------------------------------
/src/test/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/FeatureDecoderTest.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow.serde
17 |
18 | import org.scalatest.{Matchers, WordSpec}
19 | import org.tensorflow.example.{BytesList, FloatList, Feature, Int64List}
20 | import org.tensorflow.hadoop.shaded.protobuf.ByteString
21 |
22 | class FeatureDecoderTest extends WordSpec with Matchers {
23 |
24 | "Int Feature decoder" should {
25 |
26 | "Decode Feature to Int" in {
27 | val int64List = Int64List.newBuilder().addValue(4).build()
28 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
29 | IntFeatureDecoder.decode(intFeature) should equal(4)
30 | }
31 |
32 | "Throw an exception if length of feature array exceeds 1" in {
33 | intercept[Exception] {
34 | val int64List = Int64List.newBuilder().addValue(4).addValue(7).build()
35 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
36 | IntFeatureDecoder.decode(intFeature)
37 | }
38 | }
39 |
40 | "Throw an exception if feature is not an Int64List" in {
41 | intercept[Exception] {
42 | val floatList = FloatList.newBuilder().addValue(4).build()
43 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
44 | IntFeatureDecoder.decode(floatFeature)
45 | }
46 | }
47 | }
48 |
49 | "Int List Feature decoder" should {
50 |
51 | "Decode Feature to Int List" in {
52 | val int64List = Int64List.newBuilder().addValue(3).addValue(9).build()
53 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
54 | IntListFeatureDecoder.decode(intFeature) should equal(Seq(3,9))
55 | }
56 |
57 | "Throw an exception if feature is not an Int64List" in {
58 | intercept[Exception] {
59 | val floatList = FloatList.newBuilder().addValue(4).build()
60 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
61 | IntListFeatureDecoder.decode(floatFeature)
62 | }
63 | }
64 | }
65 |
66 | "Long Feature decoder" should {
67 |
68 | "Decode Feature to Long" in {
69 | val int64List = Int64List.newBuilder().addValue(5L).build()
70 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
71 | LongFeatureDecoder.decode(intFeature) should equal(5L)
72 | }
73 |
74 | "Throw an exception if length of feature array exceeds 1" in {
75 | intercept[Exception] {
76 | val int64List = Int64List.newBuilder().addValue(4L).addValue(10L).build()
77 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
78 | LongFeatureDecoder.decode(intFeature)
79 | }
80 | }
81 |
82 | "Throw an exception if feature is not an Int64List" in {
83 | intercept[Exception] {
84 | val floatList = FloatList.newBuilder().addValue(4).build()
85 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
86 | LongFeatureDecoder.decode(floatFeature)
87 | }
88 | }
89 | }
90 |
91 | "Long List Feature decoder" should {
92 |
93 | "Decode Feature to Long List" in {
94 | val int64List = Int64List.newBuilder().addValue(3L).addValue(Int.MaxValue+10L).build()
95 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
96 | LongListFeatureDecoder.decode(intFeature) should equal(Seq(3L,Int.MaxValue+10L))
97 | }
98 |
99 | "Throw an exception if feature is not an Int64List" in {
100 | intercept[Exception] {
101 | val floatList = FloatList.newBuilder().addValue(4).build()
102 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
103 | LongListFeatureDecoder.decode(floatFeature)
104 | }
105 | }
106 | }
107 |
108 | "Float Feature decoder" should {
109 |
110 | "Decode Feature to Float" in {
111 | val floatList = FloatList.newBuilder().addValue(2.5F).build()
112 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
113 | FloatFeatureDecoder.decode(floatFeature) should equal(2.5F)
114 | }
115 |
116 | "Throw an exception if length of feature array exceeds 1" in {
117 | intercept[Exception] {
118 | val floatList = FloatList.newBuilder().addValue(1.5F).addValue(3.33F).build()
119 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
120 | FloatFeatureDecoder.decode(floatFeature)
121 | }
122 | }
123 |
124 | "Throw an exception if feature is not a FloatList" in {
125 | intercept[Exception] {
126 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
127 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
128 | FloatFeatureDecoder.decode(bytesFeature)
129 | }
130 | }
131 | }
132 |
133 | "Float List Feature decoder" should {
134 |
135 | "Decode Feature to Float List" in {
136 | val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.3F).build()
137 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
138 | FloatListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5F, 4.3F))
139 | }
140 |
141 | "Throw an exception if feature is not a FloatList" in {
142 | intercept[Exception] {
143 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
144 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
145 | FloatListFeatureDecoder.decode(bytesFeature)
146 | }
147 | }
148 | }
149 |
150 | "Double Feature decoder" should {
151 |
152 | "Decode Feature to Double" in {
153 | val floatList = FloatList.newBuilder().addValue(2.5F).build()
154 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
155 | DoubleFeatureDecoder.decode(floatFeature) should equal(2.5d)
156 | }
157 |
158 | "Throw an exception if length of feature array exceeds 1" in {
159 | intercept[Exception] {
160 | val floatList = FloatList.newBuilder().addValue(1.5F).addValue(3.33F).build()
161 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
162 | DoubleFeatureDecoder.decode(floatFeature)
163 | }
164 | }
165 |
166 | "Throw an exception if feature is not a FloatList" in {
167 | intercept[Exception] {
168 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
169 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
170 | DoubleFeatureDecoder.decode(bytesFeature)
171 | }
172 | }
173 | }
174 |
175 | "Double List Feature decoder" should {
176 |
177 | "Decode Feature to Double List" in {
178 | val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build()
179 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
180 | DoubleListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5d, 4.0d))
181 | }
182 |
183 | "Throw an exception if feature is not a DoubleList" in {
184 | intercept[Exception] {
185 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
186 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
187 | FloatListFeatureDecoder.decode(bytesFeature)
188 | }
189 | }
190 | }
191 |
192 | "Bytes List Feature decoder" should {
193 |
194 | "Decode Feature to Bytes List" in {
195 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
196 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
197 | StringFeatureDecoder.decode(bytesFeature) should equal("str-input")
198 | }
199 |
200 | "Throw an exception if feature is not a BytesList" in {
201 | intercept[Exception] {
202 | val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build()
203 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
204 | StringFeatureDecoder.decode(floatFeature)
205 | }
206 | }
207 | }
208 | }
209 |
210 |
--------------------------------------------------------------------------------
/src/test/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/FeatureEncoderTest.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) 2016 Intel Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.trustedanalytics.spark.datasources.tensorflow.serde
17 |
18 | import org.scalatest.{Matchers, WordSpec}
19 | import scala.collection.JavaConverters._
20 |
21 | class FeatureEncoderTest extends WordSpec with Matchers {
22 |
23 | "Int64List feature encoder" should {
24 | "Encode inputs to Int64List" in {
25 | val intFeature = Int64ListFeatureEncoder.encode(5)
26 | val longFeature = Int64ListFeatureEncoder.encode(10L)
27 | val longListFeature = Int64ListFeatureEncoder.encode(Seq(3L,5L,6L))
28 |
29 | intFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(5L))
30 | longFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(10L))
31 | longListFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(3L, 5L, 6L))
32 | }
33 |
34 | "Throw an exception when inputs contain null" in {
35 | intercept[Exception] {
36 | Int64ListFeatureEncoder.encode(null)
37 | }
38 | intercept[Exception] {
39 | Int64ListFeatureEncoder.encode(Seq(3,null,6))
40 | }
41 | }
42 |
43 | "Throw an exception for non-numeric inputs" in {
44 | intercept[Exception] {
45 | Int64ListFeatureEncoder.encode("bad-input")
46 | }
47 | }
48 | }
49 |
50 | "FloatList feature encoder" should {
51 | "Encode inputs to FloatList" in {
52 | val intFeature = FloatListFeatureEncoder.encode(5)
53 | val longFeature = FloatListFeatureEncoder.encode(10L)
54 | val floatFeature = FloatListFeatureEncoder.encode(2.5F)
55 | val doubleFeature = FloatListFeatureEncoder.encode(14.6)
56 | val floatListFeature = FloatListFeatureEncoder.encode(Seq(1.5F,6.8F,-3.2F))
57 |
58 | intFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(5F))
59 | longFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(10F))
60 | floatFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(2.5F))
61 | doubleFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(14.6F))
62 | floatListFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(1.5F,6.8F,-3.2F))
63 | }
64 |
65 | "Throw an exception when inputs contain null" in {
66 | intercept[Exception] {
67 | FloatListFeatureEncoder.encode(null)
68 | }
69 | intercept[Exception] {
70 | FloatListFeatureEncoder.encode(Seq(3,null,6))
71 | }
72 | }
73 |
74 | "Throw an exception for non-numeric inputs" in {
75 | intercept[Exception] {
76 | FloatListFeatureEncoder.encode("bad-input")
77 | }
78 | }
79 | }
80 |
81 | "ByteList feature encoder" should {
82 | "Encode inputs to ByteList" in {
83 | val longFeature = BytesListFeatureEncoder.encode(10L)
84 | val longListFeature = BytesListFeatureEncoder.encode(Seq(3L,5L,6L))
85 | val strFeature = BytesListFeatureEncoder.encode("str-input")
86 |
87 | longFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("10")
88 | longListFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("List(3, 5, 6)")
89 | strFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("str-input")
90 | }
91 |
92 | "Throw an exception when inputs contain null" in {
93 | intercept[Exception] {
94 | BytesListFeatureEncoder.encode(null)
95 | }
96 | }
97 | }
98 | }
99 |
--------------------------------------------------------------------------------