├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── build.sbt ├── pom.xml ├── project ├── build.properties └── plugins.sbt └── src ├── main ├── resources │ └── META-INF │ │ └── services │ │ └── org.apache.spark.sql.sources.DataSourceRegister └── scala │ └── org │ └── trustedanalytics │ └── spark │ └── datasources │ └── tensorflow │ ├── DataTypesConvertor.scala │ ├── DefaultSource.scala │ ├── TensorflowInferSchema.scala │ ├── TensorflowRelation.scala │ └── serde │ ├── DefaultTfRecordRowDecoder.scala │ ├── DefaultTfRecordRowEncoder.scala │ ├── FeatureDecoder.scala │ └── FeatureEncoder.scala └── test └── scala └── org └── trustedanalytics └── spark └── datasources └── tensorflow ├── SharedSparkSessionSuite.scala ├── TensorflowSuite.scala └── serde ├── FeatureDecoderTest.scala └── FeatureEncoderTest.scala /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | *.jar 4 | *.log 5 | target 6 | tf-sandbox 7 | spark-warehouse/ 8 | metastore_db/ 9 | project/project/ 10 | test-output.tfr -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | 3 | # Cache settings here are based on latest SBT documentation. 4 | cache: 5 | directories: 6 | - $HOME/.ivy2/cache 7 | - $HOME/.sbt/boot/ 8 | 9 | before_cache: 10 | # Tricks to avoid unnecessary cache updates 11 | - find $HOME/.ivy2 -name "ivydata-*.properties" -delete 12 | - find $HOME/.sbt -name "*.lock" -delete 13 | 14 | scala: 15 | - 2.11.8 16 | 17 | jdk: 18 | - oraclejdk8 19 | 20 | script: 21 | - sbt ++$TRAVIS_SCALA_VERSION clean publish-local 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/tapanalyticstoolkit/spark-tensorflow-connector.svg?branch=sbt)](https://travis-ci.org/tapanalyticstoolkit/spark-tensorflow-connector) 2 | 3 | # spark-tensorflow-connector 4 | 5 | __NOTE: This repo has been contributed to the TensorFlow ecosystem, and is no longer maintained here. Please go to [spark-tensorflow-connector](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector) in the TensorFlow ecosystem for the latest version.__ 6 | 7 | 8 | This repo contains a library for loading and storing TensorFlow records with [Apache Spark](http://spark.apache.org/). 9 | The library implements data import from the standard TensorFlow record format ([TFRecords] 10 | (https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records. 11 | 12 | ## What's new 13 | 14 | This is the initial release of the `spark-tensorflow-connector` repo. 15 | 16 | ## Known issues 17 | 18 | None. 19 | 20 | ## Prerequisites 21 | 22 | 1. [Apache Spark 2.0 (or later)](http://spark.apache.org/) 23 | 24 | 2. [Apache Maven](https://maven.apache.org/) 25 | 26 | ## Building the library 27 | You can build library using both Maven and SBT build tools 28 | 29 | #### Maven 30 | Build the library using Maven(3.3) as shown below 31 | 32 | ```sh 33 | mvn clean install 34 | ``` 35 | 36 | #### SBT 37 | Build the library using SBT(0.13.13) as show below 38 | ```sh 39 | sbt clean assembly 40 | ``` 41 | 42 | ## Using Spark Shell 43 | Run this library in Spark using the `--jars` command line option in `spark-shell` or `spark-submit`. For example: 44 | 45 | Maven Jars 46 | ```sh 47 | $SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar 48 | ``` 49 | 50 | SBT Jars 51 | ```sh 52 | $SPARK_HOME/bin/spark-shell --jars target/scala-2.11/spark-tensorflow-connector-assembly-1.0.0.jar 53 | ``` 54 | 55 | The following code snippet demonstrates usage. 56 | 57 | ```scala 58 | import org.apache.commons.io.FileUtils 59 | import org.apache.spark.sql.{ DataFrame, Row } 60 | import org.apache.spark.sql.catalyst.expressions.GenericRow 61 | import org.apache.spark.sql.types._ 62 | 63 | val path = "test-output.tfr" 64 | val testRows: Array[Row] = Array( 65 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), 66 | new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) 67 | val schema = StructType(List(StructField("id", IntegerType), 68 | StructField("IntegerTypelabel", IntegerType), 69 | StructField("LongTypelabel", LongType), 70 | StructField("FloatTypelabel", FloatType), 71 | StructField("DoubleTypelabel", DoubleType), 72 | StructField("vectorlabel", ArrayType(DoubleType, true)), 73 | StructField("name", StringType))) 74 | 75 | val rdd = spark.sparkContext.parallelize(testRows) 76 | 77 | //Save DataFrame as TFRecords 78 | val df: DataFrame = spark.createDataFrame(rdd, schema) 79 | df.write.format("tensorflow").save(path) 80 | 81 | //Read TFRecords into DataFrame. 82 | //The DataFrame schema is inferred from the TFRecords if no custom schema is provided. 83 | val importedDf1: DataFrame = spark.read.format("tensorflow").load(path) 84 | importedDf1.show() 85 | 86 | //Read TFRecords into DataFrame using custom schema 87 | val importedDf2: DataFrame = spark.read.format("tensorflow").schema(schema).load(path) 88 | importedDf2.show() 89 | 90 | ``` 91 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "spark-tensorflow-connector" 2 | 3 | organization := "org.trustedanalytics" 4 | 5 | scalaVersion in Global := "2.11.8" 6 | 7 | spName := "tapanalyticstoolkit/spark-tensorflow-connector" 8 | 9 | sparkVersion := "2.1.0" 10 | 11 | sparkComponents ++= Seq("sql", "mllib") 12 | 13 | version := "1.0.0" 14 | 15 | def ProjectName(name: String,path:String): Project = Project(name, file(path)) 16 | 17 | resolvers in Global ++= Seq("https://tap.jfrog.io/tap/public" at "https://tap.jfrog.io/tap/public" , 18 | "https://tap.jfrog.io/tap/public-snapshots" at "https://tap.jfrog.io/tap/public-snapshots" , 19 | "https://repo.maven.apache.org/maven2" at "https://repo.maven.apache.org/maven2" ) 20 | 21 | val `junit_junit` = "junit" % "junit" % "4.12" 22 | 23 | val `org.apache.hadoop_hadoop-yarn-api` = "org.apache.hadoop" % "hadoop-yarn-api" % "2.7.3" 24 | 25 | val `org.apache.spark_spark-core_2.11` = "org.apache.spark" % "spark-core_2.11" % "2.1.0" 26 | 27 | val `org.apache.spark_spark-sql_2.11` = "org.apache.spark" % "spark-sql_2.11" % "2.1.0" 28 | 29 | val `org.apache.spark_spark-mllib_2.11` = "org.apache.spark" % "spark-mllib_2.11" % "2.1.0" 30 | 31 | val `org.scalatest_scalatest_2.11` = "org.scalatest" % "scalatest_2.11" % "2.2.6" 32 | 33 | val `org.tensorflow_tensorflow-hadoop` = "org.tensorflow" % "tensorflow-hadoop" % "1.0-01232017-SNAPSHOT" 34 | 35 | libraryDependencies in Global ++= Seq(`org.tensorflow_tensorflow-hadoop` classifier "shaded-protobuf", 36 | `org.scalatest_scalatest_2.11` % "test" , 37 | `org.apache.spark_spark-sql_2.11` % "provided" , 38 | `org.apache.spark_spark-mllib_2.11` % "test" classifier "tests", 39 | `org.apache.spark_spark-core_2.11` % "provided" , 40 | `org.apache.hadoop_hadoop-yarn-api` % "provided" , 41 | `junit_junit` % "test" ) 42 | 43 | assemblyExcludedJars in assembly := { 44 | val cp = (fullClasspath in assembly).value 45 | cp filterNot {x => List("spark-tensorflow-connector-1.0-SNAPSHOT.jar", 46 | "tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar").contains(x.data.getName)} 47 | } 48 | 49 | /******************** 50 | * Release settings * 51 | ********************/ 52 | 53 | spIgnoreProvided := true 54 | 55 | spAppendScalaVersion := true 56 | 57 | // If you published your package to Maven Central for this release (must be done prior to spPublish) 58 | spIncludeMaven := false 59 | 60 | publishMavenStyle := true 61 | 62 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")) 63 | 64 | pomExtra := 65 | https://github.com/tapanalyticstoolkit/spark-tensorflow-connector 66 | 67 | git@github.com:tapanalyticstoolkit/spark-tensorflow-connector.git 68 | scm:git:git@github.com:tapanalyticstoolkit/spark-tensorflow-connector.git 69 | 70 | 71 | 72 | karthikvadla 73 | Karthik Vadla 74 | https://github.com/karthikvadla 75 | 76 | 77 | skavulya 78 | Soila Kavulya 79 | https://github.com/skavulya 80 | 81 | 82 | joyeshmishra 83 | Joyesh Mishra 84 | https://github.com/joyeshmishra 85 | 86 | 87 | 88 | credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") // A file containing credentials 89 | 90 | // Add assembly jar to Spark package 91 | test in assembly := {} 92 | 93 | spShade := true 94 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | org.trustedanalytics 8 | spark-tensorflow-connector 9 | jar 10 | 1.0-SNAPSHOT 11 | 12 | 13 | 14 | central1 15 | http://central1.maven.org/maven2 16 | 17 | true 18 | 19 | 20 | false 21 | 22 | 23 | 24 | 25 | tap 26 | https://tap.jfrog.io/tap/public 27 | 28 | false 29 | 30 | 31 | true 32 | 33 | 34 | 35 | tap-snapshots 36 | https://tap.jfrog.io/tap/public-snapshots 37 | 38 | true 39 | 40 | 41 | false 42 | 43 | 44 | 45 | 46 | 47 | 48 | compile 49 | 50 | true 51 | 52 | !NEVERSETME 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | true 61 | net.alchim31.maven 62 | scala-maven-plugin 63 | 3.1.6 64 | 65 | 66 | compile 67 | 68 | add-source 69 | compile 70 | 71 | 72 | 73 | -Xms256m 74 | -Xmx512m 75 | 76 | 77 | -g:vars 78 | -deprecation 79 | -feature 80 | -unchecked 81 | -Xfatal-warnings 82 | -language:implicitConversions 83 | -language:existentials 84 | 85 | 86 | 87 | 88 | test 89 | 90 | add-source 91 | testCompile 92 | 93 | 94 | 95 | 96 | incremental 97 | true 98 | 2.11 99 | false 100 | 101 | 102 | 103 | org.apache.maven.plugins 104 | maven-dependency-plugin 105 | 106 | 107 | copy-dependencies 108 | process-resources 109 | 110 | copy-dependencies 111 | 112 | 113 | provided 114 | true 115 | org.apache.spark,junit,org.scalatest 116 | ${project.build.directory}/lib 117 | 118 | 119 | 120 | 121 | 122 | 123 | org.codehaus.mojo 124 | properties-maven-plugin 125 | 1.0.0 126 | 127 | 128 | generate-resources 129 | 130 | write-project-properties 131 | 132 | 133 | ${project.build.outputDirectory}/maven.properties 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | net.alchim31.maven 143 | scala-maven-plugin 144 | 145 | 146 | 147 | 148 | 149 | 150 | test 151 | 152 | true 153 | 154 | !NEVERSETME 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | true 163 | net.alchim31.maven 164 | scala-maven-plugin 165 | 3.2.2 166 | 167 | 168 | compile 169 | 170 | 171 | 172 | 173 | true 174 | org.scalatest 175 | scalatest-maven-plugin 176 | 1.0 177 | 178 | ${project.build.directory}/surefire-reports 179 | . 180 | WDF TestSuite.txt 181 | false 182 | FTD 183 | -Xmx1024m -XX:PermSize=256m -XX:MaxDirectMemorySize=1000m 184 | 185 | 186 | 187 | scalaTest 188 | test 189 | 190 | test 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | net.alchim31.maven 201 | scala-maven-plugin 202 | 203 | 204 | 205 | 206 | 207 | 208 | org.scalatest 209 | scalatest_2.11 210 | 2.2.6 211 | test 212 | 213 | 214 | 215 | 216 | 217 | 218 | org.scalatest 219 | scalatest_2.11 220 | test 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | net.alchim31.maven 231 | scala-maven-plugin 232 | 233 | 234 | org.apache.maven.plugins 235 | maven-dependency-plugin 236 | 237 | 238 | org.scalatest 239 | scalatest-maven-plugin 240 | 241 | 242 | org.apache.maven.plugins 243 | maven-compiler-plugin 244 | 3.0 245 | 246 | 1.8 247 | 1.8 248 | 249 | 250 | 251 | 252 | 253 | 254 | core/src/main/resources 255 | 256 | reference.conf 257 | 258 | 259 | 260 | core/src/test/resources 261 | 262 | 263 | 264 | 265 | 266 | 267 | org.tensorflow 268 | tensorflow-hadoop 269 | 1.0-01232017-SNAPSHOT 270 | shaded-protobuf 271 | 272 | 273 | org.apache.spark 274 | spark-core_2.11 275 | 2.1.0 276 | provided 277 | 278 | 279 | org.apache.spark 280 | spark-sql_2.11 281 | 2.1.0 282 | provided 283 | 284 | 285 | org.apache.hadoop 286 | hadoop-yarn-api 287 | 2.7.3 288 | provided 289 | 290 | 291 | 292 | org.apache.spark 293 | spark-mllib_2.11 294 | 2.1.0 295 | test-jar 296 | test 297 | 298 | 299 | junit 300 | junit 301 | 4.12 302 | test 303 | 304 | 305 | 306 | 307 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.13 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" 2 | 3 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3") 4 | 5 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.5") 6 | -------------------------------------------------------------------------------- /src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | org.trustedanalytics.spark.datasources.tensorflow.DefaultSource -------------------------------------------------------------------------------- /src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/DataTypesConvertor.scala: -------------------------------------------------------------------------------- 1 | package org.trustedanalytics.spark.datasources.tensorflow 2 | 3 | /** 4 | * DataTypes supported 5 | */ 6 | object DataTypesConvertor { 7 | 8 | def toLong(value: Any): Long = { 9 | value match { 10 | case null => throw new IllegalArgumentException("null cannot be converted to Long") 11 | case i: Int => i.toLong 12 | case l: Long => l 13 | case f: Float => f.toLong 14 | case d: Double => d.toLong 15 | case bd: BigDecimal => bd.toLong 16 | case s: String => s.trim().toLong 17 | case _ => throw new RuntimeException(s"${value.getClass.getName} toLong is not implemented") 18 | } 19 | } 20 | 21 | def toFloat(value: Any): Float = { 22 | value match { 23 | case null => throw new IllegalArgumentException("null cannot be converted to Float") 24 | case i: Int => i.toFloat 25 | case l: Long => l.toFloat 26 | case f: Float => f 27 | case d: Double => d.toFloat 28 | case bd: BigDecimal => bd.toFloat 29 | case s: String => s.trim().toFloat 30 | case _ => throw new RuntimeException(s"${value.getClass.getName} toFloat is not implemented") 31 | } 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow 17 | 18 | import org.apache.hadoop.io.{BytesWritable, NullWritable} 19 | import org.apache.spark.sql._ 20 | import org.apache.spark.sql.sources._ 21 | import org.apache.spark.sql.types.StructType 22 | import org.tensorflow.hadoop.io.TFRecordFileOutputFormat 23 | import org.trustedanalytics.spark.datasources.tensorflow.serde.DefaultTfRecordRowEncoder 24 | 25 | /** 26 | * Provides access to TensorFlow record source 27 | */ 28 | class DefaultSource extends DataSourceRegister 29 | with CreatableRelationProvider 30 | with RelationProvider 31 | with SchemaRelationProvider{ 32 | 33 | /** 34 | * Short alias for spark-tensorflow data source. 35 | */ 36 | override def shortName(): String = "tensorflow" 37 | 38 | // Writes DataFrame as TensorFlow Records 39 | override def createRelation( 40 | sqlContext: SQLContext, 41 | mode: SaveMode, 42 | parameters: Map[String, String], 43 | data: DataFrame): BaseRelation = { 44 | 45 | val path = parameters("path") 46 | 47 | //Export DataFrame as TFRecords 48 | val features = data.rdd.map(row => { 49 | val example = DefaultTfRecordRowEncoder.encodeTfRecord(row) 50 | (new BytesWritable(example.toByteArray), NullWritable.get()) 51 | }) 52 | features.saveAsNewAPIHadoopFile[TFRecordFileOutputFormat](path) 53 | 54 | TensorflowRelation(parameters)(sqlContext.sparkSession) 55 | } 56 | 57 | override def createRelation(sqlContext: SQLContext, 58 | parameters: Map[String, String], 59 | schema: StructType): BaseRelation = { 60 | TensorflowRelation(parameters, Some(schema))(sqlContext.sparkSession) 61 | } 62 | 63 | // Reads TensorFlow Records into DataFrame 64 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): TensorflowRelation = { 65 | TensorflowRelation(parameters)(sqlContext.sparkSession) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/TensorflowInferSchema.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow 17 | 18 | import org.apache.spark.rdd.RDD 19 | import org.apache.spark.sql.types._ 20 | import org.tensorflow.example.{Example, Feature} 21 | import scala.collection.mutable.Map 22 | import scala.util.control.Exception._ 23 | import scala.collection.JavaConverters._ 24 | 25 | object TensorflowInferSchema { 26 | 27 | /** 28 | * Similar to the JSON schema inference. 29 | * [[org.apache.spark.sql.execution.datasources.json.InferSchema]] 30 | * 1. Infer type of each row 31 | * 2. Merge row types to find common type 32 | * 3. Replace any null types with string type 33 | */ 34 | def apply(exampleRdd: RDD[Example]): StructType = { 35 | val startType: Map[String, DataType] = Map.empty[String, DataType] 36 | val rootTypes: Map[String, DataType] = exampleRdd.aggregate(startType)(inferRowType, mergeFieldTypes) 37 | val columnsList = rootTypes.map { 38 | case (featureName, featureType) => 39 | if (featureType == null) { 40 | StructField(featureName, StringType) 41 | } 42 | else { 43 | StructField(featureName, featureType) 44 | } 45 | } 46 | StructType(columnsList.toSeq) 47 | } 48 | 49 | private def inferRowType(schemaSoFar: Map[String, DataType], next: Example): Map[String, DataType] = { 50 | next.getFeatures.getFeatureMap.asScala.map { 51 | case (featureName, feature) => { 52 | val currentType = inferField(feature) 53 | if (schemaSoFar.contains(featureName)) { 54 | val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) 55 | schemaSoFar(featureName) = updatedType.getOrElse(null) 56 | } 57 | else { 58 | schemaSoFar += (featureName -> currentType) 59 | } 60 | } 61 | } 62 | schemaSoFar 63 | } 64 | 65 | private def mergeFieldTypes(first: Map[String, DataType], second: Map[String, DataType]): Map[String, DataType] = { 66 | //Merge two maps and do the comparison. 67 | val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet) 68 | .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get)) 69 | .toSeq: _*) 70 | mutMap 71 | } 72 | 73 | /** 74 | * Infer Feature datatype based on field number 75 | */ 76 | private def inferField(feature: Feature): DataType = { 77 | feature.getKindCase.getNumber match { 78 | case Feature.BYTES_LIST_FIELD_NUMBER => { 79 | StringType 80 | } 81 | case Feature.INT64_LIST_FIELD_NUMBER => { 82 | parseInt64List(feature) 83 | } 84 | case Feature.FLOAT_LIST_FIELD_NUMBER => { 85 | parseFloatList(feature) 86 | } 87 | case _ => throw new RuntimeException("unsupported type ...") 88 | } 89 | } 90 | 91 | private def parseInt64List(feature: Feature): DataType = { 92 | val int64List = feature.getInt64List.getValueList.asScala.toArray 93 | val length = int64List.size 94 | if (length == 0) { 95 | null 96 | } 97 | else if (length > 1) { 98 | ArrayType(LongType) 99 | } 100 | else { 101 | val fieldValue = int64List(0).toString 102 | parseInteger(fieldValue) 103 | } 104 | } 105 | 106 | private def parseFloatList(feature: Feature): DataType = { 107 | val floatList = feature.getFloatList.getValueList.asScala.toArray 108 | val length = floatList.size 109 | if (length == 0) { 110 | null 111 | } 112 | else if (length > 1) { 113 | ArrayType(DoubleType) 114 | } 115 | else { 116 | val fieldValue = floatList(0).toString 117 | parseFloat(fieldValue) 118 | } 119 | } 120 | 121 | private def parseInteger(field: String): DataType = if (allCatch.opt(field.toInt).isDefined) { 122 | IntegerType 123 | } 124 | else { 125 | parseLong(field) 126 | } 127 | 128 | private def parseLong(field: String): DataType = if (allCatch.opt(field.toLong).isDefined) { 129 | LongType 130 | } 131 | else { 132 | throw new RuntimeException("Unable to parse field datatype to int64...") 133 | } 134 | 135 | private def parseFloat(field: String): DataType = { 136 | if ((allCatch opt field.toFloat).isDefined) { 137 | FloatType 138 | } 139 | else { 140 | parseDouble(field) 141 | } 142 | } 143 | 144 | private def parseDouble(field: String): DataType = if (allCatch.opt(field.toDouble).isDefined) { 145 | DoubleType 146 | } 147 | else { 148 | throw new RuntimeException("Unable to parse field datatype to float64...") 149 | } 150 | /** 151 | * Copied from internal Spark api 152 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] 153 | */ 154 | private val numericPrecedence: IndexedSeq[DataType] = 155 | IndexedSeq[DataType](IntegerType, 156 | LongType, 157 | FloatType, 158 | DoubleType, 159 | StringType) 160 | 161 | private def getNumericPrecedence(dataType: DataType): Int = { 162 | dataType match { 163 | case x if x.equals(IntegerType) => 0 164 | case x if x.equals(LongType) => 1 165 | case x if x.equals(FloatType) => 2 166 | case x if x.equals(DoubleType) => 3 167 | case x if x.equals(ArrayType(LongType)) => 4 168 | case x if x.equals(ArrayType(DoubleType)) => 5 169 | case x if x.equals(StringType) => 6 170 | case _ => throw new RuntimeException("Unable to get the precedence for given datatype...") 171 | } 172 | } 173 | 174 | /** 175 | * Copied from internal Spark api 176 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] 177 | */ 178 | private val findTightestCommonType: (DataType, DataType) => Option[DataType] = { 179 | case (t1, t2) if t1 == t2 => Some(t1) 180 | case (null, t2) => Some(t2) 181 | case (t1, null) => Some(t1) 182 | case (t1, t2) if t1.equals(ArrayType(LongType)) && t2.equals(ArrayType(DoubleType)) => Some(ArrayType(DoubleType)) 183 | case (t1, t2) if t1.equals(ArrayType(DoubleType)) && t2.equals(ArrayType(LongType)) => Some(ArrayType(DoubleType)) 184 | case (StringType, t2) => Some(StringType) 185 | case (t1, StringType) => Some(StringType) 186 | 187 | // Promote numeric types to the highest of the two and all numeric types to unlimited decimal 188 | case (t1, t2) => 189 | val t1Precedence = getNumericPrecedence(t1) 190 | val t2Precedence = getNumericPrecedence(t2) 191 | val newType = if (t1Precedence > t2Precedence) t1 else t2 192 | Some(newType) 193 | case _ => None 194 | } 195 | } 196 | 197 | -------------------------------------------------------------------------------- /src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/TensorflowRelation.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow 17 | 18 | import org.apache.hadoop.io.{BytesWritable, NullWritable} 19 | import org.apache.spark.rdd.RDD 20 | import org.apache.spark.sql.sources.{BaseRelation, TableScan} 21 | import org.apache.spark.sql.types.StructType 22 | import org.apache.spark.sql.{Row, SQLContext, SparkSession} 23 | import org.tensorflow.example.Example 24 | import org.tensorflow.hadoop.io.TFRecordFileInputFormat 25 | import org.trustedanalytics.spark.datasources.tensorflow.serde.DefaultTfRecordRowDecoder 26 | 27 | 28 | case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None)(@transient val session: SparkSession) extends BaseRelation with TableScan { 29 | 30 | //Import TFRecords as DataFrame happens here 31 | lazy val (tf_rdd, tf_schema) = { 32 | val rdd = session.sparkContext.newAPIHadoopFile(options("path"), classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable]) 33 | 34 | val exampleRdd = rdd.map { 35 | case (bytesWritable, nullWritable) => Example.parseFrom(bytesWritable.getBytes) 36 | } 37 | 38 | val finalSchema = customSchema.getOrElse(TensorflowInferSchema(exampleRdd)) 39 | 40 | (exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeTfRecord(example, finalSchema)), finalSchema) 41 | } 42 | 43 | override def sqlContext: SQLContext = session.sqlContext 44 | 45 | override def schema: StructType = tf_schema 46 | 47 | override def buildScan(): RDD[Row] = tf_rdd 48 | } 49 | 50 | -------------------------------------------------------------------------------- /src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/DefaultTfRecordRowDecoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow.serde 17 | 18 | import org.apache.spark.sql.types._ 19 | import org.apache.spark.sql.Row 20 | import org.tensorflow.example._ 21 | import scala.collection.JavaConverters._ 22 | 23 | trait TfRecordRowDecoder { 24 | /** 25 | * Decodes each TensorFlow "Example" as DataFrame "Row" 26 | * 27 | * Maps each feature in Example to element in Row with DataType based on custom schema or 28 | * default mapping of Int64List, FloatList, BytesList to column data type 29 | * 30 | * @param example TensorFlow Example to decode 31 | * @param schema Decode Example using specified schema 32 | * @return a DataFrame row 33 | */ 34 | def decodeTfRecord(example: Example, schema: StructType): Row 35 | } 36 | 37 | object DefaultTfRecordRowDecoder extends TfRecordRowDecoder { 38 | 39 | /** 40 | * Decodes each TensorFlow "Example" as DataFrame "Row" 41 | * 42 | * Maps each feature in Example to element in Row with DataType based on custom schema 43 | * 44 | * @param example TensorFlow Example to decode 45 | * @param schema Decode Example using specified schema 46 | * @return a DataFrame row 47 | */ 48 | def decodeTfRecord(example: Example, schema: StructType): Row = { 49 | val row = Array.fill[Any](schema.length)(null) 50 | example.getFeatures.getFeatureMap.asScala.foreach { 51 | case (featureName, feature) => 52 | val index = schema.fieldIndex(featureName) 53 | val colDataType = schema.fields(index).dataType 54 | row(index) = colDataType match { 55 | case IntegerType => IntFeatureDecoder.decode(feature) 56 | case LongType => LongFeatureDecoder.decode(feature) 57 | case FloatType => FloatFeatureDecoder.decode(feature) 58 | case DoubleType => DoubleFeatureDecoder.decode(feature) 59 | case ArrayType(IntegerType, true) => IntListFeatureDecoder.decode(feature) 60 | case ArrayType(LongType, _) => LongListFeatureDecoder.decode(feature) 61 | case ArrayType(FloatType, _) => FloatListFeatureDecoder.decode(feature) 62 | case ArrayType(DoubleType, _) => DoubleListFeatureDecoder.decode(feature) 63 | case StringType => StringFeatureDecoder.decode(feature) 64 | case _ => throw new RuntimeException(s"Cannot convert feature to unsupported data type ${colDataType}") 65 | } 66 | } 67 | Row.fromSeq(row) 68 | } 69 | } 70 | 71 | -------------------------------------------------------------------------------- /src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/DefaultTfRecordRowEncoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow.serde 17 | 18 | import org.apache.spark.sql.Row 19 | import org.apache.spark.sql.types._ 20 | import org.tensorflow.example._ 21 | 22 | trait TfRecordRowEncoder { 23 | /** 24 | * Encodes each Row as TensorFlow "Example" 25 | * 26 | * Maps each column in Row to one of Int64List, FloatList, BytesList based on the column data type 27 | * 28 | * @param row a DataFrame row 29 | * @return TensorFlow Example 30 | */ 31 | def encodeTfRecord(row: Row): Example 32 | } 33 | 34 | object DefaultTfRecordRowEncoder extends TfRecordRowEncoder { 35 | 36 | /** 37 | * Encodes each Row as TensorFlow "Example" 38 | * 39 | * Maps each column in Row to one of Int64List, FloatList, BytesList based on the column data type 40 | * 41 | * @param row a DataFrame row 42 | * @return TensorFlow Example 43 | */ 44 | def encodeTfRecord(row: Row): Example = { 45 | val features = Features.newBuilder() 46 | val example = Example.newBuilder() 47 | 48 | row.schema.zipWithIndex.map { 49 | case (structField, index) => 50 | val value = row.get(index) 51 | val feature = structField.dataType match { 52 | case IntegerType | LongType => Int64ListFeatureEncoder.encode(value) 53 | case FloatType | DoubleType => FloatListFeatureEncoder.encode(value) 54 | case ArrayType(IntegerType, _) | ArrayType(LongType, _) => Int64ListFeatureEncoder.encode(value) 55 | case ArrayType(DoubleType, _) => FloatListFeatureEncoder.encode(value) 56 | case _ => BytesListFeatureEncoder.encode(value) 57 | } 58 | features.putFeature(structField.name, feature) 59 | } 60 | 61 | features.build() 62 | example.setFeatures(features) 63 | example.build() 64 | } 65 | } 66 | 67 | -------------------------------------------------------------------------------- /src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/FeatureDecoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow.serde 17 | 18 | import org.tensorflow.example.Feature 19 | import scala.collection.JavaConverters._ 20 | 21 | trait FeatureDecoder[T] { 22 | /** 23 | * Decodes each TensorFlow "Feature" to desired Scala type 24 | * 25 | * @param feature TensorFlow Feature 26 | * @return Decoded feature 27 | */ 28 | def decode(feature: Feature): T 29 | } 30 | 31 | /** 32 | * Decode TensorFlow "Feature" to Integer 33 | */ 34 | object IntFeatureDecoder extends FeatureDecoder[Int] { 35 | override def decode(feature: Feature): Int = { 36 | require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") 37 | try { 38 | val int64List = feature.getInt64List.getValueList 39 | require(int64List.size() == 1, "Length of Int64List must equal 1") 40 | int64List.get(0).intValue() 41 | } 42 | catch { 43 | case ex: Exception => 44 | throw new RuntimeException(s"Cannot convert feature to Int.", ex) 45 | } 46 | } 47 | } 48 | 49 | /** 50 | * Decode TensorFlow "Feature" to Seq[Int] 51 | */ 52 | object IntListFeatureDecoder extends FeatureDecoder[Seq[Int]] { 53 | override def decode(feature: Feature): Seq[Int] = { 54 | require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") 55 | try { 56 | val array = feature.getInt64List.getValueList.asScala.toArray 57 | array.map(_.toInt) 58 | } 59 | catch { 60 | case ex: Exception => 61 | throw new RuntimeException(s"Cannot convert feature to Seq[Int].", ex) 62 | } 63 | } 64 | } 65 | 66 | /** 67 | * Decode TensorFlow "Feature" to Long 68 | */ 69 | object LongFeatureDecoder extends FeatureDecoder[Long] { 70 | override def decode(feature: Feature): Long = { 71 | require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") 72 | try { 73 | val int64List = feature.getInt64List.getValueList 74 | require(int64List.size() == 1, "Length of Int64List must equal 1") 75 | int64List.get(0).longValue() 76 | } 77 | catch { 78 | case ex: Exception => 79 | throw new RuntimeException(s"Cannot convert feature to Long.", ex) 80 | } 81 | } 82 | } 83 | 84 | /** 85 | * Decode TensorFlow "Feature" to Seq[Long] 86 | */ 87 | object LongListFeatureDecoder extends FeatureDecoder[Seq[Long]] { 88 | override def decode(feature: Feature): Seq[Long] = { 89 | require(feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") 90 | try { 91 | val array = feature.getInt64List.getValueList.asScala.toArray 92 | array.map(_.toLong) 93 | } 94 | catch { 95 | case ex: Exception => 96 | throw new RuntimeException(s"Cannot convert feature to Array[Long].", ex) 97 | } 98 | } 99 | } 100 | 101 | /** 102 | * Decode TensorFlow "Feature" to Float 103 | */ 104 | object FloatFeatureDecoder extends FeatureDecoder[Float] { 105 | override def decode(feature: Feature): Float = { 106 | require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") 107 | try { 108 | val floatList = feature.getFloatList.getValueList 109 | require(floatList.size() == 1, "Length of FloatList must equal 1") 110 | floatList.get(0).floatValue() 111 | } 112 | catch { 113 | case ex: Exception => 114 | throw new RuntimeException(s"Cannot convert feature to Float.", ex) 115 | } 116 | } 117 | } 118 | 119 | /** 120 | * Decode TensorFlow "Feature" to Seq[Float] 121 | */ 122 | object FloatListFeatureDecoder extends FeatureDecoder[Seq[Float]] { 123 | override def decode(feature: Feature): Seq[Float] = { 124 | require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") 125 | try { 126 | val array = feature.getFloatList.getValueList.asScala.toArray 127 | array.map(_.toFloat) 128 | } 129 | catch { 130 | case ex: Exception => 131 | throw new RuntimeException(s"Cannot convert feature to Array[Float].", ex) 132 | } 133 | } 134 | } 135 | 136 | /** 137 | * Decode TensorFlow "Feature" to Double 138 | */ 139 | object DoubleFeatureDecoder extends FeatureDecoder[Double] { 140 | override def decode(feature: Feature): Double = { 141 | require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") 142 | try { 143 | val floatList = feature.getFloatList.getValueList 144 | require(floatList.size() == 1, "Length of FloatList must equal 1") 145 | floatList.get(0).doubleValue() 146 | } 147 | catch { 148 | case ex: Exception => 149 | throw new RuntimeException(s"Cannot convert feature to Double.", ex) 150 | } 151 | } 152 | } 153 | 154 | /** 155 | * Decode TensorFlow "Feature" to Seq[Double] 156 | */ 157 | object DoubleListFeatureDecoder extends FeatureDecoder[Seq[Double]] { 158 | override def decode(feature: Feature): Seq[Double] = { 159 | require(feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") 160 | try { 161 | val array = feature.getFloatList.getValueList.asScala.toArray 162 | array.map(_.toDouble) 163 | } 164 | catch { 165 | case ex: Exception => 166 | throw new RuntimeException(s"Cannot convert feature to Array[Double].", ex) 167 | } 168 | } 169 | } 170 | 171 | /** 172 | * Decode TensorFlow "Feature" to String 173 | */ 174 | object StringFeatureDecoder extends FeatureDecoder[String] { 175 | override def decode(feature: Feature): String = { 176 | require(feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList") 177 | try { 178 | feature.getBytesList.toByteString.toStringUtf8.trim 179 | } 180 | catch { 181 | case ex: Exception => 182 | throw new RuntimeException(s"Cannot convert feature to String.", ex) 183 | } 184 | } 185 | } 186 | 187 | -------------------------------------------------------------------------------- /src/main/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/FeatureEncoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow.serde 17 | 18 | import org.tensorflow.example.{BytesList, Feature, FloatList, Int64List} 19 | import org.tensorflow.hadoop.shaded.protobuf.ByteString 20 | import org.trustedanalytics.spark.datasources.tensorflow.DataTypesConvertor 21 | 22 | trait FeatureEncoder { 23 | /** 24 | * Encodes input value as TensorFlow "Feature" 25 | * 26 | * Maps input value to one of Int64List, FloatList, BytesList 27 | * 28 | * @param value Input value 29 | * @return TensorFlow Feature 30 | */ 31 | def encode(value: Any): Feature 32 | } 33 | 34 | /** 35 | * Encode input value to Int64List 36 | */ 37 | object Int64ListFeatureEncoder extends FeatureEncoder { 38 | override def encode(value: Any): Feature = { 39 | try { 40 | val int64List = value match { 41 | case i: Int => Int64List.newBuilder().addValue(i.toLong).build() 42 | case l: Long => Int64List.newBuilder().addValue(l).build() 43 | case arr: scala.collection.mutable.WrappedArray[_] => toInt64List(arr.toArray[Any]) 44 | case arr: Array[_] => toInt64List(arr) 45 | case seq: Seq[_] => toInt64List(seq.toArray[Any]) 46 | case _ => throw new RuntimeException(s"Cannot convert object $value to Int64List") 47 | } 48 | Feature.newBuilder().setInt64List(int64List).build() 49 | } 50 | catch { 51 | case ex: Exception => 52 | throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to Int64List feature.", ex) 53 | } 54 | } 55 | 56 | private def toInt64List[T](arr: Array[T]): Int64List = { 57 | val intListBuilder = Int64List.newBuilder() 58 | arr.foreach(x => { 59 | require(x != null, "Int64List with null values is not supported") 60 | val longValue = DataTypesConvertor.toLong(x) 61 | intListBuilder.addValue(longValue) 62 | }) 63 | intListBuilder.build() 64 | } 65 | } 66 | 67 | /** 68 | * Encode input value to FloatList 69 | */ 70 | object FloatListFeatureEncoder extends FeatureEncoder { 71 | override def encode(value: Any): Feature = { 72 | try { 73 | val floatList = value match { 74 | case i: Int => FloatList.newBuilder().addValue(i.toFloat).build() 75 | case l: Long => FloatList.newBuilder().addValue(l.toFloat).build() 76 | case f: Float => FloatList.newBuilder().addValue(f).build() 77 | case d: Double => FloatList.newBuilder().addValue(d.toFloat).build() 78 | case arr: scala.collection.mutable.WrappedArray[_] => toFloatList(arr.toArray[Any]) 79 | case arr: Array[_] => toFloatList(arr) 80 | case seq: Seq[_] => toFloatList(seq.toArray[Any]) 81 | case _ => throw new RuntimeException(s"Cannot convert object $value to FloatList") 82 | } 83 | Feature.newBuilder().setFloatList(floatList).build() 84 | } 85 | catch { 86 | case ex: Exception => 87 | throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to FloatList feature.", ex) 88 | } 89 | } 90 | 91 | private def toFloatList[T](arr: Array[T]): FloatList = { 92 | val floatListBuilder = FloatList.newBuilder() 93 | arr.foreach(x => { 94 | require(x != null, "FloatList with null values is not supported") 95 | val longValue = DataTypesConvertor.toFloat(x) 96 | floatListBuilder.addValue(longValue) 97 | }) 98 | floatListBuilder.build() 99 | } 100 | } 101 | 102 | /** 103 | * Encode input value to ByteList 104 | */ 105 | object BytesListFeatureEncoder extends FeatureEncoder { 106 | override def encode(value: Any): Feature = { 107 | try { 108 | val byteList = BytesList.newBuilder().addValue(ByteString.copyFrom(value.toString.getBytes)).build() 109 | Feature.newBuilder().setBytesList(byteList).build() 110 | } 111 | catch { 112 | case ex: Exception => 113 | throw new RuntimeException(s"Cannot convert object $value of type ${value.getClass} to ByteList feature.", ex) 114 | } 115 | } 116 | } 117 | 118 | 119 | -------------------------------------------------------------------------------- /src/test/scala/org/trustedanalytics/spark/datasources/tensorflow/SharedSparkSessionSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.trustedanalytics.spark.datasources.tensorflow 18 | 19 | import java.io.File 20 | 21 | import org.apache.commons.io.FileUtils 22 | import org.apache.spark.SharedSparkSession 23 | import org.junit.{After, Before} 24 | import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} 25 | 26 | 27 | trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll 28 | 29 | class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { 30 | val TF_SANDBOX_DIR = "tf-sandbox" 31 | val file = new File(TF_SANDBOX_DIR) 32 | 33 | @Before 34 | override def beforeAll() = { 35 | super.setUp() 36 | FileUtils.deleteQuietly(file) 37 | file.mkdirs() 38 | } 39 | 40 | @After 41 | override def afterAll() = { 42 | FileUtils.deleteQuietly(file) 43 | super.tearDown() 44 | } 45 | } 46 | 47 | -------------------------------------------------------------------------------- /src/test/scala/org/trustedanalytics/spark/datasources/tensorflow/TensorflowSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.trustedanalytics.spark.datasources.tensorflow 18 | 19 | import org.apache.spark.rdd.RDD 20 | import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} 21 | import org.apache.spark.sql.types._ 22 | import org.apache.spark.sql.{DataFrame, Row} 23 | import org.tensorflow.example._ 24 | import org.tensorflow.hadoop.shaded.protobuf.ByteString 25 | import org.trustedanalytics.spark.datasources.tensorflow.serde.{DefaultTfRecordRowDecoder, DefaultTfRecordRowEncoder} 26 | import scala.collection.JavaConverters._ 27 | 28 | class TensorflowSuite extends SharedSparkSessionSuite { 29 | 30 | "Spark TensorFlow module" should { 31 | 32 | "Test Import/Export" in { 33 | 34 | val path = s"$TF_SANDBOX_DIR/output25.tfr" 35 | val testRows: Array[Row] = Array( 36 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), 37 | new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) 38 | 39 | val schema = StructType(List( 40 | StructField("id", IntegerType), 41 | StructField("IntegerTypelabel", IntegerType), 42 | StructField("LongTypelabel", LongType), 43 | StructField("FloatTypelabel", FloatType), 44 | StructField("DoubleTypelabel", DoubleType), 45 | StructField("vectorlabel", ArrayType(DoubleType, true)), 46 | StructField("name", StringType))) 47 | 48 | val rdd = spark.sparkContext.parallelize(testRows) 49 | 50 | val df: DataFrame = spark.createDataFrame(rdd, schema) 51 | df.write.format("tensorflow").save(path) 52 | 53 | //If schema is not provided. It will automatically infer schema 54 | val importedDf: DataFrame = spark.read.format("tensorflow").schema(schema).load(path) 55 | val actualDf = importedDf.select("id", "IntegerTypelabel", "LongTypelabel", "FloatTypelabel", "DoubleTypelabel", "vectorlabel", "name").sort("name") 56 | 57 | val expectedRows = df.collect() 58 | val actualRows = actualDf.collect() 59 | 60 | expectedRows should equal(actualRows) 61 | } 62 | 63 | "Encode given Row as TensorFlow example" in { 64 | val schemaStructType = StructType(Array( 65 | StructField("IntegerTypelabel", IntegerType), 66 | StructField("LongTypelabel", LongType), 67 | StructField("FloatTypelabel", FloatType), 68 | StructField("DoubleTypelabel", DoubleType), 69 | StructField("vectorlabel", ArrayType(DoubleType, true)), 70 | StructField("strlabel", StringType) 71 | )) 72 | val doubleArray = Array(1.1, 111.1, 11111.1) 73 | val expectedFloatArray = Array(1.1F, 111.1F, 11111.1F) 74 | 75 | val rowWithSchema = new GenericRowWithSchema(Array[Any](1, 23L, 10.0F, 14.0, doubleArray, "r1"), schemaStructType) 76 | 77 | //Encode Sql Row to TensorFlow example 78 | val example = DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema) 79 | import org.tensorflow.example.Feature 80 | 81 | //Verify each Datatype converted to TensorFlow datatypes 82 | val featureMap = example.getFeatures.getFeatureMap.asScala 83 | assert(featureMap("IntegerTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) 84 | assert(featureMap("IntegerTypelabel").getInt64List.getValue(0).toInt == 1) 85 | 86 | assert(featureMap("LongTypelabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) 87 | assert(featureMap("LongTypelabel").getInt64List.getValue(0).toInt == 23) 88 | 89 | assert(featureMap("FloatTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 90 | assert(featureMap("FloatTypelabel").getFloatList.getValue(0) == 10.0F) 91 | 92 | assert(featureMap("DoubleTypelabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 93 | assert(featureMap("DoubleTypelabel").getFloatList.getValue(0) == 14.0F) 94 | 95 | assert(featureMap("vectorlabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 96 | assert(featureMap("vectorlabel").getFloatList.getValueList.toArray === expectedFloatArray) 97 | 98 | assert(featureMap("strlabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) 99 | assert(featureMap("strlabel").getBytesList.toByteString.toStringUtf8.trim == "r1") 100 | 101 | } 102 | 103 | "Throw an exception for a vector with null values during Encode" in { 104 | intercept[Exception] { 105 | val schemaStructType = StructType(Array( 106 | StructField("vectorlabel", ArrayType(DoubleType, true)) 107 | )) 108 | val doubleArray = Array(1.1, null, 111.1, null, 11111.1) 109 | 110 | val rowWithSchema = new GenericRowWithSchema(Array[Any](doubleArray), schemaStructType) 111 | 112 | //Throws NullPointerException 113 | DefaultTfRecordRowEncoder.encodeTfRecord(rowWithSchema) 114 | } 115 | } 116 | 117 | "Decode given TensorFlow Example as Row" in { 118 | 119 | //Here Vector with null's are not supported 120 | val expectedRow = new GenericRow(Array[Any](1, 23L, 10.0F, 14.0, Seq(1.0, 2.0), "r1")) 121 | 122 | val schema = StructType(List( 123 | StructField("IntegerTypelabel", IntegerType), 124 | StructField("LongTypelabel", LongType), 125 | StructField("FloatTypelabel", FloatType), 126 | StructField("DoubleTypelabel", DoubleType), 127 | StructField("vectorlabel", ArrayType(DoubleType)), 128 | StructField("strlabel", StringType))) 129 | 130 | //Build example 131 | val intFeature = Int64List.newBuilder().addValue(1) 132 | val longFeature = Int64List.newBuilder().addValue(23L) 133 | val floatFeature = FloatList.newBuilder().addValue(10.0F) 134 | val doubleFeature = FloatList.newBuilder().addValue(14.0F) 135 | val vectorFeature = FloatList.newBuilder().addValue(1F).addValue(2F).build() 136 | val strFeature = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build() 137 | val features = Features.newBuilder() 138 | .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature).build()) 139 | .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature).build()) 140 | .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature).build()) 141 | .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature).build()) 142 | .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature).build()) 143 | .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature).build()) 144 | .build() 145 | val example = Example.newBuilder() 146 | .setFeatures(features) 147 | .build() 148 | 149 | //Decode TensorFlow example to Sql Row 150 | val actualRow = DefaultTfRecordRowDecoder.decodeTfRecord(example, schema) 151 | actualRow should equal(expectedRow) 152 | } 153 | 154 | "Check infer schema" in { 155 | 156 | //Build example1 157 | val intFeature1 = Int64List.newBuilder().addValue(1) 158 | val longFeature1 = Int64List.newBuilder().addValue(Int.MaxValue + 10L) 159 | val floatFeature1 = FloatList.newBuilder().addValue(10.0F) 160 | val doubleFeature1 = FloatList.newBuilder().addValue(14.0F) 161 | val vectorFeature1 = FloatList.newBuilder().addValue(1F).build() 162 | val strFeature1 = BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build() 163 | val features1 = Features.newBuilder() 164 | .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature1).build()) 165 | .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature1).build()) 166 | .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature1).build()) 167 | .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature1).build()) 168 | .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature1).build()) 169 | .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature1).build()) 170 | .build() 171 | val example1 = Example.newBuilder() 172 | .setFeatures(features1) 173 | .build() 174 | 175 | //Build example2 176 | val intFeature2 = Int64List.newBuilder().addValue(2) 177 | val longFeature2 = Int64List.newBuilder().addValue(24) 178 | val floatFeature2 = FloatList.newBuilder().addValue(12.0F) 179 | val doubleFeature2 = FloatList.newBuilder().addValue(Float.MaxValue + 15) 180 | val vectorFeature2 = FloatList.newBuilder().addValue(2F).addValue(2F).build() 181 | val strFeature2 = BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)).build() 182 | val features2 = Features.newBuilder() 183 | .putFeature("IntegerTypelabel", Feature.newBuilder().setInt64List(intFeature2).build()) 184 | .putFeature("LongTypelabel", Feature.newBuilder().setInt64List(longFeature2).build()) 185 | .putFeature("FloatTypelabel", Feature.newBuilder().setFloatList(floatFeature2).build()) 186 | .putFeature("DoubleTypelabel", Feature.newBuilder().setFloatList(doubleFeature2).build()) 187 | .putFeature("vectorlabel", Feature.newBuilder().setFloatList(vectorFeature2).build()) 188 | .putFeature("strlabel", Feature.newBuilder().setBytesList(strFeature2).build()) 189 | .build() 190 | val example2 = Example.newBuilder() 191 | .setFeatures(features2) 192 | .build() 193 | 194 | val exampleRDD: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2)) 195 | 196 | val actualSchema = TensorflowInferSchema(exampleRDD) 197 | 198 | //Verify each TensorFlow Datatype is inferred as one of our Datatype 199 | actualSchema.fields.map { colum => 200 | colum.name match { 201 | case "IntegerTypelabel" => colum.dataType.equals(IntegerType) 202 | case "LongTypelabel" => colum.dataType.equals(LongType) 203 | case "FloatTypelabel" | "DoubleTypelabel" | "vectorlabel" => colum.dataType.equals(FloatType) 204 | case "strlabel" => colum.dataType.equals(StringType) 205 | } 206 | } 207 | } 208 | } 209 | } 210 | 211 | -------------------------------------------------------------------------------- /src/test/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/FeatureDecoderTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow.serde 17 | 18 | import org.scalatest.{Matchers, WordSpec} 19 | import org.tensorflow.example.{BytesList, FloatList, Feature, Int64List} 20 | import org.tensorflow.hadoop.shaded.protobuf.ByteString 21 | 22 | class FeatureDecoderTest extends WordSpec with Matchers { 23 | 24 | "Int Feature decoder" should { 25 | 26 | "Decode Feature to Int" in { 27 | val int64List = Int64List.newBuilder().addValue(4).build() 28 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 29 | IntFeatureDecoder.decode(intFeature) should equal(4) 30 | } 31 | 32 | "Throw an exception if length of feature array exceeds 1" in { 33 | intercept[Exception] { 34 | val int64List = Int64List.newBuilder().addValue(4).addValue(7).build() 35 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 36 | IntFeatureDecoder.decode(intFeature) 37 | } 38 | } 39 | 40 | "Throw an exception if feature is not an Int64List" in { 41 | intercept[Exception] { 42 | val floatList = FloatList.newBuilder().addValue(4).build() 43 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 44 | IntFeatureDecoder.decode(floatFeature) 45 | } 46 | } 47 | } 48 | 49 | "Int List Feature decoder" should { 50 | 51 | "Decode Feature to Int List" in { 52 | val int64List = Int64List.newBuilder().addValue(3).addValue(9).build() 53 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 54 | IntListFeatureDecoder.decode(intFeature) should equal(Seq(3,9)) 55 | } 56 | 57 | "Throw an exception if feature is not an Int64List" in { 58 | intercept[Exception] { 59 | val floatList = FloatList.newBuilder().addValue(4).build() 60 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 61 | IntListFeatureDecoder.decode(floatFeature) 62 | } 63 | } 64 | } 65 | 66 | "Long Feature decoder" should { 67 | 68 | "Decode Feature to Long" in { 69 | val int64List = Int64List.newBuilder().addValue(5L).build() 70 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 71 | LongFeatureDecoder.decode(intFeature) should equal(5L) 72 | } 73 | 74 | "Throw an exception if length of feature array exceeds 1" in { 75 | intercept[Exception] { 76 | val int64List = Int64List.newBuilder().addValue(4L).addValue(10L).build() 77 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 78 | LongFeatureDecoder.decode(intFeature) 79 | } 80 | } 81 | 82 | "Throw an exception if feature is not an Int64List" in { 83 | intercept[Exception] { 84 | val floatList = FloatList.newBuilder().addValue(4).build() 85 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 86 | LongFeatureDecoder.decode(floatFeature) 87 | } 88 | } 89 | } 90 | 91 | "Long List Feature decoder" should { 92 | 93 | "Decode Feature to Long List" in { 94 | val int64List = Int64List.newBuilder().addValue(3L).addValue(Int.MaxValue+10L).build() 95 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 96 | LongListFeatureDecoder.decode(intFeature) should equal(Seq(3L,Int.MaxValue+10L)) 97 | } 98 | 99 | "Throw an exception if feature is not an Int64List" in { 100 | intercept[Exception] { 101 | val floatList = FloatList.newBuilder().addValue(4).build() 102 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 103 | LongListFeatureDecoder.decode(floatFeature) 104 | } 105 | } 106 | } 107 | 108 | "Float Feature decoder" should { 109 | 110 | "Decode Feature to Float" in { 111 | val floatList = FloatList.newBuilder().addValue(2.5F).build() 112 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 113 | FloatFeatureDecoder.decode(floatFeature) should equal(2.5F) 114 | } 115 | 116 | "Throw an exception if length of feature array exceeds 1" in { 117 | intercept[Exception] { 118 | val floatList = FloatList.newBuilder().addValue(1.5F).addValue(3.33F).build() 119 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 120 | FloatFeatureDecoder.decode(floatFeature) 121 | } 122 | } 123 | 124 | "Throw an exception if feature is not a FloatList" in { 125 | intercept[Exception] { 126 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() 127 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() 128 | FloatFeatureDecoder.decode(bytesFeature) 129 | } 130 | } 131 | } 132 | 133 | "Float List Feature decoder" should { 134 | 135 | "Decode Feature to Float List" in { 136 | val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.3F).build() 137 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 138 | FloatListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5F, 4.3F)) 139 | } 140 | 141 | "Throw an exception if feature is not a FloatList" in { 142 | intercept[Exception] { 143 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() 144 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() 145 | FloatListFeatureDecoder.decode(bytesFeature) 146 | } 147 | } 148 | } 149 | 150 | "Double Feature decoder" should { 151 | 152 | "Decode Feature to Double" in { 153 | val floatList = FloatList.newBuilder().addValue(2.5F).build() 154 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 155 | DoubleFeatureDecoder.decode(floatFeature) should equal(2.5d) 156 | } 157 | 158 | "Throw an exception if length of feature array exceeds 1" in { 159 | intercept[Exception] { 160 | val floatList = FloatList.newBuilder().addValue(1.5F).addValue(3.33F).build() 161 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 162 | DoubleFeatureDecoder.decode(floatFeature) 163 | } 164 | } 165 | 166 | "Throw an exception if feature is not a FloatList" in { 167 | intercept[Exception] { 168 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() 169 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() 170 | DoubleFeatureDecoder.decode(bytesFeature) 171 | } 172 | } 173 | } 174 | 175 | "Double List Feature decoder" should { 176 | 177 | "Decode Feature to Double List" in { 178 | val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build() 179 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 180 | DoubleListFeatureDecoder.decode(floatFeature) should equal(Seq(2.5d, 4.0d)) 181 | } 182 | 183 | "Throw an exception if feature is not a DoubleList" in { 184 | intercept[Exception] { 185 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() 186 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() 187 | FloatListFeatureDecoder.decode(bytesFeature) 188 | } 189 | } 190 | } 191 | 192 | "Bytes List Feature decoder" should { 193 | 194 | "Decode Feature to Bytes List" in { 195 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() 196 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() 197 | StringFeatureDecoder.decode(bytesFeature) should equal("str-input") 198 | } 199 | 200 | "Throw an exception if feature is not a BytesList" in { 201 | intercept[Exception] { 202 | val floatList = FloatList.newBuilder().addValue(2.5F).addValue(4.0F).build() 203 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 204 | StringFeatureDecoder.decode(floatFeature) 205 | } 206 | } 207 | } 208 | } 209 | 210 | -------------------------------------------------------------------------------- /src/test/scala/org/trustedanalytics/spark/datasources/tensorflow/serde/FeatureEncoderTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016 Intel Corporation  3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.trustedanalytics.spark.datasources.tensorflow.serde 17 | 18 | import org.scalatest.{Matchers, WordSpec} 19 | import scala.collection.JavaConverters._ 20 | 21 | class FeatureEncoderTest extends WordSpec with Matchers { 22 | 23 | "Int64List feature encoder" should { 24 | "Encode inputs to Int64List" in { 25 | val intFeature = Int64ListFeatureEncoder.encode(5) 26 | val longFeature = Int64ListFeatureEncoder.encode(10L) 27 | val longListFeature = Int64ListFeatureEncoder.encode(Seq(3L,5L,6L)) 28 | 29 | intFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(5L)) 30 | longFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(10L)) 31 | longListFeature.getInt64List.getValueList.asScala.toSeq should equal (Seq(3L, 5L, 6L)) 32 | } 33 | 34 | "Throw an exception when inputs contain null" in { 35 | intercept[Exception] { 36 | Int64ListFeatureEncoder.encode(null) 37 | } 38 | intercept[Exception] { 39 | Int64ListFeatureEncoder.encode(Seq(3,null,6)) 40 | } 41 | } 42 | 43 | "Throw an exception for non-numeric inputs" in { 44 | intercept[Exception] { 45 | Int64ListFeatureEncoder.encode("bad-input") 46 | } 47 | } 48 | } 49 | 50 | "FloatList feature encoder" should { 51 | "Encode inputs to FloatList" in { 52 | val intFeature = FloatListFeatureEncoder.encode(5) 53 | val longFeature = FloatListFeatureEncoder.encode(10L) 54 | val floatFeature = FloatListFeatureEncoder.encode(2.5F) 55 | val doubleFeature = FloatListFeatureEncoder.encode(14.6) 56 | val floatListFeature = FloatListFeatureEncoder.encode(Seq(1.5F,6.8F,-3.2F)) 57 | 58 | intFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(5F)) 59 | longFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(10F)) 60 | floatFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(2.5F)) 61 | doubleFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(14.6F)) 62 | floatListFeature.getFloatList.getValueList.asScala.toSeq should equal (Seq(1.5F,6.8F,-3.2F)) 63 | } 64 | 65 | "Throw an exception when inputs contain null" in { 66 | intercept[Exception] { 67 | FloatListFeatureEncoder.encode(null) 68 | } 69 | intercept[Exception] { 70 | FloatListFeatureEncoder.encode(Seq(3,null,6)) 71 | } 72 | } 73 | 74 | "Throw an exception for non-numeric inputs" in { 75 | intercept[Exception] { 76 | FloatListFeatureEncoder.encode("bad-input") 77 | } 78 | } 79 | } 80 | 81 | "ByteList feature encoder" should { 82 | "Encode inputs to ByteList" in { 83 | val longFeature = BytesListFeatureEncoder.encode(10L) 84 | val longListFeature = BytesListFeatureEncoder.encode(Seq(3L,5L,6L)) 85 | val strFeature = BytesListFeatureEncoder.encode("str-input") 86 | 87 | longFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("10") 88 | longListFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("List(3, 5, 6)") 89 | strFeature.getBytesList.toByteString.toStringUtf8.trim should equal ("str-input") 90 | } 91 | 92 | "Throw an exception when inputs contain null" in { 93 | intercept[Exception] { 94 | BytesListFeatureEncoder.encode(null) 95 | } 96 | } 97 | } 98 | } 99 | --------------------------------------------------------------------------------