├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE.md ├── NOTICE ├── README.md ├── pom.xml ├── project ├── build.properties └── plugins.sbt └── src ├── main ├── resources │ └── META-INF │ │ └── services │ │ └── org.apache.spark.sql.sources.DataSourceRegister └── scala │ └── com │ └── linkedin │ └── spark │ └── datasources │ └── tfrecord │ ├── DefaultSource.scala │ ├── TFRecordDeserializer.scala │ ├── TFRecordFileReader.scala │ ├── TFRecordOutputWriter.scala │ ├── TFRecordSerializer.scala │ └── TensorFlowInferSchema.scala └── test └── scala └── com └── linkedin └── spark └── datasources └── tfrecord ├── InferSchemaSuite.scala ├── SharedSparkSessionSuite.scala ├── TFRecordDeserializerTest.scala ├── TFRecordIOSuite.scala ├── TFRecordSerializerTest.scala └── TestingUtils.scala /.gitignore: -------------------------------------------------------------------------------- 1 | out 2 | *.egg 3 | *.egg-info/ 4 | *.iml 5 | *.ipr 6 | *.iws 7 | *.pyc 8 | *.pyo 9 | *.sublime-* 10 | .*.swo 11 | .*.swp 12 | .cache/ 13 | .coverage 14 | .direnv/ 15 | .env 16 | .envrc 17 | .gradle/ 18 | .idea/ 19 | target/ 20 | .tox* 21 | .venv* 22 | /*/*pinned.txt 23 | /*/MANIFEST 24 | /*/activate 25 | /*/build/ 26 | /*/config 27 | /*/coverage.xml 28 | /*/dist/ 29 | /*/htmlcov/ 30 | /*/product-spec.json 31 | /build/ 32 | /config/ 33 | /dist/ 34 | /ligradle/ 35 | TEST-*.xml 36 | __pycache__/ 37 | /*/build 38 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | language: scala 3 | git: 4 | depth: 3 5 | jdk: 6 | - oraclejdk8 7 | 8 | matrix: 9 | include: 10 | - scala: 2.12.12 11 | script: "mvn test -B -Pscala-2.12" 12 | 13 | - scala: 2.13.8 14 | script: "mvn test -B -Pscala-2.13" 15 | 16 | # safelist 17 | branches: 18 | only: 19 | - master 20 | - spark-2.3 21 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contribution Agreement 2 | ====================== 3 | 4 | As a contributor, you represent that the code you submit is your original work or 5 | that of your employer (in which case you represent you have the right to bind your 6 | employer). By submitting code, you (and, if applicable, your employer) are 7 | licensing the submitted code to LinkedIn and the open source community subject 8 | to the BSD 2-Clause license. 9 | 10 | Responsible Disclosure of Security Vulnerabilities 11 | ================================================== 12 | 13 | **Do not file an issue on Github for security issues.** Please review 14 | the [guidelines for disclosure][disclosure_guidelines]. Reports should 15 | be encrypted using PGP ([public key][pubkey]) and sent to 16 | [security@linkedin.com][disclosure_email] preferably with the title 17 | "Vulnerability in Github LinkedIn/spark-tfrecord - <short summary>". 18 | 19 | Tips for Getting Your Pull Request Accepted 20 | =========================================== 21 | 22 | 1. Make sure all new features are tested and the tests pass. 23 | 2. Bug fixes must include a test case demonstrating the error that it fixes. 24 | 3. Open an issue first and seek advice for your change before submitting 25 | a pull request. Large features which have never been discussed are 26 | unlikely to be accepted. **You have been warned.** 27 | 28 | [disclosure_guidelines]: https://www.linkedin.com/help/linkedin/answer/62924 29 | [pubkey]: https://www.linkedin.com/help/linkedin/answer/79676 30 | [disclosure_email]: mailto:security@linkedin.com?subject=Vulnerability%20in%20Github%20LinkedIn/spark-tfrecord%20-%20%3Csummary%3E 31 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 2-CLAUSE LICENSE 2 | 3 | Copyright 2020 LinkedIn Corporation 4 | All Rights Reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following 7 | conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following 12 | disclaimer in the documentation and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 15 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 16 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2020 LinkedIn Corporation 2 | All Rights Reserved. 3 | 4 | Licensed under the BSD 2-Clause License (the "License"). 5 | See LICENSE in the project root for license information. 6 | 7 | This product includes: 8 | - spark-tensorflow-connector: Apache 2.0 License, available at https://github.com/tensorflow/ecosystem/blob/master/LICENSE 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spark-TFRecord 2 | 3 | A library for reading and writing [Tensorflow TFRecord](https://www.tensorflow.org/how_tos/reading_data/) data from [Apache Spark](http://spark.apache.org/). 4 | The implementation is based on [Spark Tensorflow Connector](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector), but it is rewritten in Spark FileFormat trait to provide the partitioning function. 5 | 6 | ## Including the library 7 | 8 | The artifacts are published to [bintray](https://bintray.com/linkedin/maven/spark-tfrecord) and [maven central](https://search.maven.org/search?q=spark-tfrecord) repositories. 9 | 10 | - Version 0.1.x targets Spark 2.3 and Scala 2.11 11 | - Version 0.2.x targets Spark 2.4 and both Scala 2.11 and 2.12 12 | - Version 0.3.x targets Spark 3.0 and Scala 2.12 13 | - Version 0.4.x targets Spark 3.2 and Scala 2.12 14 | - Version 0.5.x targets Spark 3.2 and Scala 2.13 15 | - Version 0.6.x targets Spark 3.4 and both Scala 2.12 and 2.13 16 | - Version 0.7.x targets Spark 3.5 and both Scala 2.12 and 2.13 17 | 18 | To use the package, please include the dependency as follows 19 | 20 | ```xml 21 | 22 | com.linkedin.sparktfrecord 23 | spark-tfrecord_2.12 24 | your.version 25 | 26 | ``` 27 | 28 | ## Building the library 29 | The library can be built with Maven 3.3.9 or newer as shown below: 30 | 31 | ```sh 32 | # Build Spark-TFRecord 33 | git clone https://github.com/linkedin/spark-tfrecord.git 34 | cd spark-tfrecord 35 | mvn -Pscala-2.12 clean install 36 | 37 | # One can specify the spark version and tensorflow hadoop version, for example 38 | mvn -Pscala-2.12 clean install -Dspark.version=3.0.0 -Dtensorflow.hadoop.version=1.15.0 39 | ``` 40 | 41 | ## Using Spark Shell 42 | Run this library in Spark using the `--jars` command line option in `spark-shell`, `pyspark` or `spark-submit`. For example: 43 | 44 | ```sh 45 | $SPARK_HOME/bin/spark-shell --jars target/spark-tfrecord_2.12-0.3.0.jar 46 | ``` 47 | 48 | ## Features 49 | This library allows reading TensorFlow records in local or distributed filesystem as [Spark DataFrames](https://spark.apache.org/docs/latest/sql-programming-guide.html). 50 | When reading TensorFlow records into Spark DataFrame, the API accepts several options: 51 | * `load`: input path to TensorFlow records. Similar to Spark can accept standard Hadoop globbing expressions. 52 | * `schema`: schema of TensorFlow records. Optional schema defined using Spark StructType. If not provided, the schema is inferred from TensorFlow records. 53 | * `recordType`: input format of TensorFlow records. By default it is Example. Possible values are: 54 | * `Example`: TensorFlow [Example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records 55 | * `SequenceExample`: TensorFlow [SequenceExample](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records 56 | * `ByteArray`: `Array[Byte]` type in scala. 57 | 58 | When writing Spark DataFrame to TensorFlow records, the API accepts several options: 59 | * `save`: output path to TensorFlow records. Output path to TensorFlow records on local or distributed filesystem. 60 | compression. While reading compressed TensorFlow records, `codec` can be inferred automatically, so this option is not required for reading. 61 | * `recordType`: output format of TensorFlow records. By default it is Example. Possible values are: 62 | * `Example`: TensorFlow [Example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records 63 | * `SequenceExample`: TensorFlow [SequenceExample](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records 64 | * `ByteArray`: `Array[Byte]` type in scala. For use cases when writing objects other than tensorflow Example or SequenceExample. For example, [protos](https://developers.google.com/protocol-buffers) can be transformed to byte arrays using `.toByteArray`. 65 | 66 | The writer support partitionBy operation. So the following command will partition the output by "partitionColumn". 67 | ``` 68 | df.write.mode(SaveMode.Overwrite).partitionBy("partitionColumn").format("tfrecord").option("recordType", "Example").save(output_dir) 69 | ``` 70 | Note we use `format("tfrecord")` instead `format("tfrecords")`. So if you migrate from Spark-Tensorflow-Connector, make sure this is changed accordingly. 71 | 72 | ## Schema inference 73 | This library supports automatic schema inference when reading TensorFlow records into Spark DataFrames. 74 | Schema inference is expensive since it requires an extra pass through the data. 75 | 76 | The schema inference rules are described in the table below: 77 | 78 | | TFRecordType | Feature Type | Inferred Spark Data Type | 79 | | ------------------------ |:--------------|:--------------------------| 80 | | Example, SequenceExample | Int64List | LongType if all lists have length=1, else ArrayType(LongType) | 81 | | Example, SequenceExample | FloatList | FloatType if all lists have length=1, else ArrayType(FloatType) | 82 | | Example, SequenceExample | BytesList | StringType if all lists have length=1, else ArrayType(StringType) | 83 | | SequenceExample | FeatureList of Int64List | ArrayType(ArrayType(LongType)) | 84 | | SequenceExample | FeatureList of FloatList | ArrayType(ArrayType(FloatType)) | 85 | | SequenceExample | FeatureList of BytesList | ArrayType(ArrayType(StringType)) | 86 | 87 | ## Supported data types 88 | 89 | The supported Spark data types are listed in the table below: 90 | 91 | | Type | Spark DataTypes | 92 | | --------------- |:------------------------------------------| 93 | | Scalar | IntegerType, LongType, FloatType, DoubleType, DecimalType, StringType, BinaryType | 94 | | Array | VectorType, ArrayType of IntegerType, LongType, FloatType, DoubleType, DecimalType, BinaryType, or StringType | 95 | | Array of Arrays | ArrayType of ArrayType of IntegerType, LongType, FloatType, DoubleType, DecimalType, BinaryType, or StringType | 96 | 97 | ## Usage Examples 98 | 99 | ### Python API 100 | 101 | #### TF record Import/export 102 | 103 | Run PySpark with the spark_connector in the jars argument as shown below: 104 | 105 | `$SPARK_HOME/bin/pyspark --jars target/spark-tfrecord_2.12-0.3.0.jar` 106 | 107 | The following Python code snippet demonstrates usage on test data. 108 | 109 | ```python 110 | from pyspark.sql.types import * 111 | 112 | path = "test-output.tfrecord" 113 | 114 | fields = [StructField("id", IntegerType()), StructField("IntegerCol", IntegerType()), 115 | StructField("LongCol", LongType()), StructField("FloatCol", FloatType()), 116 | StructField("DoubleCol", DoubleType()), StructField("VectorCol", ArrayType(DoubleType(), True)), 117 | StructField("StringCol", StringType())] 118 | schema = StructType(fields) 119 | test_rows = [[11, 1, 23, 10.0, 14.0, [1.0, 2.0], "r1"], [21, 2, 24, 12.0, 15.0, [2.0, 2.0], "r2"]] 120 | rdd = spark.sparkContext.parallelize(test_rows) 121 | df = spark.createDataFrame(rdd, schema) 122 | df.write.mode("overwrite").format("tfrecord").option("recordType", "Example").save(path) 123 | df = spark.read.format("tfrecord").option("recordType", "Example").load(path) 124 | df.show() 125 | ``` 126 | 127 | ### Scala API 128 | Run Spark shell with the spark_connector in the jars argument as shown below: 129 | ```sh 130 | $SPARK_HOME/bin/spark-shell --jars target/spark-tfrecord_2.12-0.3.0.jar 131 | ``` 132 | 133 | The following Scala code snippet demonstrates usage on test data. 134 | 135 | ```scala 136 | import org.apache.commons.io.FileUtils 137 | import org.apache.spark.sql.{ DataFrame, Row } 138 | import org.apache.spark.sql.catalyst.expressions.GenericRow 139 | import org.apache.spark.sql.types._ 140 | 141 | val path = "test-output.tfrecord" 142 | val testRows: Array[Row] = Array( 143 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")), 144 | new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2"))) 145 | val schema = StructType(List(StructField("id", IntegerType), 146 | StructField("IntegerCol", IntegerType), 147 | StructField("LongCol", LongType), 148 | StructField("FloatCol", FloatType), 149 | StructField("DoubleCol", DoubleType), 150 | StructField("VectorCol", ArrayType(DoubleType, true)), 151 | StructField("StringCol", StringType))) 152 | 153 | val rdd = spark.sparkContext.parallelize(testRows) 154 | 155 | //Save DataFrame as TFRecords 156 | val df: DataFrame = spark.createDataFrame(rdd, schema) 157 | df.write.format("tfrecord").option("recordType", "Example").save(path) 158 | 159 | //Read TFRecords into DataFrame. 160 | //The DataFrame schema is inferred from the TFRecords if no custom schema is provided. 161 | val importedDf1: DataFrame = spark.read.format("tfrecord").option("recordType", "Example").load(path) 162 | importedDf1.show() 163 | 164 | //Read TFRecords into DataFrame using custom schema 165 | val importedDf2: DataFrame = spark.read.format("tfrecord").schema(schema).load(path) 166 | importedDf2.show() 167 | ``` 168 | 169 | #### Use partitionBy 170 | The following example shows to how to use partitionBy, which is not supported by [Spark Tensorflow Connector](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector) 171 | 172 | ```scala 173 | 174 | // launch spark-shell with the following command: 175 | // SPARK_HOME/bin/spark-shell --jar target/spark-tfrecord_2.12-0.3.0.jar 176 | 177 | import org.apache.spark.sql.SaveMode 178 | 179 | val df = Seq((8, "bat"),(8, "abc"), (1, "xyz"), (2, "aaa")).toDF("number", "word") 180 | df.show 181 | 182 | // scala> df.show 183 | // +------+----+ 184 | // |number|word| 185 | // +------+----+ 186 | // | 8| bat| 187 | // | 8| abc| 188 | // | 1| xyz| 189 | // | 2| aaa| 190 | // +------+----+ 191 | 192 | val tf_output_dir = "/tmp/tfrecord-test" 193 | 194 | // dump the tfrecords to files. 195 | df.repartition(3, col("number")).write.mode(SaveMode.Overwrite).partitionBy("number").format("tfrecord").option("recordType", "Example").save(tf_output_dir) 196 | 197 | // ls /tmp/tfrecord-test 198 | // _SUCCESS number=1 number=2 number=8 199 | 200 | // read back the tfrecords from files. 201 | val new_df = spark.read.format("tfrecord").option("recordType", "Example").load(tf_output_dir) 202 | new_df.show 203 | 204 | // scala> new_df.show 205 | // +----+------+ 206 | // |word|number| 207 | // +----+------+ 208 | // | bat| 8| 209 | // | abc| 8| 210 | // | xyz| 1| 211 | // | aaa| 2| 212 | ``` 213 | ## Contributing 214 | 215 | Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us. 216 | 217 | ## License 218 | 219 | This project is licensed under the BSD 2-CLAUSE LICENSE - see the [LICENSE.md](LICENSE.md) file for details 220 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | com.linkedin.sparktfrecord 7 | spark-tfrecord_${scala.binary.version} 8 | jar 9 | 0.7.0 10 | spark-tfrecord 11 | https://github.com/linkedin/spark-tfrecord 12 | TensorFlow TFRecord data source for Apache Spark 13 | 14 | 15 | 16 | BSD 2-CLAUSE LICENSE 17 | https://github.com/linkedin/spark-tfrecord/blob/master/LICENSE.md 18 | repo 19 | 20 | 21 | 22 | 23 | https://github.com/linkedin/spark-tfrecord.git 24 | git@github.com:linkedin/spark-tfrecord.git 25 | scm:git:https://github.com/linkedin/spark-tfrecord.git 26 | 27 | 28 | 29 | UTF-8 30 | 3.2.2 31 | 1.0 32 | 3.0.8 33 | 3.5.1 34 | 3.0 35 | 1.8 36 | 4.13.1 37 | 1.15.0 38 | 39 | 40 | 41 | 42 | 43 | 44 | true 45 | net.alchim31.maven 46 | scala-maven-plugin 47 | ${scala.maven.version} 48 | 49 | 50 | compile 51 | 52 | add-source 53 | compile 54 | 55 | 56 | 57 | -Xms256m 58 | -Xmx512m 59 | 60 | 61 | -g:vars 62 | -deprecation 63 | -feature 64 | -unchecked 65 | -Xfatal-warnings 66 | -language:implicitConversions 67 | -language:existentials 68 | 69 | 70 | 71 | 72 | test 73 | 74 | add-source 75 | testCompile 76 | 77 | 78 | 79 | attach-javadocs 80 | 81 | doc-jar 82 | 83 | 84 | 85 | 86 | incremental 87 | true 88 | ${scala.version} 89 | false 90 | 91 | 92 | 93 | true 94 | org.scalatest 95 | scalatest-maven-plugin 96 | ${scalatest.maven.version} 97 | 98 | 99 | scalaTest 100 | test 101 | 102 | test 103 | 104 | 105 | 106 | 107 | 108 | 109 | maven-shade-plugin 110 | 3.1.0 111 | 112 | 113 | package 114 | 115 | shade 116 | 117 | 118 | true 119 | 120 | 121 | com.google.protobuf:protobuf-java 122 | org.tensorflow:tensorflow-hadoop 123 | org.tensorflow:proto 124 | 125 | 126 | 127 | 128 | 129 | com.google.protobuf:protobuf-java 130 | 131 | **/*.java 132 | 133 | 134 | 135 | 136 | 137 | com.google.protobuf 138 | 139 | com.linkedin.spark.shaded.com.google.protobuf 140 | 141 | 142 | 143 | org.tensorflow.hadoop 144 | 145 | com.linkedin.spark.shaded.org.tensorflow.hadoop 146 | 147 | 148 | 149 | org.tensorflow.example 150 | 151 | com.linkedin.spark.shaded.org.tensorflow.example 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | org.apache.maven.plugins 162 | maven-gpg-plugin 163 | 1.5 164 | 165 | 166 | sign-artifacts 167 | verify 168 | 169 | sign 170 | 171 | 172 | 173 | 174 | 175 | 176 | org.spurint.maven.plugins 177 | scala-cross-maven-plugin 178 | 0.2.1 179 | 180 | 181 | rewrite-pom 182 | 183 | rewrite-pom 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | net.alchim31.maven 193 | scala-maven-plugin 194 | 195 | 196 | org.apache.maven.plugins 197 | maven-shade-plugin 198 | 199 | 200 | org.scalatest 201 | scalatest-maven-plugin 202 | 203 | 204 | org.apache.maven.plugins 205 | maven-compiler-plugin 206 | ${maven.compiler.version} 207 | 208 | ${java.version} 209 | ${java.version} 210 | 211 | 212 | 213 | org.apache.maven.plugins 214 | maven-source-plugin 215 | 2.2.1 216 | 217 | 218 | attach-sources 219 | 220 | jar-no-fork 221 | 222 | 223 | 224 | 225 | 226 | org.apache.maven.plugins 227 | maven-javadoc-plugin 228 | 2.9.1 229 | 230 | 231 | attach-javadocs 232 | 233 | jar 234 | 235 | 236 | 237 | 238 | 239 | org.spurint.maven.plugins 240 | scala-cross-maven-plugin 241 | 242 | 243 | 244 | 245 | 246 | 247 | apache.snapshots 248 | Apache Development Snapshot Repository 249 | https://repository.apache.org/content/repositories/snapshots/ 250 | 251 | false 252 | 253 | 254 | true 255 | 256 | 257 | 258 | 259 | 260 | 261 | test 262 | 263 | true 264 | 265 | !NEVERSETME 266 | 267 | 268 | 269 | 270 | 271 | net.alchim31.maven 272 | scala-maven-plugin 273 | 274 | 275 | 276 | 277 | 278 | 279 | org.scalatest 280 | scalatest_${scala.binary.version} 281 | ${scala.test.version} 282 | test 283 | 284 | 285 | 286 | 287 | 288 | org.scalatest 289 | scalatest_${scala.binary.version} 290 | test 291 | 292 | 293 | 294 | 295 | 298 | 299 | ossrh 300 | 301 | 302 | 303 | ossrh 304 | https://oss.sonatype.org/content/repositories/snapshots 305 | 306 | 307 | ossrh 308 | https://oss.sonatype.org/service/local/staging/deploy/maven2/ 309 | 310 | 311 | 312 | 313 | 314 | org.apache.maven.plugins 315 | maven-gpg-plugin 316 | 317 | 318 | 319 | 320 | 321 | 322 | bintray 323 | 324 | 325 | 326 | bintray-linkedin-maven 327 | linkedin-maven 328 | https://api.bintray.com/maven/linkedin/maven/spark-tfrecord/;publish=1 329 | 330 | 331 | 332 | 333 | 334 | org.apache.maven.plugins 335 | maven-gpg-plugin 336 | 337 | 338 | 339 | 340 | 341 | 342 | scala-2.12 343 | 344 | 2.12 345 | 2.12.12 346 | 347 | 348 | 349 | 353 | 354 | scala-2.13 355 | 356 | 2.13 357 | 2.13.8 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | spark-tfrecord developers 366 | LinkedIn 367 | http://www.linkedin.com 368 | 369 | 370 | 371 | 372 | 373 | org.tensorflow 374 | tensorflow-hadoop 375 | ${tensorflow.hadoop.version} 376 | 377 | 378 | org.apache.spark 379 | spark-core_${scala.binary.version} 380 | ${spark.version} 381 | provided 382 | 383 | 384 | org.apache.spark 385 | spark-sql_${scala.binary.version} 386 | ${spark.version} 387 | provided 388 | 389 | 390 | org.apache.spark 391 | spark-mllib_${scala.binary.version} 392 | ${spark.version} 393 | provided 394 | 395 | 396 | org.apache.spark 397 | spark-mllib_${scala.binary.version} 398 | ${spark.version} 399 | test-jar 400 | test 401 | 402 | 403 | junit 404 | junit 405 | ${junit.version} 406 | test 407 | 408 | 409 | 410 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.13 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" 2 | 3 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3") 4 | 5 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.5") 6 | -------------------------------------------------------------------------------- /src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | com.linkedin.spark.datasources.tfrecord.DefaultSource -------------------------------------------------------------------------------- /src/main/scala/com/linkedin/spark/datasources/tfrecord/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.spark.datasources.tfrecord 2 | 3 | import java.io.{DataInputStream, DataOutputStream, IOException, ObjectInputStream, ObjectOutputStream, Serializable} 4 | 5 | import com.esotericsoftware.kryo.{Kryo, KryoSerializable} 6 | import com.esotericsoftware.kryo.io.{Input, Output} 7 | import org.apache.hadoop.conf.Configuration 8 | import org.apache.hadoop.fs.{FileStatus, Path} 9 | import org.apache.hadoop.io.SequenceFile.CompressionType 10 | import org.apache.hadoop.io.{BytesWritable, NullWritable} 11 | import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} 12 | import org.apache.spark.sql.SparkSession 13 | import org.apache.spark.sql.catalyst.InternalRow 14 | import org.apache.spark.sql.execution.datasources._ 15 | import org.apache.spark.sql.sources._ 16 | import org.apache.spark.sql.types._ 17 | import org.slf4j.LoggerFactory 18 | import org.tensorflow.example.{Example, SequenceExample} 19 | import org.tensorflow.hadoop.io.TFRecordFileInputFormat 20 | 21 | import scala.util.control.NonFatal 22 | 23 | class DefaultSource extends FileFormat with DataSourceRegister { 24 | override val shortName: String = "tfrecord" 25 | 26 | override def isSplitable( 27 | sparkSession: SparkSession, 28 | options: Map[String, String], 29 | path: Path): Boolean = false 30 | 31 | override def inferSchema( 32 | sparkSession: SparkSession, 33 | options: Map[String, String], 34 | files: Seq[FileStatus]): Option[StructType] = { 35 | val recordType = options.getOrElse("recordType", "Example") 36 | files.collectFirst { 37 | case f if hasSchema(sparkSession, f, recordType) => getSchemaFromFile(sparkSession, f, recordType) 38 | } 39 | } 40 | 41 | /** 42 | * Get schema from a file 43 | * @param sparkSession A spark session. 44 | * @param file The file where schema to be extracted. 45 | * @param recordType Example or SequenceExample 46 | * @return the extracted schema (a StructType). 47 | */ 48 | private def getSchemaFromFile( 49 | sparkSession: SparkSession, 50 | file: FileStatus, 51 | recordType: String): StructType = { 52 | val rdd = sparkSession.sparkContext.newAPIHadoopFile(file.getPath.toString, 53 | classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable]) 54 | recordType match { 55 | case "ByteArray" => 56 | TensorFlowInferSchema.getSchemaForByteArray() 57 | case "Example" => 58 | val exampleRdd = rdd.map{case (bytesWritable, nullWritable) => 59 | Example.parseFrom(bytesWritable.getBytes) 60 | } 61 | TensorFlowInferSchema(exampleRdd) 62 | case "SequenceExample" => 63 | val sequenceExampleRdd = rdd.map{case (bytesWritable, nullWritable) => 64 | SequenceExample.parseFrom(bytesWritable.getBytes) 65 | } 66 | TensorFlowInferSchema(sequenceExampleRdd) 67 | case _ => 68 | throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be ByteArray, Example or SequenceExample") 69 | } 70 | } 71 | 72 | /** 73 | * Check if a non-empty schema can be extracted from a file. 74 | * The schema is empty if one of the following is true: 75 | * 1. The file size is zero. 76 | * 2. The file size is non-zero, but the schema is empty (e.g. empty .gz file) 77 | * @param sparkSession A spark session. 78 | * @param file The file where schema to be extracted. 79 | * @param recordType Example or SequenceExample 80 | * @return true if schema is non-empty. 81 | */ 82 | private def hasSchema( 83 | sparkSession: SparkSession, 84 | file: FileStatus, 85 | recordType: String): Boolean = { 86 | (file.getLen > 0) && (getSchemaFromFile(sparkSession, file, recordType).length > 0) 87 | } 88 | 89 | override def prepareWrite( 90 | sparkSession: SparkSession, 91 | job: Job, 92 | options: Map[String, String], 93 | dataSchema: StructType): OutputWriterFactory = { 94 | val conf = job.getConfiguration 95 | val codec = options.getOrElse("codec", "") 96 | if (!codec.isEmpty) { 97 | conf.set("mapreduce.output.fileoutputformat.compress", "true") 98 | conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) 99 | conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) 100 | conf.set("mapreduce.map.output.compress", "true") 101 | conf.set("mapreduce.map.output.compress.codec", codec) 102 | } 103 | 104 | new OutputWriterFactory { 105 | override def newInstance( 106 | path: String, 107 | dataSchema: StructType, 108 | context: TaskAttemptContext): OutputWriter = { 109 | new TFRecordOutputWriter(path, options, dataSchema, context) 110 | } 111 | 112 | override def getFileExtension(context: TaskAttemptContext): String = { 113 | ".tfrecord" + CodecStreams.getCompressionExtension(context) 114 | } 115 | } 116 | } 117 | 118 | override def buildReader( 119 | sparkSession: SparkSession, 120 | dataSchema: StructType, 121 | partitionSchema: StructType, 122 | requiredSchema: StructType, 123 | filters: Seq[Filter], 124 | options: Map[String, String], 125 | hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { 126 | val broadcastedHadoopConf = 127 | sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) 128 | 129 | (file: PartitionedFile) => { 130 | TFRecordFileReader.readFile( 131 | broadcastedHadoopConf.value.value, 132 | options, 133 | file, 134 | requiredSchema) 135 | } 136 | } 137 | 138 | override def toString: String = "TFRECORD" 139 | 140 | override def hashCode(): Int = getClass.hashCode() 141 | 142 | override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] 143 | } 144 | 145 | private [tfrecord] class SerializableConfiguration(@transient var value: Configuration) 146 | extends Serializable with KryoSerializable { 147 | @transient private[tfrecord] lazy val log = LoggerFactory.getLogger(getClass) 148 | 149 | private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException { 150 | out.defaultWriteObject() 151 | value.write(out) 152 | } 153 | 154 | private def readObject(in: ObjectInputStream): Unit = tryOrIOException { 155 | value = new Configuration(false) 156 | value.readFields(in) 157 | } 158 | 159 | private def tryOrIOException[T](block: => T): T = { 160 | try { 161 | block 162 | } catch { 163 | case e: IOException => 164 | log.error("Exception encountered", e) 165 | throw e 166 | case NonFatal(e) => 167 | log.error("Exception encountered", e) 168 | throw new IOException(e) 169 | } 170 | } 171 | 172 | def write(kryo: Kryo, out: Output): Unit = { 173 | val dos = new DataOutputStream(out) 174 | value.write(dos) 175 | dos.flush() 176 | } 177 | 178 | def read(kryo: Kryo, in: Input): Unit = { 179 | value = new Configuration(false) 180 | value.readFields(new DataInputStream(in)) 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.spark.datasources.tfrecord 2 | 3 | import org.apache.spark.sql.catalyst.InternalRow 4 | import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow, UnsafeArrayData} 5 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 6 | import org.apache.spark.sql.types.{DecimalType, DoubleType, _} 7 | import org.apache.spark.unsafe.types.UTF8String 8 | import org.tensorflow.example._ 9 | 10 | import scala.jdk.CollectionConverters._ 11 | 12 | /** 13 | * Creates a TFRecord deserializer to deserialize Tfrecord example or sequenceExample to Spark InternalRow 14 | */ 15 | class TFRecordDeserializer(dataSchema: StructType) { 16 | 17 | def deserializeByteArray(byteArray: Array[Byte]): InternalRow = { 18 | InternalRow(byteArray) 19 | } 20 | 21 | def deserializeExample(example: Example): InternalRow = { 22 | val featureMap = example.getFeatures.getFeatureMap.asScala 23 | val resultRow = new SpecificInternalRow(dataSchema.map(_.dataType)) 24 | dataSchema.zipWithIndex.foreach { 25 | case (field, index) => 26 | val feature = featureMap.get(field.name) 27 | feature match { 28 | case Some(ft) => 29 | val featureWriter = newFeatureWriter(field.dataType, new RowUpdater(resultRow)) 30 | featureWriter(index, ft) 31 | case None => if (!field.nullable) throw new NullPointerException(s"Field ${field.name} does not allow null values") 32 | } 33 | } 34 | resultRow 35 | } 36 | 37 | def deserializeSequenceExample(sequenceExample: SequenceExample): InternalRow = { 38 | 39 | val featureMap = sequenceExample.getContext.getFeatureMap.asScala 40 | val featureListMap = sequenceExample.getFeatureLists.getFeatureListMap.asScala 41 | val resultRow = new SpecificInternalRow(dataSchema.map(_.dataType)) 42 | 43 | dataSchema.zipWithIndex.foreach { 44 | case (field, index) => 45 | val feature = featureMap.get(field.name) 46 | feature match { 47 | case Some(ft) => 48 | val featureWriter = newFeatureWriter(field.dataType, new RowUpdater(resultRow)) 49 | featureWriter(index, ft) 50 | case None => 51 | val featureList = featureListMap.get(field.name) 52 | featureList match { 53 | case Some(ftList) => 54 | val featureListWriter = newFeatureListWriter(field.dataType, new RowUpdater(resultRow)) 55 | featureListWriter(index, ftList) 56 | case None => if (!field.nullable) throw new NullPointerException(s"Field ${field.name} does not allow null values") 57 | } 58 | } 59 | } 60 | resultRow 61 | } 62 | 63 | private type arrayElementConverter = (SpecializedGetters, Int) => Any 64 | 65 | /** 66 | * Creates a writer to write Tfrecord Feature values to Catalyst data structure at the given ordinal. 67 | */ 68 | private def newFeatureWriter( 69 | dataType: DataType, updater: CatalystDataUpdater): (Int, Feature) => Unit = 70 | dataType match { 71 | case NullType => (ordinal, _) => 72 | updater.setNullAt(ordinal) 73 | 74 | case IntegerType => (ordinal, feature) => 75 | updater.setInt(ordinal, Int64ListFeature2SeqLong(feature).head.toInt) 76 | 77 | case LongType => (ordinal, feature) => 78 | updater.setLong(ordinal, Int64ListFeature2SeqLong(feature).head) 79 | 80 | case FloatType => (ordinal, feature) => 81 | updater.setFloat(ordinal, floatListFeature2SeqFloat(feature).head.toFloat) 82 | 83 | case DoubleType => (ordinal, feature) => 84 | updater.setDouble(ordinal, floatListFeature2SeqFloat(feature).head.toDouble) 85 | 86 | case DecimalType() => (ordinal, feature) => 87 | updater.setDecimal(ordinal, Decimal(floatListFeature2SeqFloat(feature).head.toDouble)) 88 | 89 | case StringType => (ordinal, feature) => 90 | val value = bytesListFeature2SeqString(feature).head 91 | updater.set(ordinal, UTF8String.fromString(value)) 92 | 93 | case BinaryType => (ordinal, feature) => 94 | val value = bytesListFeature2SeqArrayByte(feature).head 95 | updater.set(ordinal, value) 96 | 97 | case ArrayType(elementType, _) => (ordinal, feature) => 98 | 99 | elementType match { 100 | case IntegerType | LongType | FloatType | DoubleType | DecimalType() | StringType | BinaryType => 101 | val valueList = elementType match { 102 | case IntegerType => Int64ListFeature2SeqLong(feature).map(_.toInt) 103 | case LongType => Int64ListFeature2SeqLong(feature) 104 | case FloatType => floatListFeature2SeqFloat(feature).map(_.toFloat) 105 | case DoubleType => floatListFeature2SeqFloat(feature).map(_.toDouble) 106 | case DecimalType() => floatListFeature2SeqFloat(feature).map(x => Decimal(x.toDouble)) 107 | case StringType => bytesListFeature2SeqString(feature) 108 | case BinaryType => bytesListFeature2SeqArrayByte(feature) 109 | } 110 | val len = valueList.length 111 | val result = createArrayData(elementType, len) 112 | val elementUpdater = new ArrayDataUpdater(result) 113 | val elementConverter = newArrayElementWriter(elementType, elementUpdater) 114 | for (idx <- 0 until len) { 115 | elementConverter(idx, valueList(idx)) 116 | } 117 | updater.set(ordinal, result) 118 | 119 | case _ => throw new scala.RuntimeException(s"Cannot convert Array type to unsupported data type ${elementType}") 120 | } 121 | 122 | case _ => 123 | throw new UnsupportedOperationException(s"$dataType is not supported yet.") 124 | } 125 | 126 | /** 127 | * Creates a writer to write Tfrecord FeatureList values to Catalyst data structure at the given ordinal. 128 | */ 129 | private def newFeatureListWriter( 130 | dataType: DataType, updater: CatalystDataUpdater): (Int, FeatureList) => Unit = 131 | dataType match { 132 | case ArrayType(elementType, _) => (ordinal, featureList) => 133 | val ftList = featureList.getFeatureList.asScala 134 | val len = ftList.length 135 | val resultArray = createArrayData(elementType, len) 136 | val elementUpdater = new ArrayDataUpdater(resultArray) 137 | val elementConverter = newFeatureWriter(elementType, elementUpdater) 138 | for (idx <- 0 until len) { 139 | elementConverter(idx, ftList(idx)) 140 | } 141 | updater.set(ordinal, resultArray) 142 | case _ => throw new scala.RuntimeException(s"Cannot convert FeatureList to unsupported data type ${dataType}") 143 | } 144 | 145 | /** 146 | * Creates a writer to write Tfrecord Feature array element to Catalyst data structure at the given ordinal. 147 | */ 148 | private def newArrayElementWriter( 149 | dataType: DataType, updater: CatalystDataUpdater): (Int, Any) => Unit = 150 | dataType match { 151 | case NullType => null 152 | 153 | case IntegerType => (ordinal, value) => 154 | updater.setInt(ordinal, value.asInstanceOf[Int]) 155 | 156 | case LongType => (ordinal, value) => 157 | updater.setLong(ordinal, value.asInstanceOf[Long]) 158 | 159 | case FloatType => (ordinal, value) => 160 | updater.setFloat(ordinal, value.asInstanceOf[Float]) 161 | 162 | case DoubleType => (ordinal, value) => 163 | updater.setDouble(ordinal, value.asInstanceOf[Double]) 164 | 165 | case DecimalType() => (ordinal, value) => 166 | updater.setDecimal(ordinal, value.asInstanceOf[Decimal]) 167 | 168 | case StringType => (ordinal, value) => 169 | updater.set(ordinal, UTF8String.fromString(value.asInstanceOf[String])) 170 | 171 | case BinaryType => (ordinal, value) => 172 | updater.set(ordinal, value.asInstanceOf[Array[Byte]]) 173 | 174 | case _ => throw new RuntimeException(s"Cannot convert array element to unsupported data type ${dataType}") 175 | } 176 | 177 | def Int64ListFeature2SeqLong(feature: Feature): Seq[Long] = { 178 | require(feature != null && feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List") 179 | try { 180 | feature.getInt64List.getValueList.asScala.toSeq.map(_.toLong) 181 | } 182 | catch { 183 | case ex: Exception => 184 | throw new RuntimeException(s"Cannot convert feature to long.", ex) 185 | } 186 | } 187 | 188 | def floatListFeature2SeqFloat(feature: Feature): Seq[java.lang.Float] = { 189 | require(feature != null && feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList") 190 | try { 191 | val array = feature.getFloatList.getValueList.asScala.toSeq 192 | array 193 | } 194 | catch { 195 | case ex: Exception => 196 | throw new RuntimeException(s"Cannot convert feature to Float.", ex) 197 | } 198 | } 199 | 200 | def bytesListFeature2SeqArrayByte(feature: Feature): Seq[Array[Byte]] = { 201 | require(feature != null && feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList") 202 | try { 203 | feature.getBytesList.getValueList.asScala.toSeq.map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) 204 | } 205 | catch { 206 | case ex: Exception => 207 | throw new RuntimeException(s"Cannot convert feature to byte array.", ex) 208 | } 209 | } 210 | 211 | def bytesListFeature2SeqString(feature: Feature): Seq[String] = { 212 | require(feature != null && feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList") 213 | try { 214 | val array = feature.getBytesList.getValueList.asScala.toSeq 215 | array.map(_.toStringUtf8) 216 | } 217 | catch { 218 | case ex: Exception => 219 | throw new RuntimeException(s"Cannot convert feature to String array.", ex) 220 | } 221 | } 222 | 223 | private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { 224 | case BooleanType => new GenericArrayData(new Array[Boolean](length)) 225 | case ByteType => new GenericArrayData(new Array[Byte](length)) 226 | case ShortType => new GenericArrayData(new Array[Short](length)) 227 | case IntegerType => new GenericArrayData(new Array[Int](length)) 228 | case LongType => new GenericArrayData(new Array[Long](length)) 229 | case FloatType => new GenericArrayData(new Array[Float](length)) 230 | case DoubleType => new GenericArrayData(new Array[Double](length)) 231 | case _ => new GenericArrayData(new Array[Any](length)) 232 | } 233 | 234 | /** 235 | * A base interface for updating values inside catalyst data structure like `InternalRow` and 236 | * `ArrayData`. 237 | */ 238 | sealed trait CatalystDataUpdater { 239 | def set(ordinal: Int, value: Any): Unit 240 | def setNullAt(ordinal: Int): Unit = set(ordinal, null) 241 | def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) 242 | def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) 243 | def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) 244 | def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) 245 | def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) 246 | def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) 247 | def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) 248 | def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) 249 | } 250 | 251 | final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { 252 | override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) 253 | override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) 254 | override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) 255 | override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) 256 | override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) 257 | override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) 258 | override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) 259 | override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) 260 | override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) 261 | override def setDecimal(ordinal: Int, value: Decimal): Unit = 262 | row.setDecimal(ordinal, value, value.precision) 263 | } 264 | 265 | final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { 266 | override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) 267 | override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) 268 | override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) 269 | override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) 270 | override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) 271 | override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) 272 | override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) 273 | override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) 274 | override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) 275 | override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordFileReader.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.spark.datasources.tfrecord 2 | 3 | import org.apache.hadoop.conf.Configuration 4 | import org.apache.hadoop.fs.Path 5 | import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} 6 | import org.apache.hadoop.mapreduce.lib.input.FileSplit 7 | import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl 8 | import org.apache.spark.sql.catalyst.InternalRow 9 | import org.apache.spark.sql.execution.datasources.PartitionedFile 10 | import org.apache.spark.TaskContext 11 | import org.apache.spark.sql.types.StructType 12 | import org.tensorflow.example.{Example, SequenceExample} 13 | import org.tensorflow.hadoop.io.TFRecordFileInputFormat 14 | 15 | object TFRecordFileReader { 16 | def readFile( 17 | conf: Configuration, 18 | options: Map[String, String], 19 | file: PartitionedFile, 20 | schema: StructType): Iterator[InternalRow] = { 21 | 22 | val recordType = options.getOrElse("recordType", "Example") 23 | 24 | val inputSplit = new FileSplit( 25 | file.toPath, 26 | file.start, 27 | file.length, 28 | // The locality is decided by `getPreferredLocations` in `FileScanRDD`. 29 | Array.empty) 30 | val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) 31 | val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) 32 | val recordReader = new TFRecordFileInputFormat().createRecordReader(inputSplit, hadoopAttemptContext) 33 | 34 | // Ensure that the reader is closed even if the task fails or doesn't consume the entire 35 | // iterator of records. 36 | Option(TaskContext.get()).foreach { taskContext => 37 | taskContext.addTaskCompletionListener[Unit]((_: TaskContext) => 38 | recordReader.close() 39 | ) 40 | } 41 | 42 | recordReader.initialize(inputSplit, hadoopAttemptContext) 43 | 44 | val deserializer = new TFRecordDeserializer(schema) 45 | 46 | new Iterator[InternalRow] { 47 | private[this] var havePair = false 48 | private[this] var finished = false 49 | override def hasNext: Boolean = { 50 | if (!finished && !havePair) { 51 | finished = !recordReader.nextKeyValue 52 | if (finished) { 53 | // Close and release the reader here; close() will also be called when the task 54 | // completes, but for tasks that read from many files, it helps to release the 55 | // resources early. 56 | recordReader.close() 57 | } 58 | havePair = !finished 59 | } 60 | !finished 61 | } 62 | 63 | override def next(): InternalRow = { 64 | if (!hasNext) { 65 | throw new java.util.NoSuchElementException("End of stream") 66 | } 67 | havePair = false 68 | val bytesWritable = recordReader.getCurrentKey 69 | recordType match { 70 | case "ByteArray" => 71 | deserializer.deserializeByteArray(bytesWritable.getBytes) 72 | case "Example" => 73 | val example = Example.parseFrom(bytesWritable.getBytes) 74 | deserializer.deserializeExample(example) 75 | case "SequenceExample" => 76 | val sequenceExample = SequenceExample.parseFrom(bytesWritable.getBytes) 77 | deserializer.deserializeSequenceExample(sequenceExample) 78 | case _ => 79 | throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be ByteArray, Example or SequenceExample") 80 | } 81 | } 82 | } 83 | } 84 | } -------------------------------------------------------------------------------- /src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordOutputWriter.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.spark.datasources.tfrecord 2 | 3 | import java.io.DataOutputStream 4 | import org.apache.hadoop.fs.Path 5 | import org.apache.hadoop.mapreduce.TaskAttemptContext 6 | import org.apache.spark.sql.catalyst.InternalRow 7 | import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} 8 | import org.apache.spark.sql.types.StructType 9 | import org.tensorflow.hadoop.util.TFRecordWriter 10 | 11 | 12 | class TFRecordOutputWriter( 13 | val path: String, 14 | options: Map[String, String], 15 | dataSchema: StructType, 16 | context: TaskAttemptContext) 17 | extends OutputWriter{ 18 | 19 | private val outputStream = CodecStreams.createOutputStream(context, new Path(path)) 20 | private val dataOutputStream = new DataOutputStream(outputStream) 21 | private val writer = new TFRecordWriter(dataOutputStream) 22 | private val recordType = options.getOrElse("recordType", "Example") 23 | 24 | private[this] val serializer = new TFRecordSerializer(dataSchema) 25 | 26 | override def write(row: InternalRow): Unit = { 27 | val record = recordType match { 28 | case "ByteArray" => 29 | serializer.serializeByteArray(row) 30 | case "Example" => 31 | serializer.serializeExample(row).toByteArray 32 | case "SequenceExample" => 33 | serializer.serializeSequenceExample(row).toByteArray 34 | case _ => 35 | throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Byte Array, Example or SequenceExample") 36 | } 37 | writer.write(record) 38 | } 39 | 40 | override def close(): Unit = { 41 | dataOutputStream.close() 42 | outputStream.close() 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala: -------------------------------------------------------------------------------- 1 | package com.linkedin.spark.datasources.tfrecord 2 | 3 | import org.apache.spark.sql.catalyst.InternalRow 4 | import org.apache.spark.sql.types.{DecimalType, DoubleType, _} 5 | import org.tensorflow.example._ 6 | import org.apache.spark.sql.catalyst.expressions.SpecializedGetters 7 | import com.google.protobuf.ByteString 8 | 9 | /** 10 | * Creates a TFRecord serializer to serialize Spark InternalRow to Tfrecord example or sequenceExample 11 | */ 12 | class TFRecordSerializer(dataSchema: StructType) { 13 | 14 | private val featureConverters = dataSchema.map(_.dataType).map(newFeatureConverter(_)).toArray 15 | 16 | def serializeByteArray(row: InternalRow): Array[Byte] = { 17 | row.getBinary(0) 18 | } 19 | 20 | def serializeExample(row: InternalRow): Example = { 21 | val features = Features.newBuilder() 22 | val example = Example.newBuilder() 23 | for (idx <- featureConverters.indices) { 24 | val structField = dataSchema(idx) 25 | if (!row.isNullAt(idx)) { 26 | val feature = featureConverters(idx)(row, idx).asInstanceOf[Feature] 27 | features.putFeature(structField.name, feature) 28 | } 29 | else if (!dataSchema(idx).nullable) { 30 | throw new NullPointerException(s"${structField.name} does not allow null values") 31 | } 32 | } 33 | example.setFeatures(features.build()) 34 | example.build() 35 | } 36 | 37 | def serializeSequenceExample(row: InternalRow): SequenceExample = { 38 | val features = Features.newBuilder() 39 | val featureLists = FeatureLists.newBuilder() 40 | val sequenceExample = SequenceExample.newBuilder() 41 | for (idx <- featureConverters.indices) { 42 | val structField = dataSchema(idx) 43 | if (!row.isNullAt(idx)) { 44 | structField.dataType match { 45 | case ArrayType(ArrayType(_, _), _) => 46 | val featureList = featureConverters(idx)(row, idx).asInstanceOf[FeatureList] 47 | featureLists.putFeatureList(structField.name, featureList) 48 | case _ => 49 | val feature = featureConverters(idx)(row, idx).asInstanceOf[Feature] 50 | features.putFeature(structField.name, feature) 51 | } 52 | } 53 | else if (!dataSchema(idx).nullable) { 54 | throw new NullPointerException(s"${structField.name} does not allow null values") 55 | } 56 | } 57 | sequenceExample.setContext(features.build()) 58 | sequenceExample.setFeatureLists(featureLists.build()) 59 | sequenceExample.build() 60 | } 61 | 62 | private type FeatureConverter = (SpecializedGetters, Int) => Any 63 | private type arrayElementConverter = (SpecializedGetters, Int) => Any 64 | 65 | /** 66 | * Creates a converter to convert Catalyst data at the given ordinal to TFrecord Feature. 67 | */ 68 | private def newFeatureConverter( 69 | dataType: DataType): FeatureConverter = dataType match { 70 | case NullType => (getter, ordinal) => null 71 | 72 | case IntegerType => (getter, ordinal) => 73 | val value = getter.getInt(ordinal) 74 | Int64ListFeature(Seq(value.toLong)) 75 | 76 | case LongType => (getter, ordinal) => 77 | val value = getter.getLong(ordinal) 78 | Int64ListFeature(Seq(value)) 79 | 80 | case FloatType => (getter, ordinal) => 81 | val value = getter.getFloat(ordinal) 82 | floatListFeature(Seq(value)) 83 | 84 | case DoubleType => (getter, ordinal) => 85 | val value = getter.getDouble(ordinal) 86 | floatListFeature(Seq(value.toFloat)) 87 | 88 | case DecimalType() => (getter, ordinal) => 89 | val value = getter.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale) 90 | floatListFeature(Seq(value.toFloat)) 91 | 92 | case StringType => (getter, ordinal) => 93 | val value = getter.getUTF8String(ordinal).getBytes 94 | bytesListFeature(Seq(value)) 95 | 96 | case BinaryType => (getter, ordinal) => 97 | val value = getter.getBinary(ordinal) 98 | bytesListFeature(Seq(value)) 99 | 100 | case ArrayType(elementType, containsNull) => (getter, ordinal) => 101 | val arrayData = getter.getArray(ordinal) 102 | val featureOrFeatureList = elementType match { 103 | case IntegerType => 104 | Int64ListFeature(arrayData.toIntArray().toSeq.map(_.toLong)) 105 | 106 | case LongType => 107 | Int64ListFeature(arrayData.toLongArray().toSeq) 108 | 109 | case FloatType => 110 | floatListFeature(arrayData.toFloatArray().toSeq) 111 | 112 | case DoubleType => 113 | floatListFeature(arrayData.toDoubleArray().toSeq.map(_.toFloat)) 114 | 115 | case DecimalType() => 116 | val elementConverter = arrayElementConverter(elementType) 117 | val len = arrayData.numElements() 118 | val result = new Array[Decimal](len) 119 | for (idx <- 0 until len) { 120 | if (containsNull && arrayData.isNullAt(idx)) { 121 | result(idx) = null 122 | } else result(idx) = elementConverter(arrayData, idx).asInstanceOf[Decimal] 123 | } 124 | floatListFeature(result.toSeq.map(_.toFloat)) 125 | 126 | case StringType | BinaryType => 127 | val elementConverter = arrayElementConverter(elementType) 128 | val len = arrayData.numElements() 129 | val result = new Array[Array[Byte]](len) 130 | for (idx <- 0 until len) { 131 | if (containsNull && arrayData.isNullAt(idx)) { 132 | result(idx) = null 133 | } else result(idx) = elementConverter(arrayData, idx).asInstanceOf[Array[Byte]] 134 | } 135 | bytesListFeature(result.toSeq) 136 | 137 | // 2-dimensional array to TensorFlow "FeatureList" 138 | case ArrayType(_, _) => 139 | val elementConverter = newFeatureConverter(elementType) 140 | val featureList = FeatureList.newBuilder() 141 | for (idx <- 0 until arrayData.numElements()) { 142 | val feature = elementConverter(arrayData, idx).asInstanceOf[Feature] 143 | featureList.addFeature(feature) 144 | } 145 | featureList.build() 146 | 147 | case _ => throw new RuntimeException(s"Array element data type ${dataType} is unsupported") 148 | } 149 | featureOrFeatureList 150 | 151 | case _ => throw new RuntimeException(s"Cannot convert field to unsupported data type ${dataType}") 152 | } 153 | 154 | private def arrayElementConverter( 155 | dataType: DataType): arrayElementConverter = dataType match { 156 | case NullType => null 157 | 158 | case IntegerType => (getter, ordinal) => 159 | getter.getInt(ordinal) 160 | 161 | case LongType => (getter, ordinal) => 162 | getter.getLong(ordinal) 163 | 164 | case FloatType => (getter, ordinal) => 165 | getter.getFloat(ordinal) 166 | 167 | case DoubleType => (getter, ordinal) => 168 | getter.getDouble(ordinal) 169 | 170 | case DecimalType() => (getter, ordinal) => 171 | getter.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale) 172 | 173 | case StringType => (getter, ordinal) => 174 | getter.getUTF8String(ordinal).getBytes 175 | 176 | case BinaryType => (getter, ordinal) => 177 | getter.getBinary(ordinal) 178 | 179 | case _ => throw new RuntimeException(s"Cannot convert field to unsupported data type ${dataType}") 180 | } 181 | 182 | def Int64ListFeature(value: Seq[Long]): Feature = { 183 | val intListBuilder = Int64List.newBuilder() 184 | value.foreach {x => 185 | intListBuilder.addValue(x) 186 | } 187 | val int64List = intListBuilder.build() 188 | Feature.newBuilder().setInt64List(int64List).build() 189 | } 190 | 191 | def floatListFeature(value: Seq[Float]): Feature = { 192 | val floatListBuilder = FloatList.newBuilder() 193 | value.foreach {x => 194 | floatListBuilder.addValue(x) 195 | } 196 | val floatList = floatListBuilder.build() 197 | Feature.newBuilder().setFloatList(floatList).build() 198 | } 199 | 200 | def bytesListFeature(value: Seq[Array[Byte]]): Feature = { 201 | val bytesListBuilder = BytesList.newBuilder() 202 | value.foreach {x => 203 | bytesListBuilder.addValue(ByteString.copyFrom(x)) 204 | } 205 | val bytesList = bytesListBuilder.build() 206 | Feature.newBuilder().setBytesList(bytesList).build() 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.linkedin.spark.datasources.tfrecord 17 | 18 | import org.apache.spark.rdd.RDD 19 | import org.apache.spark.sql.types._ 20 | import org.tensorflow.example.{FeatureList, SequenceExample, Example, Feature} 21 | 22 | import scala.collection.mutable 23 | import scala.jdk.CollectionConverters._ 24 | import scala.reflect.runtime.universe._ 25 | 26 | object TensorFlowInferSchema { 27 | 28 | /** 29 | * Similar to the JSON schema inference. 30 | * [[org.apache.spark.sql.execution.datasources.json.InferSchema]] 31 | * 1. Infer type of each row 32 | * 2. Merge row types to find common type 33 | * 3. Replace any null types with string type 34 | */ 35 | def apply[T : TypeTag](rdd: RDD[T]): StructType = { 36 | val startType: mutable.Map[String, DataType] = mutable.Map.empty[String, DataType] 37 | 38 | val rootTypes: mutable.Map[String, DataType] = typeOf[T] match { 39 | case t if t =:= typeOf[Example] => { 40 | rdd.asInstanceOf[RDD[Example]].aggregate(startType)(inferExampleRowType, mergeFieldTypes) 41 | } 42 | case t if t =:= typeOf[SequenceExample] => { 43 | rdd.asInstanceOf[RDD[SequenceExample]].aggregate(startType)(inferSequenceExampleRowType, mergeFieldTypes) 44 | } 45 | case _ => throw new IllegalArgumentException(s"Unsupported recordType: recordType can be Example or SequenceExample") 46 | } 47 | 48 | val columnsList = rootTypes.map { 49 | case (featureName, featureType) => 50 | if (featureType == null) { 51 | StructField(featureName, NullType) 52 | } 53 | else { 54 | StructField(featureName, featureType) 55 | } 56 | } 57 | StructType(columnsList.toSeq) 58 | } 59 | 60 | def getSchemaForByteArray() : StructType = { 61 | StructType(Array( 62 | StructField("byteArray", BinaryType) 63 | )) 64 | } 65 | 66 | private def inferSequenceExampleRowType(schemaSoFar: mutable.Map[String, DataType], 67 | next: SequenceExample): mutable.Map[String, DataType] = { 68 | val featureMap = next.getContext.getFeatureMap.asScala 69 | val updatedSchema = inferFeatureTypes(schemaSoFar, featureMap) 70 | 71 | val featureListMap = next.getFeatureLists.getFeatureListMap.asScala 72 | inferFeatureListTypes(updatedSchema, featureListMap) 73 | } 74 | 75 | private def inferExampleRowType(schemaSoFar: mutable.Map[String, DataType], 76 | next: Example): mutable.Map[String, DataType] = { 77 | val featureMap = next.getFeatures.getFeatureMap.asScala 78 | inferFeatureTypes(schemaSoFar, featureMap) 79 | } 80 | 81 | private def inferFeatureTypes(schemaSoFar: mutable.Map[String, DataType], 82 | featureMap: mutable.Map[String, Feature]): mutable.Map[String, DataType] = { 83 | featureMap.foreach { 84 | case (featureName, feature) => { 85 | val currentType = inferField(feature) 86 | if (schemaSoFar.contains(featureName)) { 87 | val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) 88 | schemaSoFar(featureName) = updatedType.orNull 89 | } 90 | else { 91 | schemaSoFar += (featureName -> currentType) 92 | } 93 | } 94 | } 95 | schemaSoFar 96 | } 97 | 98 | def inferFeatureListTypes(schemaSoFar: mutable.Map[String, DataType], 99 | featureListMap: mutable.Map[String, FeatureList]): mutable.Map[String, DataType] = { 100 | featureListMap.foreach { 101 | case (featureName, featureList) => { 102 | val featureType = featureList.getFeatureList.asScala.map(f => inferField(f)) 103 | .reduceLeft((a, b) => findTightestCommonType(a, b).orNull) 104 | val currentType = featureType match { 105 | case ArrayType(_, _) => ArrayType(featureType) 106 | case _ => ArrayType(ArrayType(featureType)) 107 | } 108 | if (schemaSoFar.contains(featureName)) { 109 | val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) 110 | schemaSoFar(featureName) = updatedType.orNull 111 | } 112 | else { 113 | schemaSoFar += (featureName -> currentType) 114 | } 115 | } 116 | } 117 | schemaSoFar 118 | } 119 | 120 | private def mergeFieldTypes(first: mutable.Map[String, DataType], 121 | second: mutable.Map[String, DataType]): mutable.Map[String, DataType] = { 122 | //Merge two maps and do the comparison. 123 | val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet) 124 | .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get)) 125 | .toSeq: _*) 126 | mutMap 127 | } 128 | 129 | /** 130 | * Infer Feature datatype based on field number 131 | */ 132 | private def inferField(feature: Feature): DataType = { 133 | feature.getKindCase.getNumber match { 134 | case Feature.BYTES_LIST_FIELD_NUMBER => { 135 | parseBytesList(feature) 136 | } 137 | case Feature.INT64_LIST_FIELD_NUMBER => { 138 | parseInt64List(feature) 139 | } 140 | case Feature.FLOAT_LIST_FIELD_NUMBER => { 141 | parseFloatList(feature) 142 | } 143 | case _ => throw new RuntimeException("unsupported type ...") 144 | } 145 | } 146 | 147 | private def parseBytesList(feature: Feature): DataType = { 148 | val length = feature.getBytesList.getValueCount 149 | 150 | if (length == 0) { 151 | null 152 | } 153 | else if (length > 1) { 154 | ArrayType(StringType) 155 | } 156 | else { 157 | StringType 158 | } 159 | } 160 | 161 | private def parseInt64List(feature: Feature): DataType = { 162 | val int64List = feature.getInt64List.getValueList.asScala.toArray 163 | val length = int64List.length 164 | 165 | if (length == 0) { 166 | null 167 | } 168 | else if (length > 1) { 169 | ArrayType(LongType) 170 | } 171 | else { 172 | LongType 173 | } 174 | } 175 | 176 | private def parseFloatList(feature: Feature): DataType = { 177 | val floatList = feature.getFloatList.getValueList.asScala.toArray 178 | val length = floatList.length 179 | if (length == 0) { 180 | null 181 | } 182 | else if (length > 1) { 183 | ArrayType(FloatType) 184 | } 185 | else { 186 | FloatType 187 | } 188 | } 189 | 190 | /** 191 | * Copied from internal Spark api 192 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] 193 | */ 194 | private def getNumericPrecedence(dataType: DataType): Int = { 195 | dataType match { 196 | case LongType => 1 197 | case FloatType => 2 198 | case StringType => 3 199 | case ArrayType(LongType, _) => 4 200 | case ArrayType(FloatType, _) => 5 201 | case ArrayType(StringType, _) => 6 202 | case ArrayType(ArrayType(LongType, _), _) => 7 203 | case ArrayType(ArrayType(FloatType, _), _) => 8 204 | case ArrayType(ArrayType(StringType, _), _) => 9 205 | case _ => throw new RuntimeException("Unable to get the precedence for given datatype...") 206 | } 207 | } 208 | 209 | /** 210 | * Copied from internal Spark api 211 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] 212 | */ 213 | private def findTightestCommonType(tt1: DataType, tt2: DataType) : Option[DataType] = { 214 | val currType = (tt1, tt2) match { 215 | case (t1, t2) if t1 == t2 => Some(t1) 216 | case (null, t2) => Some(t2) 217 | case (t1, null) => Some(t1) 218 | 219 | // Promote types based on numeric precedence 220 | case (t1, t2) => 221 | val t1Precedence = getNumericPrecedence(t1) 222 | val t2Precedence = getNumericPrecedence(t2) 223 | val newType = if (t1Precedence > t2Precedence) t1 else t2 224 | Some(newType) 225 | case _ => None 226 | } 227 | currType 228 | } 229 | } 230 | 231 | -------------------------------------------------------------------------------- /src/test/scala/com/linkedin/spark/datasources/tfrecord/InferSchemaSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.linkedin.spark.datasources.tfrecord 17 | 18 | import org.apache.spark.rdd.RDD 19 | import org.apache.spark.sql.types._ 20 | import org.tensorflow.example._ 21 | import com.google.protobuf.ByteString 22 | 23 | class InferSchemaSuite extends SharedSparkSessionSuite { 24 | 25 | val longFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(Int.MaxValue + 10L)).build() 26 | val floatFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(10.0F).build()).build() 27 | val strFeature = Feature.newBuilder().setBytesList( 28 | BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes))).build() 29 | 30 | val longList = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(-2L).addValue(20L).build()).build() 31 | val floatList = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(2.5F).addValue(7F).build()).build() 32 | val strList = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)) 33 | .addValue(ByteString.copyFrom("r2".getBytes)).build()).build() 34 | 35 | val emptyFloatList = Feature.newBuilder().setFloatList(FloatList.newBuilder().build()).build() 36 | 37 | "InferSchema" should { 38 | 39 | "Infer schema from Example records" in { 40 | //Build example1 41 | val features1 = Features.newBuilder() 42 | .putFeature("LongFeature", longFeature) 43 | .putFeature("FloatFeature", floatFeature) 44 | .putFeature("StrFeature", strFeature) 45 | .putFeature("LongList", longFeature) 46 | .putFeature("FloatList", floatFeature) 47 | .putFeature("StrList", strFeature) 48 | .putFeature("MixedTypeList", longList) 49 | .build() 50 | val example1 = Example.newBuilder() 51 | .setFeatures(features1) 52 | .build() 53 | 54 | //Example2 contains subset of features in example1 to test behavior with missing features 55 | val features2 = Features.newBuilder() 56 | .putFeature("StrFeature", strFeature) 57 | .putFeature("LongList", longList) 58 | .putFeature("FloatList", floatList) 59 | .putFeature("StrList", strList) 60 | .putFeature("MixedTypeList", floatList) 61 | .build() 62 | val example2 = Example.newBuilder() 63 | .setFeatures(features2) 64 | .build() 65 | 66 | val exampleRdd: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2)) 67 | val inferredSchema = TensorFlowInferSchema(exampleRdd) 68 | 69 | //Verify each TensorFlow Datatype is inferred as one of our Datatype 70 | assert(inferredSchema.fields.length == 7) 71 | val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap 72 | assert(schemaMap("LongFeature") === LongType) 73 | assert(schemaMap("FloatFeature") === FloatType) 74 | assert(schemaMap("StrFeature") === StringType) 75 | assert(schemaMap("LongList") === ArrayType(LongType)) 76 | assert(schemaMap("FloatList") === ArrayType(FloatType)) 77 | assert(schemaMap("StrList") === ArrayType(StringType)) 78 | assert(schemaMap("MixedTypeList") === ArrayType(FloatType)) 79 | } 80 | 81 | "Infer schema from SequenceExample records" in { 82 | 83 | //Build sequence example1 84 | val features1 = Features.newBuilder() 85 | .putFeature("FloatFeature", floatFeature) 86 | 87 | val longFeatureList1 = FeatureList.newBuilder().addFeature(longFeature).addFeature(longList).build() 88 | val floatFeatureList1 = FeatureList.newBuilder().addFeature(floatFeature).addFeature(floatList).build() 89 | val strFeatureList1 = FeatureList.newBuilder().addFeature(strFeature).build() 90 | val mixedFeatureList1 = FeatureList.newBuilder().addFeature(floatFeature).addFeature(strList).build() 91 | 92 | val featureLists1 = FeatureLists.newBuilder() 93 | .putFeatureList("LongListOfLists", longFeatureList1) 94 | .putFeatureList("FloatListOfLists", floatFeatureList1) 95 | .putFeatureList("StringListOfLists", strFeatureList1) 96 | .putFeatureList("MixedListOfLists", mixedFeatureList1) 97 | .build() 98 | 99 | val seqExample1 = SequenceExample.newBuilder() 100 | .setContext(features1) 101 | .setFeatureLists(featureLists1) 102 | .build() 103 | 104 | //SequenceExample2 contains subset of features in example1 to test behavior with missing features 105 | val longFeatureList2 = FeatureList.newBuilder().addFeature(longList).build() 106 | val floatFeatureList2 = FeatureList.newBuilder().addFeature(floatFeature).build() 107 | val strFeatureList2 = FeatureList.newBuilder().addFeature(strFeature).build() //test both string lists of length=1 108 | val mixedFeatureList2 = FeatureList.newBuilder().addFeature(longFeature).addFeature(strFeature).build() 109 | 110 | val featureLists2 = FeatureLists.newBuilder() 111 | .putFeatureList("LongListOfLists", longFeatureList2) 112 | .putFeatureList("FloatListOfLists", floatFeatureList2) 113 | .putFeatureList("StringListOfLists", strFeatureList2) 114 | .putFeatureList("MixedListOfLists", mixedFeatureList2) 115 | .build() 116 | 117 | val seqExample2 = SequenceExample.newBuilder() 118 | .setFeatureLists(featureLists2) 119 | .build() 120 | 121 | val seqExampleRdd: RDD[SequenceExample] = spark.sparkContext.parallelize(List(seqExample1, seqExample2)) 122 | val inferredSchema = TensorFlowInferSchema(seqExampleRdd) 123 | 124 | //Verify each TensorFlow Datatype is inferred as one of our Datatype 125 | assert(inferredSchema.fields.length == 5) 126 | val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap 127 | assert(schemaMap("FloatFeature") === FloatType) 128 | assert(schemaMap("LongListOfLists") === ArrayType(ArrayType(LongType))) 129 | assert(schemaMap("FloatListOfLists") === ArrayType(ArrayType(FloatType))) 130 | assert(schemaMap("StringListOfLists") === ArrayType(ArrayType(StringType))) 131 | assert(schemaMap("MixedListOfLists") === ArrayType(ArrayType(StringType))) 132 | } 133 | } 134 | 135 | "Throw an exception for unsupported record types" in { 136 | intercept[Exception] { 137 | val rdd: RDD[Long] = spark.sparkContext.parallelize(List(5L, 6L)) 138 | TensorFlowInferSchema(rdd) 139 | } 140 | } 141 | 142 | "Should have a nullType if there are no elements" in { 143 | //Build sequence example1 144 | val features1 = Features.newBuilder().putFeature("emptyFloatFeature", emptyFloatList) 145 | 146 | val seqExample1 = SequenceExample.newBuilder() 147 | .setContext(features1) 148 | .build() 149 | 150 | val seqExampleRdd: RDD[SequenceExample] = spark.sparkContext.parallelize(List(seqExample1)) 151 | val inferredSchema = TensorFlowInferSchema(seqExampleRdd) 152 | assert(inferredSchema.fields.length == 1) 153 | val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap 154 | assert(schemaMap("emptyFloatFeature") === NullType) 155 | } 156 | } 157 | 158 | -------------------------------------------------------------------------------- /src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.linkedin.spark.datasources.tfrecord 17 | 18 | import java.io.File 19 | 20 | import org.apache.commons.io.FileUtils 21 | import org.apache.spark.SharedSparkSession 22 | import org.junit.{After, Before} 23 | import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} 24 | 25 | 26 | trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll 27 | 28 | class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { 29 | val TF_SANDBOX_DIR = "tf-sandbox" 30 | val file = new File(TF_SANDBOX_DIR) 31 | 32 | @Before 33 | override def beforeAll() = { 34 | super.setUp() 35 | FileUtils.deleteQuietly(file) 36 | file.mkdirs() 37 | } 38 | 39 | @After 40 | override def afterAll() = { 41 | FileUtils.deleteQuietly(file) 42 | super.tearDown() 43 | } 44 | } 45 | 46 | -------------------------------------------------------------------------------- /src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.linkedin.spark.datasources.tfrecord 17 | 18 | import com.google.protobuf.ByteString 19 | import org.apache.spark.sql.catalyst.InternalRow 20 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 21 | import org.apache.spark.sql.types._ 22 | import org.apache.spark.unsafe.types.UTF8String 23 | import org.scalatest.{Matchers, WordSpec} 24 | import org.tensorflow.example._ 25 | import TestingUtils._ 26 | 27 | 28 | class TFRecordDeserializerTest extends WordSpec with Matchers { 29 | val intFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(1)).build() 30 | val longFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(23L)).build() 31 | val floatFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(10.0F)).build() 32 | val doubleFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(14.0F)).build() 33 | val decimalFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(2.5F)).build() 34 | val longArrFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(-2L).addValue(7L).build()).build() 35 | val doubleArrFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(1F).addValue(2F).build()).build() 36 | val decimalArrFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(3F).addValue(5F).build()).build() 37 | val strFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build()).build() 38 | val strListFeature =Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)) 39 | .addValue(ByteString.copyFrom("r3".getBytes)).build()).build() 40 | val binaryFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r4".getBytes))).build() 41 | val binaryListFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r5".getBytes)) 42 | .addValue(ByteString.copyFrom("r6".getBytes)).build()).build() 43 | 44 | private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) 45 | 46 | "Deserialize tfrecord to spark internalRow" should { 47 | 48 | "Deserialize ByteArray to internalRow" in { 49 | val schema = StructType(Array( 50 | StructField("ByteArray", BinaryType) 51 | )) 52 | 53 | val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte) 54 | val expectedInternalRow = InternalRow(byteArray) 55 | val deserializer = new TFRecordDeserializer(schema) 56 | val actualInternalRow = deserializer.deserializeByteArray(byteArray) 57 | 58 | assert(actualInternalRow ~== (expectedInternalRow,schema)) 59 | } 60 | 61 | "Serialize tfrecord example to spark internalRow" in { 62 | val schema = StructType(List( 63 | StructField("IntegerLabel", IntegerType), 64 | StructField("LongLabel", LongType), 65 | StructField("FloatLabel", FloatType), 66 | StructField("DoubleLabel", DoubleType), 67 | StructField("DecimalLabel", DataTypes.createDecimalType()), 68 | StructField("LongArrayLabel", ArrayType(LongType)), 69 | StructField("DoubleArrayLabel", ArrayType(DoubleType)), 70 | StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())), 71 | StructField("StrLabel", StringType), 72 | StructField("StrArrayLabel", ArrayType(StringType)), 73 | StructField("BinaryTypeLabel", BinaryType), 74 | StructField("BinaryTypeArrayLabel", ArrayType(BinaryType)) 75 | )) 76 | 77 | val expectedInternalRow = InternalRow.fromSeq( 78 | Array[Any](1, 23L, 10.0F, 14.0, Decimal(2.5d), 79 | createArray(-2L,7L), 80 | createArray(1.0, 2.0), 81 | createArray(Decimal(3.0), Decimal(5.0)), 82 | UTF8String.fromString("r1"), 83 | createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")), 84 | "r4".getBytes, 85 | createArray("r5".getBytes(), "r6".getBytes()) 86 | ) 87 | ) 88 | 89 | //Build example 90 | val features = Features.newBuilder() 91 | .putFeature("IntegerLabel", intFeature) 92 | .putFeature("LongLabel", longFeature) 93 | .putFeature("FloatLabel", floatFeature) 94 | .putFeature("DoubleLabel", doubleFeature) 95 | .putFeature("DecimalLabel", decimalFeature) 96 | .putFeature("LongArrayLabel", longArrFeature) 97 | .putFeature("DoubleArrayLabel", doubleArrFeature) 98 | .putFeature("DecimalArrayLabel", decimalArrFeature) 99 | .putFeature("StrLabel", strFeature) 100 | .putFeature("StrArrayLabel", strListFeature) 101 | .putFeature("BinaryTypeLabel", binaryFeature) 102 | .putFeature("BinaryTypeArrayLabel", binaryListFeature) 103 | .build() 104 | val example = Example.newBuilder() 105 | .setFeatures(features) 106 | .build() 107 | val deserializer = new TFRecordDeserializer(schema) 108 | val actualInternalRow = deserializer.deserializeExample(example) 109 | 110 | assert(actualInternalRow ~== (expectedInternalRow,schema)) 111 | } 112 | 113 | "Serialize spark internalRow to tfrecord sequenceExample" in { 114 | 115 | val schema = StructType(List( 116 | StructField("FloatLabel", FloatType), 117 | StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), 118 | StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), 119 | StructField("DecimalArrayOfArrayLabel", ArrayType(ArrayType(DataTypes.createDecimalType()))), 120 | StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))), 121 | StructField("ByteArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) 122 | )) 123 | 124 | val expectedInternalRow = InternalRow.fromSeq( 125 | Array[Any](10.0F, 126 | createArray(createArray(-2L, 7L)), 127 | createArray(createArray(10.0F), createArray(1.0F, 2.0F)), 128 | createArray(createArray(Decimal(3.0), Decimal(5.0))), 129 | createArray(createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")), 130 | createArray(UTF8String.fromString("r1"))), 131 | createArray(createArray("r5".getBytes, "r6".getBytes), createArray("r4".getBytes)) 132 | ) 133 | ) 134 | 135 | //Build sequence example 136 | val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() 137 | val floatFeatureList = FeatureList.newBuilder().addFeature(floatFeature).addFeature(doubleArrFeature).build() 138 | val decimalFeatureList = FeatureList.newBuilder().addFeature(decimalArrFeature).build() 139 | val stringFeatureList = FeatureList.newBuilder().addFeature(strListFeature).addFeature(strFeature).build() 140 | val binaryFeatureList = FeatureList.newBuilder().addFeature(binaryListFeature).addFeature(binaryFeature).build() 141 | 142 | 143 | val features = Features.newBuilder() 144 | .putFeature("FloatLabel", floatFeature) 145 | 146 | val featureLists = FeatureLists.newBuilder() 147 | .putFeatureList("LongArrayOfArrayLabel", int64FeatureList) 148 | .putFeatureList("FloatArrayOfArrayLabel", floatFeatureList) 149 | .putFeatureList("DecimalArrayOfArrayLabel", decimalFeatureList) 150 | .putFeatureList("StrArrayOfArrayLabel", stringFeatureList) 151 | .putFeatureList("ByteArrayOfArrayLabel", binaryFeatureList) 152 | .build() 153 | 154 | val seqExample = SequenceExample.newBuilder() 155 | .setContext(features) 156 | .setFeatureLists(featureLists) 157 | .build() 158 | 159 | val deserializer = new TFRecordDeserializer(schema) 160 | val actualInternalRow = deserializer.deserializeSequenceExample(seqExample) 161 | assert(actualInternalRow ~== (expectedInternalRow, schema)) 162 | } 163 | 164 | "Throw an exception for unsupported data types" in { 165 | 166 | val features = Features.newBuilder().putFeature("MapLabel1", intFeature) 167 | val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() 168 | val featureLists = FeatureLists.newBuilder().putFeatureList("MapLabel2", int64FeatureList) 169 | 170 | intercept[RuntimeException] { 171 | val example = Example.newBuilder() 172 | .setFeatures(features) 173 | .build() 174 | val schema = StructType(List(StructField("MapLabel1", TimestampType))) 175 | val deserializer = new TFRecordDeserializer(schema) 176 | deserializer.deserializeExample(example) 177 | } 178 | 179 | intercept[RuntimeException] { 180 | val seqExample = SequenceExample.newBuilder() 181 | .setContext(features) 182 | .setFeatureLists(featureLists) 183 | .build() 184 | val schema = StructType(List(StructField("MapLabel2", TimestampType))) 185 | val deserializer = new TFRecordDeserializer(schema) 186 | deserializer.deserializeSequenceExample(seqExample) 187 | } 188 | } 189 | 190 | "Throw an exception for non-nullable data types" in { 191 | val features = Features.newBuilder().putFeature("FloatLabel", floatFeature) 192 | val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() 193 | val featureLists = FeatureLists.newBuilder().putFeatureList("LongArrayOfArrayLabel", int64FeatureList) 194 | 195 | intercept[NullPointerException] { 196 | val example = Example.newBuilder() 197 | .setFeatures(features) 198 | .build() 199 | val schema = StructType(List(StructField("MissingLabel", FloatType, nullable = false))) 200 | val deserializer = new TFRecordDeserializer(schema) 201 | deserializer.deserializeExample(example) 202 | } 203 | 204 | intercept[NullPointerException] { 205 | val seqExample = SequenceExample.newBuilder() 206 | .setContext(features) 207 | .setFeatureLists(featureLists) 208 | .build() 209 | val schema = StructType(List(StructField("MissingLabel", ArrayType(ArrayType(LongType)), nullable = false))) 210 | val deserializer = new TFRecordDeserializer(schema) 211 | deserializer.deserializeSequenceExample(seqExample) 212 | } 213 | } 214 | 215 | 216 | "Return null fields for nullable data types" in { 217 | val features = Features.newBuilder().putFeature("FloatLabel", floatFeature) 218 | val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() 219 | val featureLists = FeatureLists.newBuilder().putFeatureList("LongArrayOfArrayLabel", int64FeatureList) 220 | 221 | // Deserialize Example 222 | val schema1 = StructType(List( 223 | StructField("FloatLabel", FloatType), 224 | StructField("MissingLabel", FloatType, nullable = true)) 225 | ) 226 | val expectedInternalRow1 = InternalRow.fromSeq( 227 | Array[Any](10.0F, null) 228 | ) 229 | val example = Example.newBuilder() 230 | .setFeatures(features) 231 | .build() 232 | val deserializer1 = new TFRecordDeserializer(schema1) 233 | val actualInternalRow1 = deserializer1.deserializeExample(example) 234 | assert(actualInternalRow1 ~== (expectedInternalRow1, schema1)) 235 | 236 | // Deserialize SequenceExample 237 | val schema2 = StructType(List( 238 | StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), 239 | StructField("MissingLabel", ArrayType(ArrayType(LongType)), nullable = true)) 240 | ) 241 | val expectedInternalRow2 = InternalRow.fromSeq( 242 | Array[Any]( 243 | createArray(createArray(-2L, 7L)), null) 244 | ) 245 | val seqExample = SequenceExample.newBuilder() 246 | .setContext(features) 247 | .setFeatureLists(featureLists) 248 | .build() 249 | val deserializer2 = new TFRecordDeserializer(schema2) 250 | val actualInternalRow2 = deserializer2.deserializeSequenceExample(seqExample) 251 | assert(actualInternalRow2 ~== (expectedInternalRow2, schema2)) 252 | 253 | } 254 | 255 | val schema = StructType(Array( 256 | StructField("LongLabel", LongType)) 257 | ) 258 | val deserializer = new TFRecordDeserializer(schema) 259 | 260 | "Test Int64ListFeature2SeqLong" in { 261 | val int64List = Int64List.newBuilder().addValue(5L).build() 262 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 263 | assert(deserializer.Int64ListFeature2SeqLong(intFeature).head === 5L) 264 | 265 | // Throw exception if type doesn't match 266 | intercept[RuntimeException] { 267 | val floatList = FloatList.newBuilder().addValue(2.5F).build() 268 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 269 | deserializer.Int64ListFeature2SeqLong(floatFeature) 270 | } 271 | } 272 | 273 | "Test floatListFeature2SeqFloat" in { 274 | val floatList = FloatList.newBuilder().addValue(2.5F).build() 275 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build() 276 | assert(deserializer.floatListFeature2SeqFloat(floatFeature).head === 2.5F) 277 | 278 | // Throw exception if type doesn't match 279 | intercept[RuntimeException] { 280 | val int64List = Int64List.newBuilder().addValue(5L).build() 281 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 282 | deserializer.floatListFeature2SeqFloat(intFeature) 283 | } 284 | } 285 | 286 | "Test bytesListFeature2SeqArrayByte" in { 287 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() 288 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() 289 | assert(deserializer.bytesListFeature2SeqArrayByte(bytesFeature).head.sameElements("str-input".getBytes)) 290 | 291 | // Throw exception if type doesn't match 292 | intercept[RuntimeException] { 293 | val int64List = Int64List.newBuilder().addValue(5L).build() 294 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 295 | deserializer.bytesListFeature2SeqArrayByte(intFeature) 296 | } 297 | } 298 | 299 | "Test bytesListFeature2SeqString" in { 300 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("alice".getBytes)) 301 | .addValue(ByteString.copyFrom("bob".getBytes)).build() 302 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() 303 | assert(deserializer.bytesListFeature2SeqString(bytesFeature) === Seq("alice", "bob")) 304 | 305 | // Throw exception if type doesn't match 306 | intercept[RuntimeException] { 307 | val int64List = Int64List.newBuilder().addValue(5L).build() 308 | val intFeature = Feature.newBuilder().setInt64List(int64List).build() 309 | deserializer.bytesListFeature2SeqArrayByte(intFeature) 310 | } 311 | } 312 | 313 | "Test deserialize rows with different features should not inherit features from previous rows" in { 314 | // given features 315 | val floatFeatures = Features.newBuilder().putFeature("FloatLabel", floatFeature) 316 | val int64Features = Features.newBuilder().putFeature("IntLabel", intFeature) 317 | 318 | // Define common deserializer and schema 319 | val schema = StructType(List( 320 | StructField("FloatLabel", FloatType), 321 | StructField("IntLabel", IntegerType), 322 | StructField("MissingLabel", FloatType, nullable = true)) 323 | ) 324 | val deserializer = new TFRecordDeserializer(schema) 325 | 326 | // given rows with different features - row 1 has only FloatLabel feature 327 | val expectedInternalRow1 = InternalRow.fromSeq( 328 | Array[Any](10.0F, null, null) 329 | ) 330 | val example1 = Example.newBuilder() 331 | .setFeatures(floatFeatures) 332 | .build() 333 | val actualInternalRow1 = deserializer.deserializeExample(example1) 334 | assert(actualInternalRow1 ~== (expectedInternalRow1, schema)) 335 | 336 | // .. and the second row has only IntLabel feature 337 | val expectedInternalRow2 = InternalRow.fromSeq( 338 | Array[Any](null, 1, null) 339 | ) 340 | val example2 = Example.newBuilder() 341 | .setFeatures(int64Features) 342 | .build() 343 | val actualInternalRow2 = deserializer.deserializeExample(example2) 344 | assert(actualInternalRow2 ~== (expectedInternalRow2, schema)) 345 | 346 | } 347 | } 348 | } 349 | -------------------------------------------------------------------------------- /src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordIOSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.linkedin.spark.datasources.tfrecord 17 | 18 | import org.apache.hadoop.fs.Path 19 | import org.apache.spark.sql.catalyst.expressions.GenericRow 20 | import org.apache.spark.sql.types._ 21 | import org.apache.spark.sql.{DataFrame, Row, SaveMode} 22 | 23 | import TestingUtils._ 24 | 25 | class TFRecordIOSuite extends SharedSparkSessionSuite { 26 | 27 | val exampleSchema = StructType(List( 28 | StructField("id", IntegerType), 29 | StructField("IntegerLabel", IntegerType), 30 | StructField("LongLabel", LongType), 31 | StructField("FloatLabel", FloatType), 32 | StructField("DoubleLabel", DoubleType), 33 | StructField("DecimalLabel", DataTypes.createDecimalType()), 34 | StructField("StrLabel", StringType), 35 | StructField("BinaryLabel", BinaryType), 36 | StructField("IntegerArrayLabel", ArrayType(IntegerType)), 37 | StructField("LongArrayLabel", ArrayType(LongType)), 38 | StructField("FloatArrayLabel", ArrayType(FloatType)), 39 | StructField("DoubleArrayLabel", ArrayType(DoubleType, true)), 40 | StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())), 41 | StructField("StrArrayLabel", ArrayType(StringType, true)), 42 | StructField("BinaryArrayLabel", ArrayType(BinaryType), true)) 43 | ) 44 | 45 | val exampleTestRows: Array[Row] = Array( 46 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, Decimal(1.1), "r1", Array[Byte](0xff.toByte, 0xf0.toByte), 47 | Seq(1, 2), 48 | Seq(11L, 12L), 49 | Seq(1.2F, 2.1F), 50 | Seq(1.1, 2.2), 51 | Seq(Decimal(1.1), Decimal(2.2)), 52 | Seq("str1", "str2"), 53 | Seq(Array[Byte](0xfa.toByte, 0xfb.toByte), Array[Byte](0xfa.toByte)))), 54 | new GenericRow(Array[Any](11, 1, 24L, 11.0F, 15.0, Decimal(2.1), "r2", Array[Byte](0xfa.toByte, 0xfb.toByte), 55 | Seq(3, 4), 56 | Seq(110L, 120L), 57 | Seq(1.2F, 2.1F), 58 | Seq(1.1, 2.2), 59 | Seq(Decimal(2.1), Decimal(3.2)), 60 | Seq("str3", "str4"), 61 | Seq(Array[Byte](0xf1.toByte, 0xf2.toByte), Array[Byte](0xfa.toByte)))), 62 | new GenericRow(Array[Any](21, 1, 23L, 10.0F, 14.0, Decimal(3.1), "r3", Array[Byte](0xfc.toByte, 0xfd.toByte), 63 | Seq(5, 6), 64 | Seq(111L, 112L), 65 | Seq(1.22F, 2.11F), 66 | Seq(11.1, 12.2), 67 | Seq(Decimal(3.1), Decimal(4.2)), 68 | Seq("str5", "str6"), 69 | Seq(Array[Byte](0xf4.toByte, 0xf2.toByte), Array[Byte](0xfa.toByte))))) 70 | 71 | val sequenceExampleTestRows: Array[Row] = Array( 72 | new GenericRow(Array[Any](23L, Seq(Seq(2, 4)), Seq(Seq(-1.1F, 0.1F)), Seq(Seq("r1", "r2")))), 73 | new GenericRow(Array[Any](24L, Seq(Seq(-1, 0)), Seq(Seq(-1.1F, 0.2F)), Seq(Seq("r3"))))) 74 | 75 | val sequenceExampleSchema = StructType(List( 76 | StructField("id",LongType), 77 | StructField("IntegerArrayOfArrayLabel", ArrayType(ArrayType(IntegerType))), 78 | StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), 79 | StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))) 80 | )) 81 | 82 | val byteArrayTestRows: Array[Row] = Array( 83 | new GenericRow(Array[Any](Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte))) 84 | ) 85 | 86 | private def createDataFrameForExampleTFRecord() : DataFrame = { 87 | val rdd = spark.sparkContext.parallelize(exampleTestRows) 88 | spark.createDataFrame(rdd, exampleSchema) 89 | } 90 | 91 | private def createDataFrameForSequenceExampleTFRecords() : DataFrame = { 92 | val rdd = spark.sparkContext.parallelize(sequenceExampleTestRows) 93 | spark.createDataFrame(rdd, sequenceExampleSchema) 94 | } 95 | 96 | private def createDataFrameForByteArrayTFRecords() : DataFrame = { 97 | val rdd = spark.sparkContext.parallelize(byteArrayTestRows) 98 | spark.createDataFrame(rdd, TensorFlowInferSchema.getSchemaForByteArray()) 99 | } 100 | 101 | private def pathExists(pathStr: String): Boolean = { 102 | val hadoopConf = spark.sparkContext.hadoopConfiguration 103 | val outputPath = new Path(pathStr) 104 | val fs = outputPath.getFileSystem(hadoopConf) 105 | val qualifiedPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) 106 | fs.exists(qualifiedPath) 107 | } 108 | 109 | private def getFileCount(pathStr: String): Long = { 110 | val hadoopConf = spark.sparkContext.hadoopConfiguration 111 | val outputPath = new Path(pathStr) 112 | val fs = outputPath.getFileSystem(hadoopConf) 113 | val qualifiedPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) 114 | fs.getContentSummary(qualifiedPath).getFileCount 115 | } 116 | 117 | "Spark tfrecord IO" should { 118 | "Test tfrecord example Read/write " in { 119 | 120 | val path = s"$TF_SANDBOX_DIR/example.tfrecord" 121 | 122 | val df: DataFrame = createDataFrameForExampleTFRecord() 123 | df.write.format("tfrecord").option("recordType", "Example").save(path) 124 | 125 | //If schema is not provided. It will automatically infer schema 126 | val importedDf: DataFrame = spark.read.format("tfrecord").option("recordType", "Example").schema(exampleSchema).load(path) 127 | 128 | val actualDf = importedDf.select("id", "IntegerLabel", "LongLabel", "FloatLabel", 129 | "DoubleLabel", "DecimalLabel", "StrLabel", "BinaryLabel", "IntegerArrayLabel", "LongArrayLabel", 130 | "FloatArrayLabel", "DoubleArrayLabel", "DecimalArrayLabel", "StrArrayLabel", "BinaryArrayLabel").sort("StrLabel") 131 | 132 | val expectedRows = df.collect() 133 | val actualRows = actualDf.collect() 134 | 135 | expectedRows.zip(actualRows).foreach { case (expected: Row, actual: Row) => 136 | assert(expected ~== actual, exampleSchema) 137 | } 138 | } 139 | 140 | "Test tfrecord partition by id" in { 141 | val output = s"$TF_SANDBOX_DIR/example-partition-by-id.tfrecord" 142 | val df: DataFrame = createDataFrameForExampleTFRecord() 143 | df.write.format("tfrecord").partitionBy("id").option("recordType", "Example").save(output) 144 | assert(pathExists(output)) 145 | val partition1Path = s"$output/id=11" 146 | val partition2Path = s"$output/id=21" 147 | assert(pathExists(partition1Path)) 148 | assert(pathExists(partition2Path)) 149 | assert(getFileCount(partition1Path) == 2) 150 | assert(getFileCount(partition2Path) == 1) 151 | } 152 | 153 | "Test tfrecord read/write SequenceExample" in { 154 | 155 | val path = s"$TF_SANDBOX_DIR/sequenceExample.tfrecord" 156 | 157 | val df: DataFrame = createDataFrameForSequenceExampleTFRecords() 158 | df.write.format("tfrecord").option("recordType", "SequenceExample").save(path) 159 | 160 | val importedDf: DataFrame = spark.read.format("tfrecord").option("recordType", "SequenceExample").schema(sequenceExampleSchema).load(path) 161 | val actualDf = importedDf.select("id", "IntegerArrayOfArrayLabel", "FloatArrayOfArrayLabel", "StrArrayOfArrayLabel").sort("id") 162 | 163 | val expectedRows = df.collect() 164 | val actualRows = actualDf.collect() 165 | 166 | assert(expectedRows === actualRows) 167 | } 168 | 169 | "Test tfrecord read/write ByteArray" in { 170 | 171 | val path = s"$TF_SANDBOX_DIR/byteArray.tfrecord" 172 | 173 | val df: DataFrame = createDataFrameForByteArrayTFRecords() 174 | df.write.format("tfrecord").option("recordType", "ByteArray").save(path) 175 | 176 | val importedDf: DataFrame = spark.read.format("tfrecord").option("recordType", "ByteArray").load(path) 177 | 178 | val expectedRows = df.collect() 179 | val actualRows = importedDf.collect() 180 | 181 | assert(expectedRows === actualRows) 182 | } 183 | 184 | "Test tfrecord write overwrite mode " in { 185 | 186 | val path = s"$TF_SANDBOX_DIR/example_overwrite.tfrecord" 187 | 188 | val df: DataFrame = createDataFrameForExampleTFRecord() 189 | df.write.format("tfrecord").option("recordType", "Example").save(path) 190 | 191 | df.write.format("tfrecord").mode(SaveMode.Overwrite).option("recordType", "Example").save(path) 192 | 193 | //If schema is not provided. It will automatically infer schema 194 | val importedDf: DataFrame = spark.read.format("tfrecord").option("recordType", "Example").schema(exampleSchema).load(path) 195 | 196 | val actualDf = importedDf.select("id", "IntegerLabel", "LongLabel", "FloatLabel", 197 | "DoubleLabel", "DecimalLabel", "StrLabel", "BinaryLabel", "IntegerArrayLabel", "LongArrayLabel", 198 | "FloatArrayLabel", "DoubleArrayLabel", "DecimalArrayLabel", "StrArrayLabel", "BinaryArrayLabel").sort("StrLabel") 199 | 200 | val expectedRows = df.collect() 201 | val actualRows = actualDf.collect() 202 | 203 | expectedRows.zip(actualRows).foreach { case (expected: Row, actual: Row) => 204 | assert(expected ~== actual, exampleSchema) 205 | } 206 | } 207 | 208 | "Test tfrecord write append mode" in { 209 | 210 | val path = s"$TF_SANDBOX_DIR/example_append.tfrecord" 211 | 212 | val df: DataFrame = createDataFrameForExampleTFRecord() 213 | df.write.format("tfrecord").option("recordType", "Example").save(path) 214 | df.write.format("tfrecord").mode(SaveMode.Append).option("recordType", "Example").save(path) 215 | } 216 | 217 | "Test tfrecord write ignore mode" in { 218 | 219 | val path = s"$TF_SANDBOX_DIR/example_ignore.tfrecord" 220 | 221 | val hadoopConf = spark.sparkContext.hadoopConfiguration 222 | val outputPath = new Path(path) 223 | val fs = outputPath.getFileSystem(hadoopConf) 224 | val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) 225 | 226 | val df: DataFrame = createDataFrameForExampleTFRecord() 227 | df.write.format("tfrecord").mode(SaveMode.Ignore).option("recordType", "Example").save(path) 228 | 229 | assert(fs.exists(qualifiedOutputPath)) 230 | val timestamp1 = fs.getFileStatus(qualifiedOutputPath).getModificationTime 231 | 232 | df.write.format("tfrecord").mode(SaveMode.Ignore).option("recordType", "Example").save(path) 233 | 234 | val timestamp2 = fs.getFileStatus(qualifiedOutputPath).getModificationTime 235 | 236 | assert(timestamp1 == timestamp2, "SaveMode.Ignore Error: File was overwritten. Timestamps do not match") 237 | } 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.linkedin.spark.datasources.tfrecord 17 | 18 | import org.apache.spark.sql.catalyst.InternalRow 19 | import org.tensorflow.example._ 20 | import org.apache.spark.sql.types.{StructField, _} 21 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} 22 | import org.apache.spark.unsafe.types.UTF8String 23 | import org.scalatest.{Matchers, WordSpec} 24 | 25 | import scala.collection.JavaConverters._ 26 | import TestingUtils._ 27 | 28 | class TFRecordSerializerTest extends WordSpec with Matchers { 29 | 30 | private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) 31 | 32 | "Serialize spark internalRow to tfrecord" should { 33 | 34 | "Serialize ByteArray internalRow to ByteArray" in { 35 | val serializer = new TFRecordSerializer(new StructType()) 36 | 37 | val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte) 38 | val internalRow = InternalRow(byteArray) 39 | val serializedByteArray = serializer.serializeByteArray(internalRow) 40 | 41 | //two byte arrays should be the same, since serialization just gets byte array from internal row 42 | assert(byteArray.length == serializedByteArray.length) 43 | assert(byteArray.sameElements(serializedByteArray)) 44 | } 45 | 46 | "Serialize decimal internalRow to tfrecord example" in { 47 | val schemaStructType = StructType(Array( 48 | StructField("DecimalLabel", DataTypes.createDecimalType()), 49 | StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())) 50 | )) 51 | val serializer = new TFRecordSerializer(schemaStructType) 52 | 53 | val decimalArray = Array(Decimal(4.0), Decimal(8.0)) 54 | val rowArray = Array[Any](Decimal(6.5), ArrayData.toArrayData(decimalArray)) 55 | val internalRow = InternalRow.fromSeq(rowArray) 56 | 57 | //Encode Sql InternalRow to TensorFlow example 58 | val example = serializer.serializeExample(internalRow) 59 | 60 | //Verify each Datatype converted to TensorFlow datatypes 61 | val featureMap = example.getFeatures.getFeatureMap.asScala 62 | assert(featureMap.size == rowArray.length) 63 | 64 | assert(featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 65 | assert(featureMap("DecimalLabel").getFloatList.getValue(0) == 6.5F) 66 | 67 | assert(featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 68 | assert(featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== decimalArray.map(_.toFloat)) 69 | } 70 | 71 | "Serialize complex internalRow to tfrecord example" in { 72 | val schemaStructType = StructType(Array( 73 | StructField("IntegerLabel", IntegerType), 74 | StructField("LongLabel", LongType), 75 | StructField("FloatLabel", FloatType), 76 | StructField("DoubleLabel", DoubleType), 77 | StructField("DecimalLabel", DataTypes.createDecimalType()), 78 | StructField("DoubleArrayLabel", ArrayType(DoubleType)), 79 | StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())), 80 | StructField("StrLabel", StringType), 81 | StructField("StrArrayLabel", ArrayType(StringType)), 82 | StructField("BinaryLabel", BinaryType), 83 | StructField("BinaryArrayLabel", ArrayType(BinaryType)) 84 | )) 85 | val doubleArray = Array(1.1, 111.1, 11111.1) 86 | val decimalArray = Array(Decimal(4.0), Decimal(8.0)) 87 | val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte) 88 | val byteArray1 = Array[Byte](-128, 23, 127) 89 | 90 | val rowArray = Array[Any](1, 23L, 10.0F, 14.0, Decimal(6.5), 91 | ArrayData.toArrayData(doubleArray), 92 | ArrayData.toArrayData(decimalArray), 93 | UTF8String.fromString("r1"), 94 | ArrayData.toArrayData(Array(UTF8String.fromString("r2"), UTF8String.fromString("r3"))), 95 | byteArray, 96 | ArrayData.toArrayData(Array(byteArray, byteArray1)) 97 | ) 98 | 99 | val internalRow = InternalRow.fromSeq(rowArray) 100 | 101 | val serializer = new TFRecordSerializer(schemaStructType) 102 | val example = serializer.serializeExample(internalRow) 103 | 104 | //Verify each Datatype converted to TensorFlow datatypes 105 | val featureMap = example.getFeatures.getFeatureMap.asScala 106 | assert(featureMap.size == rowArray.length) 107 | 108 | assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) 109 | assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 1) 110 | 111 | assert(featureMap("LongLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) 112 | assert(featureMap("LongLabel").getInt64List.getValue(0).toInt == 23) 113 | 114 | assert(featureMap("FloatLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 115 | assert(featureMap("FloatLabel").getFloatList.getValue(0) == 10.0F) 116 | 117 | assert(featureMap("DoubleLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 118 | assert(featureMap("DoubleLabel").getFloatList.getValue(0) == 14.0F) 119 | 120 | assert(featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 121 | assert(featureMap("DecimalLabel").getFloatList.getValue(0) == 6.5F) 122 | 123 | assert(featureMap("DoubleArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 124 | assert(featureMap("DoubleArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== doubleArray.map(_.toFloat)) 125 | 126 | assert(featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 127 | assert(featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== decimalArray.map(_.toFloat)) 128 | 129 | assert(featureMap("StrLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) 130 | assert(featureMap("StrLabel").getBytesList.getValue(0).toStringUtf8 == "r1") 131 | 132 | assert(featureMap("StrArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) 133 | assert(featureMap("StrArrayLabel").getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("r2", "r3")) 134 | 135 | assert(featureMap("BinaryLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) 136 | assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.sameElements(byteArray)) 137 | 138 | assert(featureMap("BinaryArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) 139 | val binaryArrayValue = featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala.map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) 140 | binaryArrayValue.toArray should equal(Array(byteArray, byteArray1)) 141 | } 142 | 143 | "Serialize internalRow to tfrecord sequenceExample" in { 144 | 145 | val schemaStructType = StructType(Array( 146 | StructField("IntegerLabel", IntegerType), 147 | StructField("StringArrayLabel", ArrayType(StringType)), 148 | StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), 149 | StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))) , 150 | StructField("DoubleArrayOfArrayLabel", ArrayType(ArrayType(DoubleType))), 151 | StructField("DecimalArrayOfArrayLabel", ArrayType(ArrayType(DataTypes.createDecimalType()))), 152 | StructField("StringArrayOfArrayLabel", ArrayType(ArrayType(StringType))), 153 | StructField("BinaryArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) 154 | )) 155 | 156 | val stringList = Array(UTF8String.fromString("r1"), UTF8String.fromString("r2"), UTF8String.fromString(("r3"))) 157 | val longListOfLists = Array(Array(3L, 5L), Array(-8L, 0L)) 158 | val floatListOfLists = Array(Array(1.5F, -6.5F), Array(-8.2F, 0F)) 159 | val doubleListOfLists = Array(Array(3.0), Array(6.0, 9.0)) 160 | val decimalListOfLists = Array(Array(Decimal(2.0), Decimal(4.0)), Array(Decimal(6.0))) 161 | val stringListOfLists = Array(Array(UTF8String.fromString("r1")), 162 | Array(UTF8String.fromString("r2"), UTF8String.fromString("r3")), 163 | Array(UTF8String.fromString("r4"))) 164 | val binaryListOfLists = stringListOfLists.map(stringList => stringList.map(_.getBytes)) 165 | 166 | val rowArray = Array[Any](10, 167 | createArray(UTF8String.fromString("r1"), UTF8String.fromString("r2"), UTF8String.fromString(("r3"))), 168 | createArray( 169 | createArray(3L, 5L), 170 | createArray(-8L, 0L) 171 | ), 172 | createArray( 173 | createArray(1.5F, -6.5F), 174 | createArray(-8.2F, 0F) 175 | ), 176 | createArray( 177 | createArray(3.0), 178 | createArray(6.0, 9.0) 179 | ), 180 | createArray( 181 | createArray(Decimal(2.0), Decimal(4.0)), 182 | createArray(Decimal(6.0)) 183 | ), 184 | createArray( 185 | createArray(UTF8String.fromString("r1")), 186 | createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")), 187 | createArray(UTF8String.fromString("r4")) 188 | ), 189 | createArray(createArray("r1".getBytes()), 190 | createArray("r2".getBytes(), "r3".getBytes), 191 | createArray("r4".getBytes()) 192 | ) 193 | ) 194 | 195 | val internalRow = InternalRow.fromSeq(rowArray) 196 | 197 | val serializer = new TFRecordSerializer(schemaStructType) 198 | val tfexample = serializer.serializeSequenceExample(internalRow) 199 | 200 | //Verify each Datatype converted to TensorFlow datatypes 201 | val featureMap = tfexample.getContext.getFeatureMap.asScala 202 | val featureListMap = tfexample.getFeatureLists.getFeatureListMap.asScala 203 | 204 | assert(featureMap.size == 2) 205 | assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) 206 | assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 10) 207 | assert(featureMap("StringArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) 208 | assert(featureMap("StringArrayLabel").getBytesList.getValueList.asScala.map(x => UTF8String.fromString(x.toStringUtf8)) === stringList) 209 | 210 | assert(featureListMap.size == 6) 211 | assert(featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala.map( 212 | _.getInt64List.getValueList.asScala.toSeq) === longListOfLists) 213 | 214 | assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.toSeq.map( 215 | _.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists.map{arr => arr.toSeq}.toSeq) 216 | assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.toSeq.map( 217 | _.getFloatList.getValueList.asScala.map(_.toDouble).toSeq) ~== doubleListOfLists.map{arr => arr.toSeq}.toSeq) 218 | 219 | assert(featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.toSeq.map( 220 | _.getFloatList.getValueList.asScala.map(x => Decimal(x.toDouble)).toSeq) ~== decimalListOfLists.map{arr => arr.toSeq}.toSeq) 221 | 222 | assert(featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map( 223 | _.getBytesList.getValueList.asScala.map(x => UTF8String.fromString(x.toStringUtf8)).toSeq) === stringListOfLists) 224 | 225 | assert(featureListMap("BinaryArrayOfArrayLabel").getFeatureList.asScala.map( 226 | _.getBytesList.getValueList.asScala.map(byteList => byteList.asScala.toSeq)) === binaryListOfLists.map(_.map(_.toSeq))) 227 | } 228 | 229 | "Throw an exception for non-nullable data types" in { 230 | val schemaStructType = StructType(Array( 231 | StructField("NonNullLabel", ArrayType(FloatType), nullable = false) 232 | )) 233 | 234 | val internalRow = InternalRow.fromSeq(Array[Any](null)) 235 | 236 | val serializer = new TFRecordSerializer(schemaStructType) 237 | 238 | intercept[NullPointerException]{ 239 | serializer.serializeExample(internalRow) 240 | } 241 | 242 | intercept[NullPointerException]{ 243 | serializer.serializeSequenceExample(internalRow) 244 | } 245 | } 246 | 247 | "Omit null fields from Example for nullable data types" in { 248 | val schemaStructType = StructType(Array( 249 | StructField("NullLabel", ArrayType(FloatType), nullable = true), 250 | StructField("FloatArrayLabel", ArrayType(FloatType)) 251 | )) 252 | 253 | val floatArray = Array(2.5F, 5.0F) 254 | val internalRow = InternalRow.fromSeq( 255 | Array[Any](null, createArray(2.5F, 5.0F)) 256 | ) 257 | 258 | val serializer = new TFRecordSerializer(schemaStructType) 259 | val tfexample = serializer.serializeExample(internalRow) 260 | 261 | //Verify each Datatype converted to TensorFlow datatypes 262 | val featureMap = tfexample.getFeatures.getFeatureMap.asScala 263 | assert(featureMap.size == 1) 264 | assert(featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 265 | assert(featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== floatArray.toSeq) 266 | } 267 | 268 | "Omit null fields from SequenceExample for nullable data types" in { 269 | val schemaStructType = StructType(Array( 270 | StructField("NullLabel", ArrayType(FloatType), nullable = true), 271 | StructField("FloatArrayLabel", ArrayType(FloatType)) 272 | )) 273 | 274 | val floatArray = Array(2.5F, 5.0F) 275 | val internalRow = InternalRow.fromSeq( 276 | Array[Any](null, createArray(2.5F, 5.0F))) 277 | 278 | val serializer = new TFRecordSerializer(schemaStructType) 279 | val tfSeqExample = serializer.serializeSequenceExample(internalRow) 280 | 281 | //Verify each Datatype converted to TensorFlow datatypes 282 | val featureMap = tfSeqExample.getContext.getFeatureMap.asScala 283 | val featureListMap = tfSeqExample.getFeatureLists.getFeatureListMap.asScala 284 | assert(featureMap.size == 1) 285 | assert(featureListMap.isEmpty) 286 | assert(featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) 287 | assert(featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== floatArray.toSeq) 288 | } 289 | 290 | "Throw an exception for unsupported data types" in { 291 | 292 | val schemaStructType = StructType(Array( 293 | StructField("TimestampLabel", TimestampType) 294 | )) 295 | 296 | intercept[RuntimeException]{ 297 | new TFRecordSerializer(schemaStructType) 298 | } 299 | } 300 | 301 | val schema = StructType(Array( 302 | StructField("bytesLabel", BinaryType)) 303 | ) 304 | val serializer = new TFRecordSerializer(schema) 305 | 306 | "Test Int64ListFeature" in { 307 | val longFeature = serializer.Int64ListFeature(Seq(10L)) 308 | val longListFeature = serializer.Int64ListFeature(Seq(3L,5L,6L)) 309 | 310 | assert(longFeature.getInt64List.getValueList.asScala.toSeq === Seq(10L)) 311 | assert(longListFeature.getInt64List.getValueList.asScala.toSeq === Seq(3L, 5L, 6L)) 312 | } 313 | 314 | "Test floatListFeature" in { 315 | val floatFeature = serializer.floatListFeature(Seq(10.1F)) 316 | val floatListFeature = serializer.floatListFeature(Seq(3.1F,5.1F,6.1F)) 317 | 318 | assert(floatFeature.getFloatList.getValueList.asScala.toSeq === Seq(10.1F)) 319 | assert(floatListFeature.getFloatList.getValueList.asScala.toSeq === Seq(3.1F,5.1F,6.1F)) 320 | } 321 | 322 | "Test bytesListFeature" in { 323 | val bytesFeature = serializer.bytesListFeature(Seq(Array(0xff.toByte, 0xd8.toByte))) 324 | val bytesListFeature = serializer.bytesListFeature(Seq( 325 | Array(0xff.toByte, 0xd8.toByte), 326 | Array(0xff.toByte, 0xd9.toByte))) 327 | 328 | bytesFeature.getBytesList.getValueList.asScala.map(_.toByteArray).toArray should equal( 329 | Array(Array(0xff.toByte, 0xd8.toByte))) 330 | bytesListFeature.getBytesList.getValueList.asScala.map(_.toByteArray).toArray should equal( 331 | Array(Array(0xff.toByte, 0xd8.toByte), Array(0xff.toByte, 0xd9.toByte))) 332 | } 333 | } 334 | } 335 | -------------------------------------------------------------------------------- /src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.linkedin.spark.datasources.tfrecord 17 | 18 | import org.apache.spark.sql.Row 19 | import org.apache.spark.sql.catalyst.InternalRow 20 | import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, SpecializedGetters} 21 | import org.apache.spark.sql.catalyst.util.ArrayData 22 | import org.apache.spark.sql.types._ 23 | import org.scalatest.Matchers 24 | 25 | object TestingUtils extends Matchers { 26 | 27 | /** 28 | * Implicit class for comparing two double values using absolute tolerance. 29 | */ 30 | implicit class FloatArrayWithAlmostEquals(val left: Seq[Float]) { 31 | 32 | /** 33 | * When the difference of two values are within eps, returns true; otherwise, returns false. 34 | */ 35 | def ~==(right: Seq[Float], epsilon : Float = 1E-6F): Boolean = { 36 | if (left.size === right.size) { 37 | (left zip right) forall { case (a, b) => a === (b +- epsilon) } 38 | } 39 | else false 40 | } 41 | } 42 | 43 | /** 44 | * Implicit class for comparing two double values using absolute tolerance. 45 | */ 46 | implicit class DoubleArrayWithAlmostEquals(val left: Seq[Double]) { 47 | 48 | /** 49 | * When the difference of two values are within eps, returns true; otherwise, returns false. 50 | */ 51 | def ~==(right: Seq[Double], epsilon : Double = 1E-6): Boolean = { 52 | if (left.size === right.size) { 53 | (left zip right) forall { case (a, b) => a === (b +- epsilon) } 54 | } 55 | else false 56 | } 57 | } 58 | 59 | /** 60 | * Implicit class for comparing two decimal values using absolute tolerance. 61 | */ 62 | implicit class DecimalArrayWithAlmostEquals(val left: Seq[Decimal]) { 63 | 64 | /** 65 | * When the difference of two values are within eps, returns true; otherwise, returns false. 66 | */ 67 | def ~==(right: Seq[Decimal], epsilon : Double = 1E-6): Boolean = { 68 | if (left.size === right.size) { 69 | (left zip right) forall { case (a, b) => a.toDouble === (b.toDouble +- epsilon) } 70 | } 71 | else false 72 | } 73 | } 74 | 75 | /** 76 | * Implicit class for comparing two double values using absolute tolerance. 77 | */ 78 | implicit class FloatMatrixWithAlmostEquals(val left: Seq[Seq[Float]]) { 79 | 80 | /** 81 | * When the difference of two values are within eps, returns true; otherwise, returns false. 82 | */ 83 | def ~==(right: Seq[Seq[Float]], epsilon : Float = 1E-6F): Boolean = { 84 | if (left.size === right.size) { 85 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) } 86 | } 87 | else false 88 | } 89 | } 90 | 91 | /** 92 | * Implicit class for comparing two double values using absolute tolerance. 93 | */ 94 | implicit class DoubleMatrixWithAlmostEquals(val left: Seq[Seq[Double]]) { 95 | 96 | /** 97 | * When the difference of two values are within eps, returns true; otherwise, returns false. 98 | */ 99 | def ~==(right: Seq[Seq[Double]], epsilon : Double = 1E-6): Boolean = { 100 | if (left.size === right.size) { 101 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) } 102 | } 103 | else false 104 | } 105 | } 106 | 107 | /** 108 | * Implicit class for comparing two decimal values using absolute tolerance. 109 | */ 110 | implicit class DecimalMatrixWithAlmostEquals(val left: Seq[Seq[Decimal]]) { 111 | 112 | /** 113 | * When the difference of two values are within eps, returns true; otherwise, returns false. 114 | */ 115 | def ~==(right: Seq[Seq[Decimal]], epsilon : Double = 1E-6): Boolean = { 116 | if (left.size === right.size) { 117 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) } 118 | } 119 | else false 120 | } 121 | } 122 | 123 | /** 124 | * Implicit class for comparing two internalRows using absolute tolerance. 125 | */ 126 | implicit class InternalRowWithAlmostEquals(val left: InternalRow) { 127 | 128 | private type valueCompare = (SpecializedGetters, SpecializedGetters, Int) => Boolean 129 | private def newValueCompare( 130 | dataType: DataType, 131 | epsilon : Float = 1E-6F): valueCompare = dataType match { 132 | case NullType => (left, right, ordinal) => 133 | left.get(ordinal, null) == right.get(ordinal, null) 134 | 135 | case IntegerType => (left, right, ordinal) => 136 | left.getInt(ordinal) === right.getInt(ordinal) 137 | 138 | case LongType => (left, right, ordinal) => 139 | left.getLong(ordinal) === right.getLong(ordinal) 140 | 141 | case FloatType => (left, right, ordinal) => 142 | left.getFloat(ordinal) === (right.getFloat(ordinal) +- epsilon) 143 | 144 | case DoubleType => (left, right, ordinal) => 145 | left.getDouble(ordinal) === (right.getDouble(ordinal) +- epsilon) 146 | 147 | case DecimalType() => (left, right, ordinal) => 148 | left.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale).toDouble === 149 | (right.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale).toDouble 150 | +- epsilon) 151 | 152 | case StringType => (left, right, ordinal) => 153 | left.getUTF8String(ordinal).getBytes === right.getUTF8String(ordinal).getBytes 154 | 155 | case BinaryType => (left, right, ordinal) => 156 | left.getBinary(ordinal) === right.getBinary(ordinal) 157 | 158 | case ArrayType(elementType, _) => (left, right, ordinal) => 159 | if (left.get(ordinal, null) == null || right.get(ordinal, null) == null ){ 160 | left.get(ordinal, null) == right.get(ordinal, null) 161 | } else { 162 | val leftArray = left.getArray(ordinal) 163 | val rightArray = right.getArray(ordinal) 164 | if (leftArray.numElements == rightArray.numElements) { 165 | val len = leftArray.numElements() 166 | val elementValueCompare = newValueCompare(elementType) 167 | var result = true 168 | var idx: Int = 0 169 | while (idx < len && result) { 170 | result = elementValueCompare(leftArray, rightArray, idx) 171 | idx += 1 172 | } 173 | result 174 | } else false 175 | } 176 | case _ => throw new RuntimeException(s"Cannot convert field to unsupported data type ${dataType}") 177 | } 178 | 179 | /** 180 | * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. 181 | */ 182 | def ~==(right: InternalRow, schema: StructType, epsilon : Float = 1E-6F): Boolean = { 183 | if (schema != null && schema.fields.size == left.numFields && schema.fields.size == right.numFields) { 184 | schema.fields.map(_.dataType).zipWithIndex.forall { 185 | case (dataType, idx) => 186 | val valueCompare = newValueCompare(dataType) 187 | valueCompare(left, right, idx) 188 | } 189 | } 190 | else false 191 | } 192 | } 193 | 194 | /** 195 | * Implicit class for comparing two rows using absolute tolerance. 196 | */ 197 | implicit class RowWithAlmostEquals(val left: Row) { 198 | 199 | /** 200 | * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. 201 | */ 202 | def ~==(right: Row, schema: StructType): Boolean = { 203 | if (schema != null && schema.fields.size == left.size && schema.fields.size == right.size) { 204 | val leftRowWithSchema = new GenericRowWithSchema(left.toSeq.toArray, schema) 205 | val rightRowWithSchema = new GenericRowWithSchema(right.toSeq.toArray, schema) 206 | leftRowWithSchema ~== rightRowWithSchema 207 | } 208 | else false 209 | } 210 | 211 | /** 212 | * When all fields in row are equal or are within eps, returns true; otherwise, returns false. 213 | */ 214 | def ~==(right: Row, epsilon : Float = 1E-6F): Boolean = { 215 | if (left.size === right.size) { 216 | val leftDataTypes = left.schema.fields.map(_.dataType) 217 | val rightDataTypes = right.schema.fields.map(_.dataType) 218 | 219 | (leftDataTypes zip rightDataTypes).zipWithIndex.forall { 220 | case (x, i) if left.get(i) == null || right.get(i) == null => 221 | left.get(i) == null && right.get(i) == null 222 | 223 | case ((FloatType, FloatType), i) => 224 | left.getFloat(i) === (right.getFloat(i) +- epsilon) 225 | 226 | case ((DoubleType, DoubleType), i) => 227 | left.getDouble(i) === (right.getDouble(i) +- epsilon) 228 | 229 | case ((BinaryType, BinaryType), i) => 230 | left.getAs[Array[Byte]](i).toSeq === right.getAs[Array[Byte]](i).toSeq 231 | 232 | case ((ArrayType(FloatType,_), ArrayType(FloatType,_)), i) => 233 | val leftArray = ArrayData.toArrayData(left.get(i)).toFloatArray().toSeq 234 | val rightArray = ArrayData.toArrayData(right.get(i)).toFloatArray().toSeq 235 | leftArray ~== (rightArray, epsilon) 236 | 237 | case ((ArrayType(DoubleType,_), ArrayType(DoubleType,_)), i) => 238 | val leftArray = ArrayData.toArrayData(left.get(i)).toDoubleArray().toSeq 239 | val rightArray = ArrayData.toArrayData(right.get(i)).toDoubleArray().toSeq 240 | leftArray ~== (rightArray, epsilon) 241 | 242 | case ((ArrayType(BinaryType,_), ArrayType(BinaryType,_)), i) => 243 | val leftArray = ArrayData.toArrayData(left.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq 244 | val rightArray = ArrayData.toArrayData(right.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq 245 | leftArray === rightArray 246 | 247 | case ((ArrayType(ArrayType(FloatType,_),_), ArrayType(ArrayType(FloatType,_),_)), i) => 248 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => 249 | ArrayData.toArrayData(arr).toFloatArray().toSeq 250 | } 251 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => 252 | ArrayData.toArrayData(arr).toFloatArray().toSeq 253 | } 254 | leftArrays ~== (rightArrays, epsilon) 255 | 256 | case ((ArrayType(ArrayType(DoubleType,_),_), ArrayType(ArrayType(DoubleType,_),_)), i) => 257 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => 258 | ArrayData.toArrayData(arr).toDoubleArray().toSeq 259 | } 260 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => 261 | ArrayData.toArrayData(arr).toDoubleArray().toSeq 262 | } 263 | leftArrays ~== (rightArrays, epsilon) 264 | 265 | case ((ArrayType(ArrayType(BinaryType,_),_), ArrayType(ArrayType(BinaryType,_),_)), i) => 266 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => 267 | ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq 268 | } 269 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => 270 | ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq 271 | } 272 | leftArrays === rightArrays 273 | 274 | case((a,b), i) => left.get(i) === right.get(i) 275 | } 276 | } 277 | else false 278 | } 279 | } 280 | } --------------------------------------------------------------------------------