├── .gitignore
├── .travis.yml
├── CONTRIBUTING.md
├── LICENSE.md
├── NOTICE
├── README.md
├── pom.xml
├── project
├── build.properties
└── plugins.sbt
└── src
├── main
├── resources
│ └── META-INF
│ │ └── services
│ │ └── org.apache.spark.sql.sources.DataSourceRegister
└── scala
│ └── com
│ └── linkedin
│ └── spark
│ └── datasources
│ └── tfrecord
│ ├── DefaultSource.scala
│ ├── TFRecordDeserializer.scala
│ ├── TFRecordFileReader.scala
│ ├── TFRecordOutputWriter.scala
│ ├── TFRecordSerializer.scala
│ └── TensorFlowInferSchema.scala
└── test
└── scala
└── com
└── linkedin
└── spark
└── datasources
└── tfrecord
├── InferSchemaSuite.scala
├── SharedSparkSessionSuite.scala
├── TFRecordDeserializerTest.scala
├── TFRecordIOSuite.scala
├── TFRecordSerializerTest.scala
└── TestingUtils.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | out
2 | *.egg
3 | *.egg-info/
4 | *.iml
5 | *.ipr
6 | *.iws
7 | *.pyc
8 | *.pyo
9 | *.sublime-*
10 | .*.swo
11 | .*.swp
12 | .cache/
13 | .coverage
14 | .direnv/
15 | .env
16 | .envrc
17 | .gradle/
18 | .idea/
19 | target/
20 | .tox*
21 | .venv*
22 | /*/*pinned.txt
23 | /*/MANIFEST
24 | /*/activate
25 | /*/build/
26 | /*/config
27 | /*/coverage.xml
28 | /*/dist/
29 | /*/htmlcov/
30 | /*/product-spec.json
31 | /build/
32 | /config/
33 | /dist/
34 | /ligradle/
35 | TEST-*.xml
36 | __pycache__/
37 | /*/build
38 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | dist: trusty
2 | language: scala
3 | git:
4 | depth: 3
5 | jdk:
6 | - oraclejdk8
7 |
8 | matrix:
9 | include:
10 | - scala: 2.12.12
11 | script: "mvn test -B -Pscala-2.12"
12 |
13 | - scala: 2.13.8
14 | script: "mvn test -B -Pscala-2.13"
15 |
16 | # safelist
17 | branches:
18 | only:
19 | - master
20 | - spark-2.3
21 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Contribution Agreement
2 | ======================
3 |
4 | As a contributor, you represent that the code you submit is your original work or
5 | that of your employer (in which case you represent you have the right to bind your
6 | employer). By submitting code, you (and, if applicable, your employer) are
7 | licensing the submitted code to LinkedIn and the open source community subject
8 | to the BSD 2-Clause license.
9 |
10 | Responsible Disclosure of Security Vulnerabilities
11 | ==================================================
12 |
13 | **Do not file an issue on Github for security issues.** Please review
14 | the [guidelines for disclosure][disclosure_guidelines]. Reports should
15 | be encrypted using PGP ([public key][pubkey]) and sent to
16 | [security@linkedin.com][disclosure_email] preferably with the title
17 | "Vulnerability in Github LinkedIn/spark-tfrecord - <short summary>".
18 |
19 | Tips for Getting Your Pull Request Accepted
20 | ===========================================
21 |
22 | 1. Make sure all new features are tested and the tests pass.
23 | 2. Bug fixes must include a test case demonstrating the error that it fixes.
24 | 3. Open an issue first and seek advice for your change before submitting
25 | a pull request. Large features which have never been discussed are
26 | unlikely to be accepted. **You have been warned.**
27 |
28 | [disclosure_guidelines]: https://www.linkedin.com/help/linkedin/answer/62924
29 | [pubkey]: https://www.linkedin.com/help/linkedin/answer/79676
30 | [disclosure_email]: mailto:security@linkedin.com?subject=Vulnerability%20in%20Github%20LinkedIn/spark-tfrecord%20-%20%3Csummary%3E
31 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | BSD 2-CLAUSE LICENSE
2 |
3 | Copyright 2020 LinkedIn Corporation
4 | All Rights Reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following
7 | conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
12 | disclaimer in the documentation and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
15 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
16 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2020 LinkedIn Corporation
2 | All Rights Reserved.
3 |
4 | Licensed under the BSD 2-Clause License (the "License").
5 | See LICENSE in the project root for license information.
6 |
7 | This product includes:
8 | - spark-tensorflow-connector: Apache 2.0 License, available at https://github.com/tensorflow/ecosystem/blob/master/LICENSE
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spark-TFRecord
2 |
3 | A library for reading and writing [Tensorflow TFRecord](https://www.tensorflow.org/how_tos/reading_data/) data from [Apache Spark](http://spark.apache.org/).
4 | The implementation is based on [Spark Tensorflow Connector](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector), but it is rewritten in Spark FileFormat trait to provide the partitioning function.
5 |
6 | ## Including the library
7 |
8 | The artifacts are published to [bintray](https://bintray.com/linkedin/maven/spark-tfrecord) and [maven central](https://search.maven.org/search?q=spark-tfrecord) repositories.
9 |
10 | - Version 0.1.x targets Spark 2.3 and Scala 2.11
11 | - Version 0.2.x targets Spark 2.4 and both Scala 2.11 and 2.12
12 | - Version 0.3.x targets Spark 3.0 and Scala 2.12
13 | - Version 0.4.x targets Spark 3.2 and Scala 2.12
14 | - Version 0.5.x targets Spark 3.2 and Scala 2.13
15 | - Version 0.6.x targets Spark 3.4 and both Scala 2.12 and 2.13
16 | - Version 0.7.x targets Spark 3.5 and both Scala 2.12 and 2.13
17 |
18 | To use the package, please include the dependency as follows
19 |
20 | ```xml
21 |
22 | com.linkedin.sparktfrecord
23 | spark-tfrecord_2.12
24 | your.version
25 |
26 | ```
27 |
28 | ## Building the library
29 | The library can be built with Maven 3.3.9 or newer as shown below:
30 |
31 | ```sh
32 | # Build Spark-TFRecord
33 | git clone https://github.com/linkedin/spark-tfrecord.git
34 | cd spark-tfrecord
35 | mvn -Pscala-2.12 clean install
36 |
37 | # One can specify the spark version and tensorflow hadoop version, for example
38 | mvn -Pscala-2.12 clean install -Dspark.version=3.0.0 -Dtensorflow.hadoop.version=1.15.0
39 | ```
40 |
41 | ## Using Spark Shell
42 | Run this library in Spark using the `--jars` command line option in `spark-shell`, `pyspark` or `spark-submit`. For example:
43 |
44 | ```sh
45 | $SPARK_HOME/bin/spark-shell --jars target/spark-tfrecord_2.12-0.3.0.jar
46 | ```
47 |
48 | ## Features
49 | This library allows reading TensorFlow records in local or distributed filesystem as [Spark DataFrames](https://spark.apache.org/docs/latest/sql-programming-guide.html).
50 | When reading TensorFlow records into Spark DataFrame, the API accepts several options:
51 | * `load`: input path to TensorFlow records. Similar to Spark can accept standard Hadoop globbing expressions.
52 | * `schema`: schema of TensorFlow records. Optional schema defined using Spark StructType. If not provided, the schema is inferred from TensorFlow records.
53 | * `recordType`: input format of TensorFlow records. By default it is Example. Possible values are:
54 | * `Example`: TensorFlow [Example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records
55 | * `SequenceExample`: TensorFlow [SequenceExample](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records
56 | * `ByteArray`: `Array[Byte]` type in scala.
57 |
58 | When writing Spark DataFrame to TensorFlow records, the API accepts several options:
59 | * `save`: output path to TensorFlow records. Output path to TensorFlow records on local or distributed filesystem.
60 | compression. While reading compressed TensorFlow records, `codec` can be inferred automatically, so this option is not required for reading.
61 | * `recordType`: output format of TensorFlow records. By default it is Example. Possible values are:
62 | * `Example`: TensorFlow [Example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records
63 | * `SequenceExample`: TensorFlow [SequenceExample](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) records
64 | * `ByteArray`: `Array[Byte]` type in scala. For use cases when writing objects other than tensorflow Example or SequenceExample. For example, [protos](https://developers.google.com/protocol-buffers) can be transformed to byte arrays using `.toByteArray`.
65 |
66 | The writer support partitionBy operation. So the following command will partition the output by "partitionColumn".
67 | ```
68 | df.write.mode(SaveMode.Overwrite).partitionBy("partitionColumn").format("tfrecord").option("recordType", "Example").save(output_dir)
69 | ```
70 | Note we use `format("tfrecord")` instead `format("tfrecords")`. So if you migrate from Spark-Tensorflow-Connector, make sure this is changed accordingly.
71 |
72 | ## Schema inference
73 | This library supports automatic schema inference when reading TensorFlow records into Spark DataFrames.
74 | Schema inference is expensive since it requires an extra pass through the data.
75 |
76 | The schema inference rules are described in the table below:
77 |
78 | | TFRecordType | Feature Type | Inferred Spark Data Type |
79 | | ------------------------ |:--------------|:--------------------------|
80 | | Example, SequenceExample | Int64List | LongType if all lists have length=1, else ArrayType(LongType) |
81 | | Example, SequenceExample | FloatList | FloatType if all lists have length=1, else ArrayType(FloatType) |
82 | | Example, SequenceExample | BytesList | StringType if all lists have length=1, else ArrayType(StringType) |
83 | | SequenceExample | FeatureList of Int64List | ArrayType(ArrayType(LongType)) |
84 | | SequenceExample | FeatureList of FloatList | ArrayType(ArrayType(FloatType)) |
85 | | SequenceExample | FeatureList of BytesList | ArrayType(ArrayType(StringType)) |
86 |
87 | ## Supported data types
88 |
89 | The supported Spark data types are listed in the table below:
90 |
91 | | Type | Spark DataTypes |
92 | | --------------- |:------------------------------------------|
93 | | Scalar | IntegerType, LongType, FloatType, DoubleType, DecimalType, StringType, BinaryType |
94 | | Array | VectorType, ArrayType of IntegerType, LongType, FloatType, DoubleType, DecimalType, BinaryType, or StringType |
95 | | Array of Arrays | ArrayType of ArrayType of IntegerType, LongType, FloatType, DoubleType, DecimalType, BinaryType, or StringType |
96 |
97 | ## Usage Examples
98 |
99 | ### Python API
100 |
101 | #### TF record Import/export
102 |
103 | Run PySpark with the spark_connector in the jars argument as shown below:
104 |
105 | `$SPARK_HOME/bin/pyspark --jars target/spark-tfrecord_2.12-0.3.0.jar`
106 |
107 | The following Python code snippet demonstrates usage on test data.
108 |
109 | ```python
110 | from pyspark.sql.types import *
111 |
112 | path = "test-output.tfrecord"
113 |
114 | fields = [StructField("id", IntegerType()), StructField("IntegerCol", IntegerType()),
115 | StructField("LongCol", LongType()), StructField("FloatCol", FloatType()),
116 | StructField("DoubleCol", DoubleType()), StructField("VectorCol", ArrayType(DoubleType(), True)),
117 | StructField("StringCol", StringType())]
118 | schema = StructType(fields)
119 | test_rows = [[11, 1, 23, 10.0, 14.0, [1.0, 2.0], "r1"], [21, 2, 24, 12.0, 15.0, [2.0, 2.0], "r2"]]
120 | rdd = spark.sparkContext.parallelize(test_rows)
121 | df = spark.createDataFrame(rdd, schema)
122 | df.write.mode("overwrite").format("tfrecord").option("recordType", "Example").save(path)
123 | df = spark.read.format("tfrecord").option("recordType", "Example").load(path)
124 | df.show()
125 | ```
126 |
127 | ### Scala API
128 | Run Spark shell with the spark_connector in the jars argument as shown below:
129 | ```sh
130 | $SPARK_HOME/bin/spark-shell --jars target/spark-tfrecord_2.12-0.3.0.jar
131 | ```
132 |
133 | The following Scala code snippet demonstrates usage on test data.
134 |
135 | ```scala
136 | import org.apache.commons.io.FileUtils
137 | import org.apache.spark.sql.{ DataFrame, Row }
138 | import org.apache.spark.sql.catalyst.expressions.GenericRow
139 | import org.apache.spark.sql.types._
140 |
141 | val path = "test-output.tfrecord"
142 | val testRows: Array[Row] = Array(
143 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
144 | new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
145 | val schema = StructType(List(StructField("id", IntegerType),
146 | StructField("IntegerCol", IntegerType),
147 | StructField("LongCol", LongType),
148 | StructField("FloatCol", FloatType),
149 | StructField("DoubleCol", DoubleType),
150 | StructField("VectorCol", ArrayType(DoubleType, true)),
151 | StructField("StringCol", StringType)))
152 |
153 | val rdd = spark.sparkContext.parallelize(testRows)
154 |
155 | //Save DataFrame as TFRecords
156 | val df: DataFrame = spark.createDataFrame(rdd, schema)
157 | df.write.format("tfrecord").option("recordType", "Example").save(path)
158 |
159 | //Read TFRecords into DataFrame.
160 | //The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
161 | val importedDf1: DataFrame = spark.read.format("tfrecord").option("recordType", "Example").load(path)
162 | importedDf1.show()
163 |
164 | //Read TFRecords into DataFrame using custom schema
165 | val importedDf2: DataFrame = spark.read.format("tfrecord").schema(schema).load(path)
166 | importedDf2.show()
167 | ```
168 |
169 | #### Use partitionBy
170 | The following example shows to how to use partitionBy, which is not supported by [Spark Tensorflow Connector](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector)
171 |
172 | ```scala
173 |
174 | // launch spark-shell with the following command:
175 | // SPARK_HOME/bin/spark-shell --jar target/spark-tfrecord_2.12-0.3.0.jar
176 |
177 | import org.apache.spark.sql.SaveMode
178 |
179 | val df = Seq((8, "bat"),(8, "abc"), (1, "xyz"), (2, "aaa")).toDF("number", "word")
180 | df.show
181 |
182 | // scala> df.show
183 | // +------+----+
184 | // |number|word|
185 | // +------+----+
186 | // | 8| bat|
187 | // | 8| abc|
188 | // | 1| xyz|
189 | // | 2| aaa|
190 | // +------+----+
191 |
192 | val tf_output_dir = "/tmp/tfrecord-test"
193 |
194 | // dump the tfrecords to files.
195 | df.repartition(3, col("number")).write.mode(SaveMode.Overwrite).partitionBy("number").format("tfrecord").option("recordType", "Example").save(tf_output_dir)
196 |
197 | // ls /tmp/tfrecord-test
198 | // _SUCCESS number=1 number=2 number=8
199 |
200 | // read back the tfrecords from files.
201 | val new_df = spark.read.format("tfrecord").option("recordType", "Example").load(tf_output_dir)
202 | new_df.show
203 |
204 | // scala> new_df.show
205 | // +----+------+
206 | // |word|number|
207 | // +----+------+
208 | // | bat| 8|
209 | // | abc| 8|
210 | // | xyz| 1|
211 | // | aaa| 2|
212 | ```
213 | ## Contributing
214 |
215 | Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us.
216 |
217 | ## License
218 |
219 | This project is licensed under the BSD 2-CLAUSE LICENSE - see the [LICENSE.md](LICENSE.md) file for details
220 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 | com.linkedin.sparktfrecord
7 | spark-tfrecord_${scala.binary.version}
8 | jar
9 | 0.7.0
10 | spark-tfrecord
11 | https://github.com/linkedin/spark-tfrecord
12 | TensorFlow TFRecord data source for Apache Spark
13 |
14 |
15 |
16 | BSD 2-CLAUSE LICENSE
17 | https://github.com/linkedin/spark-tfrecord/blob/master/LICENSE.md
18 | repo
19 |
20 |
21 |
22 |
23 | https://github.com/linkedin/spark-tfrecord.git
24 | git@github.com:linkedin/spark-tfrecord.git
25 | scm:git:https://github.com/linkedin/spark-tfrecord.git
26 |
27 |
28 |
29 | UTF-8
30 | 3.2.2
31 | 1.0
32 | 3.0.8
33 | 3.5.1
34 | 3.0
35 | 1.8
36 | 4.13.1
37 | 1.15.0
38 |
39 |
40 |
41 |
42 |
43 |
44 | true
45 | net.alchim31.maven
46 | scala-maven-plugin
47 | ${scala.maven.version}
48 |
49 |
50 | compile
51 |
52 | add-source
53 | compile
54 |
55 |
56 |
57 | -Xms256m
58 | -Xmx512m
59 |
60 |
61 | -g:vars
62 | -deprecation
63 | -feature
64 | -unchecked
65 | -Xfatal-warnings
66 | -language:implicitConversions
67 | -language:existentials
68 |
69 |
70 |
71 |
72 | test
73 |
74 | add-source
75 | testCompile
76 |
77 |
78 |
79 | attach-javadocs
80 |
81 | doc-jar
82 |
83 |
84 |
85 |
86 | incremental
87 | true
88 | ${scala.version}
89 | false
90 |
91 |
92 |
93 | true
94 | org.scalatest
95 | scalatest-maven-plugin
96 | ${scalatest.maven.version}
97 |
98 |
99 | scalaTest
100 | test
101 |
102 | test
103 |
104 |
105 |
106 |
107 |
108 |
109 | maven-shade-plugin
110 | 3.1.0
111 |
112 |
113 | package
114 |
115 | shade
116 |
117 |
118 | true
119 |
120 |
121 | com.google.protobuf:protobuf-java
122 | org.tensorflow:tensorflow-hadoop
123 | org.tensorflow:proto
124 |
125 |
126 |
127 |
128 |
129 | com.google.protobuf:protobuf-java
130 |
131 | **/*.java
132 |
133 |
134 |
135 |
136 |
137 | com.google.protobuf
138 |
139 | com.linkedin.spark.shaded.com.google.protobuf
140 |
141 |
142 |
143 | org.tensorflow.hadoop
144 |
145 | com.linkedin.spark.shaded.org.tensorflow.hadoop
146 |
147 |
148 |
149 | org.tensorflow.example
150 |
151 | com.linkedin.spark.shaded.org.tensorflow.example
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 | org.apache.maven.plugins
162 | maven-gpg-plugin
163 | 1.5
164 |
165 |
166 | sign-artifacts
167 | verify
168 |
169 | sign
170 |
171 |
172 |
173 |
174 |
175 |
176 | org.spurint.maven.plugins
177 | scala-cross-maven-plugin
178 | 0.2.1
179 |
180 |
181 | rewrite-pom
182 |
183 | rewrite-pom
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 | net.alchim31.maven
193 | scala-maven-plugin
194 |
195 |
196 | org.apache.maven.plugins
197 | maven-shade-plugin
198 |
199 |
200 | org.scalatest
201 | scalatest-maven-plugin
202 |
203 |
204 | org.apache.maven.plugins
205 | maven-compiler-plugin
206 | ${maven.compiler.version}
207 |
208 | ${java.version}
209 | ${java.version}
210 |
211 |
212 |
213 | org.apache.maven.plugins
214 | maven-source-plugin
215 | 2.2.1
216 |
217 |
218 | attach-sources
219 |
220 | jar-no-fork
221 |
222 |
223 |
224 |
225 |
226 | org.apache.maven.plugins
227 | maven-javadoc-plugin
228 | 2.9.1
229 |
230 |
231 | attach-javadocs
232 |
233 | jar
234 |
235 |
236 |
237 |
238 |
239 | org.spurint.maven.plugins
240 | scala-cross-maven-plugin
241 |
242 |
243 |
244 |
245 |
246 |
247 | apache.snapshots
248 | Apache Development Snapshot Repository
249 | https://repository.apache.org/content/repositories/snapshots/
250 |
251 | false
252 |
253 |
254 | true
255 |
256 |
257 |
258 |
259 |
260 |
261 | test
262 |
263 | true
264 |
265 | !NEVERSETME
266 |
267 |
268 |
269 |
270 |
271 | net.alchim31.maven
272 | scala-maven-plugin
273 |
274 |
275 |
276 |
277 |
278 |
279 | org.scalatest
280 | scalatest_${scala.binary.version}
281 | ${scala.test.version}
282 | test
283 |
284 |
285 |
286 |
287 |
288 | org.scalatest
289 | scalatest_${scala.binary.version}
290 | test
291 |
292 |
293 |
294 |
295 |
298 |
299 | ossrh
300 |
301 |
302 |
303 | ossrh
304 | https://oss.sonatype.org/content/repositories/snapshots
305 |
306 |
307 | ossrh
308 | https://oss.sonatype.org/service/local/staging/deploy/maven2/
309 |
310 |
311 |
312 |
313 |
314 | org.apache.maven.plugins
315 | maven-gpg-plugin
316 |
317 |
318 |
319 |
320 |
321 |
322 | bintray
323 |
324 |
325 |
326 | bintray-linkedin-maven
327 | linkedin-maven
328 | https://api.bintray.com/maven/linkedin/maven/spark-tfrecord/;publish=1
329 |
330 |
331 |
332 |
333 |
334 | org.apache.maven.plugins
335 | maven-gpg-plugin
336 |
337 |
338 |
339 |
340 |
341 |
342 | scala-2.12
343 |
344 | 2.12
345 | 2.12.12
346 |
347 |
348 |
349 |
353 |
354 | scala-2.13
355 |
356 | 2.13
357 | 2.13.8
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 | spark-tfrecord developers
366 | LinkedIn
367 | http://www.linkedin.com
368 |
369 |
370 |
371 |
372 |
373 | org.tensorflow
374 | tensorflow-hadoop
375 | ${tensorflow.hadoop.version}
376 |
377 |
378 | org.apache.spark
379 | spark-core_${scala.binary.version}
380 | ${spark.version}
381 | provided
382 |
383 |
384 | org.apache.spark
385 | spark-sql_${scala.binary.version}
386 | ${spark.version}
387 | provided
388 |
389 |
390 | org.apache.spark
391 | spark-mllib_${scala.binary.version}
392 | ${spark.version}
393 | provided
394 |
395 |
396 | org.apache.spark
397 | spark-mllib_${scala.binary.version}
398 | ${spark.version}
399 | test-jar
400 | test
401 |
402 |
403 | junit
404 | junit
405 | ${junit.version}
406 | test
407 |
408 |
409 |
410 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=0.13.13
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/"
2 |
3 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3")
4 |
5 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.5")
6 |
--------------------------------------------------------------------------------
/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister:
--------------------------------------------------------------------------------
1 | com.linkedin.spark.datasources.tfrecord.DefaultSource
--------------------------------------------------------------------------------
/src/main/scala/com/linkedin/spark/datasources/tfrecord/DefaultSource.scala:
--------------------------------------------------------------------------------
1 | package com.linkedin.spark.datasources.tfrecord
2 |
3 | import java.io.{DataInputStream, DataOutputStream, IOException, ObjectInputStream, ObjectOutputStream, Serializable}
4 |
5 | import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
6 | import com.esotericsoftware.kryo.io.{Input, Output}
7 | import org.apache.hadoop.conf.Configuration
8 | import org.apache.hadoop.fs.{FileStatus, Path}
9 | import org.apache.hadoop.io.SequenceFile.CompressionType
10 | import org.apache.hadoop.io.{BytesWritable, NullWritable}
11 | import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
12 | import org.apache.spark.sql.SparkSession
13 | import org.apache.spark.sql.catalyst.InternalRow
14 | import org.apache.spark.sql.execution.datasources._
15 | import org.apache.spark.sql.sources._
16 | import org.apache.spark.sql.types._
17 | import org.slf4j.LoggerFactory
18 | import org.tensorflow.example.{Example, SequenceExample}
19 | import org.tensorflow.hadoop.io.TFRecordFileInputFormat
20 |
21 | import scala.util.control.NonFatal
22 |
23 | class DefaultSource extends FileFormat with DataSourceRegister {
24 | override val shortName: String = "tfrecord"
25 |
26 | override def isSplitable(
27 | sparkSession: SparkSession,
28 | options: Map[String, String],
29 | path: Path): Boolean = false
30 |
31 | override def inferSchema(
32 | sparkSession: SparkSession,
33 | options: Map[String, String],
34 | files: Seq[FileStatus]): Option[StructType] = {
35 | val recordType = options.getOrElse("recordType", "Example")
36 | files.collectFirst {
37 | case f if hasSchema(sparkSession, f, recordType) => getSchemaFromFile(sparkSession, f, recordType)
38 | }
39 | }
40 |
41 | /**
42 | * Get schema from a file
43 | * @param sparkSession A spark session.
44 | * @param file The file where schema to be extracted.
45 | * @param recordType Example or SequenceExample
46 | * @return the extracted schema (a StructType).
47 | */
48 | private def getSchemaFromFile(
49 | sparkSession: SparkSession,
50 | file: FileStatus,
51 | recordType: String): StructType = {
52 | val rdd = sparkSession.sparkContext.newAPIHadoopFile(file.getPath.toString,
53 | classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable])
54 | recordType match {
55 | case "ByteArray" =>
56 | TensorFlowInferSchema.getSchemaForByteArray()
57 | case "Example" =>
58 | val exampleRdd = rdd.map{case (bytesWritable, nullWritable) =>
59 | Example.parseFrom(bytesWritable.getBytes)
60 | }
61 | TensorFlowInferSchema(exampleRdd)
62 | case "SequenceExample" =>
63 | val sequenceExampleRdd = rdd.map{case (bytesWritable, nullWritable) =>
64 | SequenceExample.parseFrom(bytesWritable.getBytes)
65 | }
66 | TensorFlowInferSchema(sequenceExampleRdd)
67 | case _ =>
68 | throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be ByteArray, Example or SequenceExample")
69 | }
70 | }
71 |
72 | /**
73 | * Check if a non-empty schema can be extracted from a file.
74 | * The schema is empty if one of the following is true:
75 | * 1. The file size is zero.
76 | * 2. The file size is non-zero, but the schema is empty (e.g. empty .gz file)
77 | * @param sparkSession A spark session.
78 | * @param file The file where schema to be extracted.
79 | * @param recordType Example or SequenceExample
80 | * @return true if schema is non-empty.
81 | */
82 | private def hasSchema(
83 | sparkSession: SparkSession,
84 | file: FileStatus,
85 | recordType: String): Boolean = {
86 | (file.getLen > 0) && (getSchemaFromFile(sparkSession, file, recordType).length > 0)
87 | }
88 |
89 | override def prepareWrite(
90 | sparkSession: SparkSession,
91 | job: Job,
92 | options: Map[String, String],
93 | dataSchema: StructType): OutputWriterFactory = {
94 | val conf = job.getConfiguration
95 | val codec = options.getOrElse("codec", "")
96 | if (!codec.isEmpty) {
97 | conf.set("mapreduce.output.fileoutputformat.compress", "true")
98 | conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString)
99 | conf.set("mapreduce.output.fileoutputformat.compress.codec", codec)
100 | conf.set("mapreduce.map.output.compress", "true")
101 | conf.set("mapreduce.map.output.compress.codec", codec)
102 | }
103 |
104 | new OutputWriterFactory {
105 | override def newInstance(
106 | path: String,
107 | dataSchema: StructType,
108 | context: TaskAttemptContext): OutputWriter = {
109 | new TFRecordOutputWriter(path, options, dataSchema, context)
110 | }
111 |
112 | override def getFileExtension(context: TaskAttemptContext): String = {
113 | ".tfrecord" + CodecStreams.getCompressionExtension(context)
114 | }
115 | }
116 | }
117 |
118 | override def buildReader(
119 | sparkSession: SparkSession,
120 | dataSchema: StructType,
121 | partitionSchema: StructType,
122 | requiredSchema: StructType,
123 | filters: Seq[Filter],
124 | options: Map[String, String],
125 | hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
126 | val broadcastedHadoopConf =
127 | sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
128 |
129 | (file: PartitionedFile) => {
130 | TFRecordFileReader.readFile(
131 | broadcastedHadoopConf.value.value,
132 | options,
133 | file,
134 | requiredSchema)
135 | }
136 | }
137 |
138 | override def toString: String = "TFRECORD"
139 |
140 | override def hashCode(): Int = getClass.hashCode()
141 |
142 | override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource]
143 | }
144 |
145 | private [tfrecord] class SerializableConfiguration(@transient var value: Configuration)
146 | extends Serializable with KryoSerializable {
147 | @transient private[tfrecord] lazy val log = LoggerFactory.getLogger(getClass)
148 |
149 | private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
150 | out.defaultWriteObject()
151 | value.write(out)
152 | }
153 |
154 | private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
155 | value = new Configuration(false)
156 | value.readFields(in)
157 | }
158 |
159 | private def tryOrIOException[T](block: => T): T = {
160 | try {
161 | block
162 | } catch {
163 | case e: IOException =>
164 | log.error("Exception encountered", e)
165 | throw e
166 | case NonFatal(e) =>
167 | log.error("Exception encountered", e)
168 | throw new IOException(e)
169 | }
170 | }
171 |
172 | def write(kryo: Kryo, out: Output): Unit = {
173 | val dos = new DataOutputStream(out)
174 | value.write(dos)
175 | dos.flush()
176 | }
177 |
178 | def read(kryo: Kryo, in: Input): Unit = {
179 | value = new Configuration(false)
180 | value.readFields(new DataInputStream(in))
181 | }
182 | }
183 |
--------------------------------------------------------------------------------
/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala:
--------------------------------------------------------------------------------
1 | package com.linkedin.spark.datasources.tfrecord
2 |
3 | import org.apache.spark.sql.catalyst.InternalRow
4 | import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow, UnsafeArrayData}
5 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
6 | import org.apache.spark.sql.types.{DecimalType, DoubleType, _}
7 | import org.apache.spark.unsafe.types.UTF8String
8 | import org.tensorflow.example._
9 |
10 | import scala.jdk.CollectionConverters._
11 |
12 | /**
13 | * Creates a TFRecord deserializer to deserialize Tfrecord example or sequenceExample to Spark InternalRow
14 | */
15 | class TFRecordDeserializer(dataSchema: StructType) {
16 |
17 | def deserializeByteArray(byteArray: Array[Byte]): InternalRow = {
18 | InternalRow(byteArray)
19 | }
20 |
21 | def deserializeExample(example: Example): InternalRow = {
22 | val featureMap = example.getFeatures.getFeatureMap.asScala
23 | val resultRow = new SpecificInternalRow(dataSchema.map(_.dataType))
24 | dataSchema.zipWithIndex.foreach {
25 | case (field, index) =>
26 | val feature = featureMap.get(field.name)
27 | feature match {
28 | case Some(ft) =>
29 | val featureWriter = newFeatureWriter(field.dataType, new RowUpdater(resultRow))
30 | featureWriter(index, ft)
31 | case None => if (!field.nullable) throw new NullPointerException(s"Field ${field.name} does not allow null values")
32 | }
33 | }
34 | resultRow
35 | }
36 |
37 | def deserializeSequenceExample(sequenceExample: SequenceExample): InternalRow = {
38 |
39 | val featureMap = sequenceExample.getContext.getFeatureMap.asScala
40 | val featureListMap = sequenceExample.getFeatureLists.getFeatureListMap.asScala
41 | val resultRow = new SpecificInternalRow(dataSchema.map(_.dataType))
42 |
43 | dataSchema.zipWithIndex.foreach {
44 | case (field, index) =>
45 | val feature = featureMap.get(field.name)
46 | feature match {
47 | case Some(ft) =>
48 | val featureWriter = newFeatureWriter(field.dataType, new RowUpdater(resultRow))
49 | featureWriter(index, ft)
50 | case None =>
51 | val featureList = featureListMap.get(field.name)
52 | featureList match {
53 | case Some(ftList) =>
54 | val featureListWriter = newFeatureListWriter(field.dataType, new RowUpdater(resultRow))
55 | featureListWriter(index, ftList)
56 | case None => if (!field.nullable) throw new NullPointerException(s"Field ${field.name} does not allow null values")
57 | }
58 | }
59 | }
60 | resultRow
61 | }
62 |
63 | private type arrayElementConverter = (SpecializedGetters, Int) => Any
64 |
65 | /**
66 | * Creates a writer to write Tfrecord Feature values to Catalyst data structure at the given ordinal.
67 | */
68 | private def newFeatureWriter(
69 | dataType: DataType, updater: CatalystDataUpdater): (Int, Feature) => Unit =
70 | dataType match {
71 | case NullType => (ordinal, _) =>
72 | updater.setNullAt(ordinal)
73 |
74 | case IntegerType => (ordinal, feature) =>
75 | updater.setInt(ordinal, Int64ListFeature2SeqLong(feature).head.toInt)
76 |
77 | case LongType => (ordinal, feature) =>
78 | updater.setLong(ordinal, Int64ListFeature2SeqLong(feature).head)
79 |
80 | case FloatType => (ordinal, feature) =>
81 | updater.setFloat(ordinal, floatListFeature2SeqFloat(feature).head.toFloat)
82 |
83 | case DoubleType => (ordinal, feature) =>
84 | updater.setDouble(ordinal, floatListFeature2SeqFloat(feature).head.toDouble)
85 |
86 | case DecimalType() => (ordinal, feature) =>
87 | updater.setDecimal(ordinal, Decimal(floatListFeature2SeqFloat(feature).head.toDouble))
88 |
89 | case StringType => (ordinal, feature) =>
90 | val value = bytesListFeature2SeqString(feature).head
91 | updater.set(ordinal, UTF8String.fromString(value))
92 |
93 | case BinaryType => (ordinal, feature) =>
94 | val value = bytesListFeature2SeqArrayByte(feature).head
95 | updater.set(ordinal, value)
96 |
97 | case ArrayType(elementType, _) => (ordinal, feature) =>
98 |
99 | elementType match {
100 | case IntegerType | LongType | FloatType | DoubleType | DecimalType() | StringType | BinaryType =>
101 | val valueList = elementType match {
102 | case IntegerType => Int64ListFeature2SeqLong(feature).map(_.toInt)
103 | case LongType => Int64ListFeature2SeqLong(feature)
104 | case FloatType => floatListFeature2SeqFloat(feature).map(_.toFloat)
105 | case DoubleType => floatListFeature2SeqFloat(feature).map(_.toDouble)
106 | case DecimalType() => floatListFeature2SeqFloat(feature).map(x => Decimal(x.toDouble))
107 | case StringType => bytesListFeature2SeqString(feature)
108 | case BinaryType => bytesListFeature2SeqArrayByte(feature)
109 | }
110 | val len = valueList.length
111 | val result = createArrayData(elementType, len)
112 | val elementUpdater = new ArrayDataUpdater(result)
113 | val elementConverter = newArrayElementWriter(elementType, elementUpdater)
114 | for (idx <- 0 until len) {
115 | elementConverter(idx, valueList(idx))
116 | }
117 | updater.set(ordinal, result)
118 |
119 | case _ => throw new scala.RuntimeException(s"Cannot convert Array type to unsupported data type ${elementType}")
120 | }
121 |
122 | case _ =>
123 | throw new UnsupportedOperationException(s"$dataType is not supported yet.")
124 | }
125 |
126 | /**
127 | * Creates a writer to write Tfrecord FeatureList values to Catalyst data structure at the given ordinal.
128 | */
129 | private def newFeatureListWriter(
130 | dataType: DataType, updater: CatalystDataUpdater): (Int, FeatureList) => Unit =
131 | dataType match {
132 | case ArrayType(elementType, _) => (ordinal, featureList) =>
133 | val ftList = featureList.getFeatureList.asScala
134 | val len = ftList.length
135 | val resultArray = createArrayData(elementType, len)
136 | val elementUpdater = new ArrayDataUpdater(resultArray)
137 | val elementConverter = newFeatureWriter(elementType, elementUpdater)
138 | for (idx <- 0 until len) {
139 | elementConverter(idx, ftList(idx))
140 | }
141 | updater.set(ordinal, resultArray)
142 | case _ => throw new scala.RuntimeException(s"Cannot convert FeatureList to unsupported data type ${dataType}")
143 | }
144 |
145 | /**
146 | * Creates a writer to write Tfrecord Feature array element to Catalyst data structure at the given ordinal.
147 | */
148 | private def newArrayElementWriter(
149 | dataType: DataType, updater: CatalystDataUpdater): (Int, Any) => Unit =
150 | dataType match {
151 | case NullType => null
152 |
153 | case IntegerType => (ordinal, value) =>
154 | updater.setInt(ordinal, value.asInstanceOf[Int])
155 |
156 | case LongType => (ordinal, value) =>
157 | updater.setLong(ordinal, value.asInstanceOf[Long])
158 |
159 | case FloatType => (ordinal, value) =>
160 | updater.setFloat(ordinal, value.asInstanceOf[Float])
161 |
162 | case DoubleType => (ordinal, value) =>
163 | updater.setDouble(ordinal, value.asInstanceOf[Double])
164 |
165 | case DecimalType() => (ordinal, value) =>
166 | updater.setDecimal(ordinal, value.asInstanceOf[Decimal])
167 |
168 | case StringType => (ordinal, value) =>
169 | updater.set(ordinal, UTF8String.fromString(value.asInstanceOf[String]))
170 |
171 | case BinaryType => (ordinal, value) =>
172 | updater.set(ordinal, value.asInstanceOf[Array[Byte]])
173 |
174 | case _ => throw new RuntimeException(s"Cannot convert array element to unsupported data type ${dataType}")
175 | }
176 |
177 | def Int64ListFeature2SeqLong(feature: Feature): Seq[Long] = {
178 | require(feature != null && feature.getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER, "Feature must be of type Int64List")
179 | try {
180 | feature.getInt64List.getValueList.asScala.toSeq.map(_.toLong)
181 | }
182 | catch {
183 | case ex: Exception =>
184 | throw new RuntimeException(s"Cannot convert feature to long.", ex)
185 | }
186 | }
187 |
188 | def floatListFeature2SeqFloat(feature: Feature): Seq[java.lang.Float] = {
189 | require(feature != null && feature.getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER, "Feature must be of type FloatList")
190 | try {
191 | val array = feature.getFloatList.getValueList.asScala.toSeq
192 | array
193 | }
194 | catch {
195 | case ex: Exception =>
196 | throw new RuntimeException(s"Cannot convert feature to Float.", ex)
197 | }
198 | }
199 |
200 | def bytesListFeature2SeqArrayByte(feature: Feature): Seq[Array[Byte]] = {
201 | require(feature != null && feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList")
202 | try {
203 | feature.getBytesList.getValueList.asScala.toSeq.map((byteArray) => byteArray.asScala.toArray.map(_.toByte))
204 | }
205 | catch {
206 | case ex: Exception =>
207 | throw new RuntimeException(s"Cannot convert feature to byte array.", ex)
208 | }
209 | }
210 |
211 | def bytesListFeature2SeqString(feature: Feature): Seq[String] = {
212 | require(feature != null && feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList")
213 | try {
214 | val array = feature.getBytesList.getValueList.asScala.toSeq
215 | array.map(_.toStringUtf8)
216 | }
217 | catch {
218 | case ex: Exception =>
219 | throw new RuntimeException(s"Cannot convert feature to String array.", ex)
220 | }
221 | }
222 |
223 | private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
224 | case BooleanType => new GenericArrayData(new Array[Boolean](length))
225 | case ByteType => new GenericArrayData(new Array[Byte](length))
226 | case ShortType => new GenericArrayData(new Array[Short](length))
227 | case IntegerType => new GenericArrayData(new Array[Int](length))
228 | case LongType => new GenericArrayData(new Array[Long](length))
229 | case FloatType => new GenericArrayData(new Array[Float](length))
230 | case DoubleType => new GenericArrayData(new Array[Double](length))
231 | case _ => new GenericArrayData(new Array[Any](length))
232 | }
233 |
234 | /**
235 | * A base interface for updating values inside catalyst data structure like `InternalRow` and
236 | * `ArrayData`.
237 | */
238 | sealed trait CatalystDataUpdater {
239 | def set(ordinal: Int, value: Any): Unit
240 | def setNullAt(ordinal: Int): Unit = set(ordinal, null)
241 | def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
242 | def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
243 | def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
244 | def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
245 | def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
246 | def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
247 | def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
248 | def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
249 | }
250 |
251 | final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
252 | override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
253 | override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
254 | override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
255 | override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
256 | override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
257 | override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
258 | override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
259 | override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
260 | override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
261 | override def setDecimal(ordinal: Int, value: Decimal): Unit =
262 | row.setDecimal(ordinal, value, value.precision)
263 | }
264 |
265 | final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
266 | override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
267 | override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
268 | override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
269 | override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
270 | override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
271 | override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
272 | override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
273 | override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
274 | override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
275 | override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
276 | }
277 | }
278 |
--------------------------------------------------------------------------------
/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordFileReader.scala:
--------------------------------------------------------------------------------
1 | package com.linkedin.spark.datasources.tfrecord
2 |
3 | import org.apache.hadoop.conf.Configuration
4 | import org.apache.hadoop.fs.Path
5 | import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
6 | import org.apache.hadoop.mapreduce.lib.input.FileSplit
7 | import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
8 | import org.apache.spark.sql.catalyst.InternalRow
9 | import org.apache.spark.sql.execution.datasources.PartitionedFile
10 | import org.apache.spark.TaskContext
11 | import org.apache.spark.sql.types.StructType
12 | import org.tensorflow.example.{Example, SequenceExample}
13 | import org.tensorflow.hadoop.io.TFRecordFileInputFormat
14 |
15 | object TFRecordFileReader {
16 | def readFile(
17 | conf: Configuration,
18 | options: Map[String, String],
19 | file: PartitionedFile,
20 | schema: StructType): Iterator[InternalRow] = {
21 |
22 | val recordType = options.getOrElse("recordType", "Example")
23 |
24 | val inputSplit = new FileSplit(
25 | file.toPath,
26 | file.start,
27 | file.length,
28 | // The locality is decided by `getPreferredLocations` in `FileScanRDD`.
29 | Array.empty)
30 | val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
31 | val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
32 | val recordReader = new TFRecordFileInputFormat().createRecordReader(inputSplit, hadoopAttemptContext)
33 |
34 | // Ensure that the reader is closed even if the task fails or doesn't consume the entire
35 | // iterator of records.
36 | Option(TaskContext.get()).foreach { taskContext =>
37 | taskContext.addTaskCompletionListener[Unit]((_: TaskContext) =>
38 | recordReader.close()
39 | )
40 | }
41 |
42 | recordReader.initialize(inputSplit, hadoopAttemptContext)
43 |
44 | val deserializer = new TFRecordDeserializer(schema)
45 |
46 | new Iterator[InternalRow] {
47 | private[this] var havePair = false
48 | private[this] var finished = false
49 | override def hasNext: Boolean = {
50 | if (!finished && !havePair) {
51 | finished = !recordReader.nextKeyValue
52 | if (finished) {
53 | // Close and release the reader here; close() will also be called when the task
54 | // completes, but for tasks that read from many files, it helps to release the
55 | // resources early.
56 | recordReader.close()
57 | }
58 | havePair = !finished
59 | }
60 | !finished
61 | }
62 |
63 | override def next(): InternalRow = {
64 | if (!hasNext) {
65 | throw new java.util.NoSuchElementException("End of stream")
66 | }
67 | havePair = false
68 | val bytesWritable = recordReader.getCurrentKey
69 | recordType match {
70 | case "ByteArray" =>
71 | deserializer.deserializeByteArray(bytesWritable.getBytes)
72 | case "Example" =>
73 | val example = Example.parseFrom(bytesWritable.getBytes)
74 | deserializer.deserializeExample(example)
75 | case "SequenceExample" =>
76 | val sequenceExample = SequenceExample.parseFrom(bytesWritable.getBytes)
77 | deserializer.deserializeSequenceExample(sequenceExample)
78 | case _ =>
79 | throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be ByteArray, Example or SequenceExample")
80 | }
81 | }
82 | }
83 | }
84 | }
--------------------------------------------------------------------------------
/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordOutputWriter.scala:
--------------------------------------------------------------------------------
1 | package com.linkedin.spark.datasources.tfrecord
2 |
3 | import java.io.DataOutputStream
4 | import org.apache.hadoop.fs.Path
5 | import org.apache.hadoop.mapreduce.TaskAttemptContext
6 | import org.apache.spark.sql.catalyst.InternalRow
7 | import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter}
8 | import org.apache.spark.sql.types.StructType
9 | import org.tensorflow.hadoop.util.TFRecordWriter
10 |
11 |
12 | class TFRecordOutputWriter(
13 | val path: String,
14 | options: Map[String, String],
15 | dataSchema: StructType,
16 | context: TaskAttemptContext)
17 | extends OutputWriter{
18 |
19 | private val outputStream = CodecStreams.createOutputStream(context, new Path(path))
20 | private val dataOutputStream = new DataOutputStream(outputStream)
21 | private val writer = new TFRecordWriter(dataOutputStream)
22 | private val recordType = options.getOrElse("recordType", "Example")
23 |
24 | private[this] val serializer = new TFRecordSerializer(dataSchema)
25 |
26 | override def write(row: InternalRow): Unit = {
27 | val record = recordType match {
28 | case "ByteArray" =>
29 | serializer.serializeByteArray(row)
30 | case "Example" =>
31 | serializer.serializeExample(row).toByteArray
32 | case "SequenceExample" =>
33 | serializer.serializeSequenceExample(row).toByteArray
34 | case _ =>
35 | throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Byte Array, Example or SequenceExample")
36 | }
37 | writer.write(record)
38 | }
39 |
40 | override def close(): Unit = {
41 | dataOutputStream.close()
42 | outputStream.close()
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala:
--------------------------------------------------------------------------------
1 | package com.linkedin.spark.datasources.tfrecord
2 |
3 | import org.apache.spark.sql.catalyst.InternalRow
4 | import org.apache.spark.sql.types.{DecimalType, DoubleType, _}
5 | import org.tensorflow.example._
6 | import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
7 | import com.google.protobuf.ByteString
8 |
9 | /**
10 | * Creates a TFRecord serializer to serialize Spark InternalRow to Tfrecord example or sequenceExample
11 | */
12 | class TFRecordSerializer(dataSchema: StructType) {
13 |
14 | private val featureConverters = dataSchema.map(_.dataType).map(newFeatureConverter(_)).toArray
15 |
16 | def serializeByteArray(row: InternalRow): Array[Byte] = {
17 | row.getBinary(0)
18 | }
19 |
20 | def serializeExample(row: InternalRow): Example = {
21 | val features = Features.newBuilder()
22 | val example = Example.newBuilder()
23 | for (idx <- featureConverters.indices) {
24 | val structField = dataSchema(idx)
25 | if (!row.isNullAt(idx)) {
26 | val feature = featureConverters(idx)(row, idx).asInstanceOf[Feature]
27 | features.putFeature(structField.name, feature)
28 | }
29 | else if (!dataSchema(idx).nullable) {
30 | throw new NullPointerException(s"${structField.name} does not allow null values")
31 | }
32 | }
33 | example.setFeatures(features.build())
34 | example.build()
35 | }
36 |
37 | def serializeSequenceExample(row: InternalRow): SequenceExample = {
38 | val features = Features.newBuilder()
39 | val featureLists = FeatureLists.newBuilder()
40 | val sequenceExample = SequenceExample.newBuilder()
41 | for (idx <- featureConverters.indices) {
42 | val structField = dataSchema(idx)
43 | if (!row.isNullAt(idx)) {
44 | structField.dataType match {
45 | case ArrayType(ArrayType(_, _), _) =>
46 | val featureList = featureConverters(idx)(row, idx).asInstanceOf[FeatureList]
47 | featureLists.putFeatureList(structField.name, featureList)
48 | case _ =>
49 | val feature = featureConverters(idx)(row, idx).asInstanceOf[Feature]
50 | features.putFeature(structField.name, feature)
51 | }
52 | }
53 | else if (!dataSchema(idx).nullable) {
54 | throw new NullPointerException(s"${structField.name} does not allow null values")
55 | }
56 | }
57 | sequenceExample.setContext(features.build())
58 | sequenceExample.setFeatureLists(featureLists.build())
59 | sequenceExample.build()
60 | }
61 |
62 | private type FeatureConverter = (SpecializedGetters, Int) => Any
63 | private type arrayElementConverter = (SpecializedGetters, Int) => Any
64 |
65 | /**
66 | * Creates a converter to convert Catalyst data at the given ordinal to TFrecord Feature.
67 | */
68 | private def newFeatureConverter(
69 | dataType: DataType): FeatureConverter = dataType match {
70 | case NullType => (getter, ordinal) => null
71 |
72 | case IntegerType => (getter, ordinal) =>
73 | val value = getter.getInt(ordinal)
74 | Int64ListFeature(Seq(value.toLong))
75 |
76 | case LongType => (getter, ordinal) =>
77 | val value = getter.getLong(ordinal)
78 | Int64ListFeature(Seq(value))
79 |
80 | case FloatType => (getter, ordinal) =>
81 | val value = getter.getFloat(ordinal)
82 | floatListFeature(Seq(value))
83 |
84 | case DoubleType => (getter, ordinal) =>
85 | val value = getter.getDouble(ordinal)
86 | floatListFeature(Seq(value.toFloat))
87 |
88 | case DecimalType() => (getter, ordinal) =>
89 | val value = getter.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale)
90 | floatListFeature(Seq(value.toFloat))
91 |
92 | case StringType => (getter, ordinal) =>
93 | val value = getter.getUTF8String(ordinal).getBytes
94 | bytesListFeature(Seq(value))
95 |
96 | case BinaryType => (getter, ordinal) =>
97 | val value = getter.getBinary(ordinal)
98 | bytesListFeature(Seq(value))
99 |
100 | case ArrayType(elementType, containsNull) => (getter, ordinal) =>
101 | val arrayData = getter.getArray(ordinal)
102 | val featureOrFeatureList = elementType match {
103 | case IntegerType =>
104 | Int64ListFeature(arrayData.toIntArray().toSeq.map(_.toLong))
105 |
106 | case LongType =>
107 | Int64ListFeature(arrayData.toLongArray().toSeq)
108 |
109 | case FloatType =>
110 | floatListFeature(arrayData.toFloatArray().toSeq)
111 |
112 | case DoubleType =>
113 | floatListFeature(arrayData.toDoubleArray().toSeq.map(_.toFloat))
114 |
115 | case DecimalType() =>
116 | val elementConverter = arrayElementConverter(elementType)
117 | val len = arrayData.numElements()
118 | val result = new Array[Decimal](len)
119 | for (idx <- 0 until len) {
120 | if (containsNull && arrayData.isNullAt(idx)) {
121 | result(idx) = null
122 | } else result(idx) = elementConverter(arrayData, idx).asInstanceOf[Decimal]
123 | }
124 | floatListFeature(result.toSeq.map(_.toFloat))
125 |
126 | case StringType | BinaryType =>
127 | val elementConverter = arrayElementConverter(elementType)
128 | val len = arrayData.numElements()
129 | val result = new Array[Array[Byte]](len)
130 | for (idx <- 0 until len) {
131 | if (containsNull && arrayData.isNullAt(idx)) {
132 | result(idx) = null
133 | } else result(idx) = elementConverter(arrayData, idx).asInstanceOf[Array[Byte]]
134 | }
135 | bytesListFeature(result.toSeq)
136 |
137 | // 2-dimensional array to TensorFlow "FeatureList"
138 | case ArrayType(_, _) =>
139 | val elementConverter = newFeatureConverter(elementType)
140 | val featureList = FeatureList.newBuilder()
141 | for (idx <- 0 until arrayData.numElements()) {
142 | val feature = elementConverter(arrayData, idx).asInstanceOf[Feature]
143 | featureList.addFeature(feature)
144 | }
145 | featureList.build()
146 |
147 | case _ => throw new RuntimeException(s"Array element data type ${dataType} is unsupported")
148 | }
149 | featureOrFeatureList
150 |
151 | case _ => throw new RuntimeException(s"Cannot convert field to unsupported data type ${dataType}")
152 | }
153 |
154 | private def arrayElementConverter(
155 | dataType: DataType): arrayElementConverter = dataType match {
156 | case NullType => null
157 |
158 | case IntegerType => (getter, ordinal) =>
159 | getter.getInt(ordinal)
160 |
161 | case LongType => (getter, ordinal) =>
162 | getter.getLong(ordinal)
163 |
164 | case FloatType => (getter, ordinal) =>
165 | getter.getFloat(ordinal)
166 |
167 | case DoubleType => (getter, ordinal) =>
168 | getter.getDouble(ordinal)
169 |
170 | case DecimalType() => (getter, ordinal) =>
171 | getter.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale)
172 |
173 | case StringType => (getter, ordinal) =>
174 | getter.getUTF8String(ordinal).getBytes
175 |
176 | case BinaryType => (getter, ordinal) =>
177 | getter.getBinary(ordinal)
178 |
179 | case _ => throw new RuntimeException(s"Cannot convert field to unsupported data type ${dataType}")
180 | }
181 |
182 | def Int64ListFeature(value: Seq[Long]): Feature = {
183 | val intListBuilder = Int64List.newBuilder()
184 | value.foreach {x =>
185 | intListBuilder.addValue(x)
186 | }
187 | val int64List = intListBuilder.build()
188 | Feature.newBuilder().setInt64List(int64List).build()
189 | }
190 |
191 | def floatListFeature(value: Seq[Float]): Feature = {
192 | val floatListBuilder = FloatList.newBuilder()
193 | value.foreach {x =>
194 | floatListBuilder.addValue(x)
195 | }
196 | val floatList = floatListBuilder.build()
197 | Feature.newBuilder().setFloatList(floatList).build()
198 | }
199 |
200 | def bytesListFeature(value: Seq[Array[Byte]]): Feature = {
201 | val bytesListBuilder = BytesList.newBuilder()
202 | value.foreach {x =>
203 | bytesListBuilder.addValue(ByteString.copyFrom(x))
204 | }
205 | val bytesList = bytesListBuilder.build()
206 | Feature.newBuilder().setBytesList(bytesList).build()
207 | }
208 | }
209 |
--------------------------------------------------------------------------------
/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.linkedin.spark.datasources.tfrecord
17 |
18 | import org.apache.spark.rdd.RDD
19 | import org.apache.spark.sql.types._
20 | import org.tensorflow.example.{FeatureList, SequenceExample, Example, Feature}
21 |
22 | import scala.collection.mutable
23 | import scala.jdk.CollectionConverters._
24 | import scala.reflect.runtime.universe._
25 |
26 | object TensorFlowInferSchema {
27 |
28 | /**
29 | * Similar to the JSON schema inference.
30 | * [[org.apache.spark.sql.execution.datasources.json.InferSchema]]
31 | * 1. Infer type of each row
32 | * 2. Merge row types to find common type
33 | * 3. Replace any null types with string type
34 | */
35 | def apply[T : TypeTag](rdd: RDD[T]): StructType = {
36 | val startType: mutable.Map[String, DataType] = mutable.Map.empty[String, DataType]
37 |
38 | val rootTypes: mutable.Map[String, DataType] = typeOf[T] match {
39 | case t if t =:= typeOf[Example] => {
40 | rdd.asInstanceOf[RDD[Example]].aggregate(startType)(inferExampleRowType, mergeFieldTypes)
41 | }
42 | case t if t =:= typeOf[SequenceExample] => {
43 | rdd.asInstanceOf[RDD[SequenceExample]].aggregate(startType)(inferSequenceExampleRowType, mergeFieldTypes)
44 | }
45 | case _ => throw new IllegalArgumentException(s"Unsupported recordType: recordType can be Example or SequenceExample")
46 | }
47 |
48 | val columnsList = rootTypes.map {
49 | case (featureName, featureType) =>
50 | if (featureType == null) {
51 | StructField(featureName, NullType)
52 | }
53 | else {
54 | StructField(featureName, featureType)
55 | }
56 | }
57 | StructType(columnsList.toSeq)
58 | }
59 |
60 | def getSchemaForByteArray() : StructType = {
61 | StructType(Array(
62 | StructField("byteArray", BinaryType)
63 | ))
64 | }
65 |
66 | private def inferSequenceExampleRowType(schemaSoFar: mutable.Map[String, DataType],
67 | next: SequenceExample): mutable.Map[String, DataType] = {
68 | val featureMap = next.getContext.getFeatureMap.asScala
69 | val updatedSchema = inferFeatureTypes(schemaSoFar, featureMap)
70 |
71 | val featureListMap = next.getFeatureLists.getFeatureListMap.asScala
72 | inferFeatureListTypes(updatedSchema, featureListMap)
73 | }
74 |
75 | private def inferExampleRowType(schemaSoFar: mutable.Map[String, DataType],
76 | next: Example): mutable.Map[String, DataType] = {
77 | val featureMap = next.getFeatures.getFeatureMap.asScala
78 | inferFeatureTypes(schemaSoFar, featureMap)
79 | }
80 |
81 | private def inferFeatureTypes(schemaSoFar: mutable.Map[String, DataType],
82 | featureMap: mutable.Map[String, Feature]): mutable.Map[String, DataType] = {
83 | featureMap.foreach {
84 | case (featureName, feature) => {
85 | val currentType = inferField(feature)
86 | if (schemaSoFar.contains(featureName)) {
87 | val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType)
88 | schemaSoFar(featureName) = updatedType.orNull
89 | }
90 | else {
91 | schemaSoFar += (featureName -> currentType)
92 | }
93 | }
94 | }
95 | schemaSoFar
96 | }
97 |
98 | def inferFeatureListTypes(schemaSoFar: mutable.Map[String, DataType],
99 | featureListMap: mutable.Map[String, FeatureList]): mutable.Map[String, DataType] = {
100 | featureListMap.foreach {
101 | case (featureName, featureList) => {
102 | val featureType = featureList.getFeatureList.asScala.map(f => inferField(f))
103 | .reduceLeft((a, b) => findTightestCommonType(a, b).orNull)
104 | val currentType = featureType match {
105 | case ArrayType(_, _) => ArrayType(featureType)
106 | case _ => ArrayType(ArrayType(featureType))
107 | }
108 | if (schemaSoFar.contains(featureName)) {
109 | val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType)
110 | schemaSoFar(featureName) = updatedType.orNull
111 | }
112 | else {
113 | schemaSoFar += (featureName -> currentType)
114 | }
115 | }
116 | }
117 | schemaSoFar
118 | }
119 |
120 | private def mergeFieldTypes(first: mutable.Map[String, DataType],
121 | second: mutable.Map[String, DataType]): mutable.Map[String, DataType] = {
122 | //Merge two maps and do the comparison.
123 | val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet)
124 | .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get))
125 | .toSeq: _*)
126 | mutMap
127 | }
128 |
129 | /**
130 | * Infer Feature datatype based on field number
131 | */
132 | private def inferField(feature: Feature): DataType = {
133 | feature.getKindCase.getNumber match {
134 | case Feature.BYTES_LIST_FIELD_NUMBER => {
135 | parseBytesList(feature)
136 | }
137 | case Feature.INT64_LIST_FIELD_NUMBER => {
138 | parseInt64List(feature)
139 | }
140 | case Feature.FLOAT_LIST_FIELD_NUMBER => {
141 | parseFloatList(feature)
142 | }
143 | case _ => throw new RuntimeException("unsupported type ...")
144 | }
145 | }
146 |
147 | private def parseBytesList(feature: Feature): DataType = {
148 | val length = feature.getBytesList.getValueCount
149 |
150 | if (length == 0) {
151 | null
152 | }
153 | else if (length > 1) {
154 | ArrayType(StringType)
155 | }
156 | else {
157 | StringType
158 | }
159 | }
160 |
161 | private def parseInt64List(feature: Feature): DataType = {
162 | val int64List = feature.getInt64List.getValueList.asScala.toArray
163 | val length = int64List.length
164 |
165 | if (length == 0) {
166 | null
167 | }
168 | else if (length > 1) {
169 | ArrayType(LongType)
170 | }
171 | else {
172 | LongType
173 | }
174 | }
175 |
176 | private def parseFloatList(feature: Feature): DataType = {
177 | val floatList = feature.getFloatList.getValueList.asScala.toArray
178 | val length = floatList.length
179 | if (length == 0) {
180 | null
181 | }
182 | else if (length > 1) {
183 | ArrayType(FloatType)
184 | }
185 | else {
186 | FloatType
187 | }
188 | }
189 |
190 | /**
191 | * Copied from internal Spark api
192 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
193 | */
194 | private def getNumericPrecedence(dataType: DataType): Int = {
195 | dataType match {
196 | case LongType => 1
197 | case FloatType => 2
198 | case StringType => 3
199 | case ArrayType(LongType, _) => 4
200 | case ArrayType(FloatType, _) => 5
201 | case ArrayType(StringType, _) => 6
202 | case ArrayType(ArrayType(LongType, _), _) => 7
203 | case ArrayType(ArrayType(FloatType, _), _) => 8
204 | case ArrayType(ArrayType(StringType, _), _) => 9
205 | case _ => throw new RuntimeException("Unable to get the precedence for given datatype...")
206 | }
207 | }
208 |
209 | /**
210 | * Copied from internal Spark api
211 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
212 | */
213 | private def findTightestCommonType(tt1: DataType, tt2: DataType) : Option[DataType] = {
214 | val currType = (tt1, tt2) match {
215 | case (t1, t2) if t1 == t2 => Some(t1)
216 | case (null, t2) => Some(t2)
217 | case (t1, null) => Some(t1)
218 |
219 | // Promote types based on numeric precedence
220 | case (t1, t2) =>
221 | val t1Precedence = getNumericPrecedence(t1)
222 | val t2Precedence = getNumericPrecedence(t2)
223 | val newType = if (t1Precedence > t2Precedence) t1 else t2
224 | Some(newType)
225 | case _ => None
226 | }
227 | currType
228 | }
229 | }
230 |
231 |
--------------------------------------------------------------------------------
/src/test/scala/com/linkedin/spark/datasources/tfrecord/InferSchemaSuite.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.linkedin.spark.datasources.tfrecord
17 |
18 | import org.apache.spark.rdd.RDD
19 | import org.apache.spark.sql.types._
20 | import org.tensorflow.example._
21 | import com.google.protobuf.ByteString
22 |
23 | class InferSchemaSuite extends SharedSparkSessionSuite {
24 |
25 | val longFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(Int.MaxValue + 10L)).build()
26 | val floatFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(10.0F).build()).build()
27 | val strFeature = Feature.newBuilder().setBytesList(
28 | BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes))).build()
29 |
30 | val longList = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(-2L).addValue(20L).build()).build()
31 | val floatList = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(2.5F).addValue(7F).build()).build()
32 | val strList = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes))
33 | .addValue(ByteString.copyFrom("r2".getBytes)).build()).build()
34 |
35 | val emptyFloatList = Feature.newBuilder().setFloatList(FloatList.newBuilder().build()).build()
36 |
37 | "InferSchema" should {
38 |
39 | "Infer schema from Example records" in {
40 | //Build example1
41 | val features1 = Features.newBuilder()
42 | .putFeature("LongFeature", longFeature)
43 | .putFeature("FloatFeature", floatFeature)
44 | .putFeature("StrFeature", strFeature)
45 | .putFeature("LongList", longFeature)
46 | .putFeature("FloatList", floatFeature)
47 | .putFeature("StrList", strFeature)
48 | .putFeature("MixedTypeList", longList)
49 | .build()
50 | val example1 = Example.newBuilder()
51 | .setFeatures(features1)
52 | .build()
53 |
54 | //Example2 contains subset of features in example1 to test behavior with missing features
55 | val features2 = Features.newBuilder()
56 | .putFeature("StrFeature", strFeature)
57 | .putFeature("LongList", longList)
58 | .putFeature("FloatList", floatList)
59 | .putFeature("StrList", strList)
60 | .putFeature("MixedTypeList", floatList)
61 | .build()
62 | val example2 = Example.newBuilder()
63 | .setFeatures(features2)
64 | .build()
65 |
66 | val exampleRdd: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2))
67 | val inferredSchema = TensorFlowInferSchema(exampleRdd)
68 |
69 | //Verify each TensorFlow Datatype is inferred as one of our Datatype
70 | assert(inferredSchema.fields.length == 7)
71 | val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap
72 | assert(schemaMap("LongFeature") === LongType)
73 | assert(schemaMap("FloatFeature") === FloatType)
74 | assert(schemaMap("StrFeature") === StringType)
75 | assert(schemaMap("LongList") === ArrayType(LongType))
76 | assert(schemaMap("FloatList") === ArrayType(FloatType))
77 | assert(schemaMap("StrList") === ArrayType(StringType))
78 | assert(schemaMap("MixedTypeList") === ArrayType(FloatType))
79 | }
80 |
81 | "Infer schema from SequenceExample records" in {
82 |
83 | //Build sequence example1
84 | val features1 = Features.newBuilder()
85 | .putFeature("FloatFeature", floatFeature)
86 |
87 | val longFeatureList1 = FeatureList.newBuilder().addFeature(longFeature).addFeature(longList).build()
88 | val floatFeatureList1 = FeatureList.newBuilder().addFeature(floatFeature).addFeature(floatList).build()
89 | val strFeatureList1 = FeatureList.newBuilder().addFeature(strFeature).build()
90 | val mixedFeatureList1 = FeatureList.newBuilder().addFeature(floatFeature).addFeature(strList).build()
91 |
92 | val featureLists1 = FeatureLists.newBuilder()
93 | .putFeatureList("LongListOfLists", longFeatureList1)
94 | .putFeatureList("FloatListOfLists", floatFeatureList1)
95 | .putFeatureList("StringListOfLists", strFeatureList1)
96 | .putFeatureList("MixedListOfLists", mixedFeatureList1)
97 | .build()
98 |
99 | val seqExample1 = SequenceExample.newBuilder()
100 | .setContext(features1)
101 | .setFeatureLists(featureLists1)
102 | .build()
103 |
104 | //SequenceExample2 contains subset of features in example1 to test behavior with missing features
105 | val longFeatureList2 = FeatureList.newBuilder().addFeature(longList).build()
106 | val floatFeatureList2 = FeatureList.newBuilder().addFeature(floatFeature).build()
107 | val strFeatureList2 = FeatureList.newBuilder().addFeature(strFeature).build() //test both string lists of length=1
108 | val mixedFeatureList2 = FeatureList.newBuilder().addFeature(longFeature).addFeature(strFeature).build()
109 |
110 | val featureLists2 = FeatureLists.newBuilder()
111 | .putFeatureList("LongListOfLists", longFeatureList2)
112 | .putFeatureList("FloatListOfLists", floatFeatureList2)
113 | .putFeatureList("StringListOfLists", strFeatureList2)
114 | .putFeatureList("MixedListOfLists", mixedFeatureList2)
115 | .build()
116 |
117 | val seqExample2 = SequenceExample.newBuilder()
118 | .setFeatureLists(featureLists2)
119 | .build()
120 |
121 | val seqExampleRdd: RDD[SequenceExample] = spark.sparkContext.parallelize(List(seqExample1, seqExample2))
122 | val inferredSchema = TensorFlowInferSchema(seqExampleRdd)
123 |
124 | //Verify each TensorFlow Datatype is inferred as one of our Datatype
125 | assert(inferredSchema.fields.length == 5)
126 | val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap
127 | assert(schemaMap("FloatFeature") === FloatType)
128 | assert(schemaMap("LongListOfLists") === ArrayType(ArrayType(LongType)))
129 | assert(schemaMap("FloatListOfLists") === ArrayType(ArrayType(FloatType)))
130 | assert(schemaMap("StringListOfLists") === ArrayType(ArrayType(StringType)))
131 | assert(schemaMap("MixedListOfLists") === ArrayType(ArrayType(StringType)))
132 | }
133 | }
134 |
135 | "Throw an exception for unsupported record types" in {
136 | intercept[Exception] {
137 | val rdd: RDD[Long] = spark.sparkContext.parallelize(List(5L, 6L))
138 | TensorFlowInferSchema(rdd)
139 | }
140 | }
141 |
142 | "Should have a nullType if there are no elements" in {
143 | //Build sequence example1
144 | val features1 = Features.newBuilder().putFeature("emptyFloatFeature", emptyFloatList)
145 |
146 | val seqExample1 = SequenceExample.newBuilder()
147 | .setContext(features1)
148 | .build()
149 |
150 | val seqExampleRdd: RDD[SequenceExample] = spark.sparkContext.parallelize(List(seqExample1))
151 | val inferredSchema = TensorFlowInferSchema(seqExampleRdd)
152 | assert(inferredSchema.fields.length == 1)
153 | val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap
154 | assert(schemaMap("emptyFloatFeature") === NullType)
155 | }
156 | }
157 |
158 |
--------------------------------------------------------------------------------
/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.linkedin.spark.datasources.tfrecord
17 |
18 | import java.io.File
19 |
20 | import org.apache.commons.io.FileUtils
21 | import org.apache.spark.SharedSparkSession
22 | import org.junit.{After, Before}
23 | import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
24 |
25 |
26 | trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll
27 |
28 | class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite {
29 | val TF_SANDBOX_DIR = "tf-sandbox"
30 | val file = new File(TF_SANDBOX_DIR)
31 |
32 | @Before
33 | override def beforeAll() = {
34 | super.setUp()
35 | FileUtils.deleteQuietly(file)
36 | file.mkdirs()
37 | }
38 |
39 | @After
40 | override def afterAll() = {
41 | FileUtils.deleteQuietly(file)
42 | super.tearDown()
43 | }
44 | }
45 |
46 |
--------------------------------------------------------------------------------
/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.linkedin.spark.datasources.tfrecord
17 |
18 | import com.google.protobuf.ByteString
19 | import org.apache.spark.sql.catalyst.InternalRow
20 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
21 | import org.apache.spark.sql.types._
22 | import org.apache.spark.unsafe.types.UTF8String
23 | import org.scalatest.{Matchers, WordSpec}
24 | import org.tensorflow.example._
25 | import TestingUtils._
26 |
27 |
28 | class TFRecordDeserializerTest extends WordSpec with Matchers {
29 | val intFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(1)).build()
30 | val longFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(23L)).build()
31 | val floatFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(10.0F)).build()
32 | val doubleFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(14.0F)).build()
33 | val decimalFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(2.5F)).build()
34 | val longArrFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(-2L).addValue(7L).build()).build()
35 | val doubleArrFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(1F).addValue(2F).build()).build()
36 | val decimalArrFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(3F).addValue(5F).build()).build()
37 | val strFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build()).build()
38 | val strListFeature =Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes))
39 | .addValue(ByteString.copyFrom("r3".getBytes)).build()).build()
40 | val binaryFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r4".getBytes))).build()
41 | val binaryListFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r5".getBytes))
42 | .addValue(ByteString.copyFrom("r6".getBytes)).build()).build()
43 |
44 | private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
45 |
46 | "Deserialize tfrecord to spark internalRow" should {
47 |
48 | "Deserialize ByteArray to internalRow" in {
49 | val schema = StructType(Array(
50 | StructField("ByteArray", BinaryType)
51 | ))
52 |
53 | val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte)
54 | val expectedInternalRow = InternalRow(byteArray)
55 | val deserializer = new TFRecordDeserializer(schema)
56 | val actualInternalRow = deserializer.deserializeByteArray(byteArray)
57 |
58 | assert(actualInternalRow ~== (expectedInternalRow,schema))
59 | }
60 |
61 | "Serialize tfrecord example to spark internalRow" in {
62 | val schema = StructType(List(
63 | StructField("IntegerLabel", IntegerType),
64 | StructField("LongLabel", LongType),
65 | StructField("FloatLabel", FloatType),
66 | StructField("DoubleLabel", DoubleType),
67 | StructField("DecimalLabel", DataTypes.createDecimalType()),
68 | StructField("LongArrayLabel", ArrayType(LongType)),
69 | StructField("DoubleArrayLabel", ArrayType(DoubleType)),
70 | StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())),
71 | StructField("StrLabel", StringType),
72 | StructField("StrArrayLabel", ArrayType(StringType)),
73 | StructField("BinaryTypeLabel", BinaryType),
74 | StructField("BinaryTypeArrayLabel", ArrayType(BinaryType))
75 | ))
76 |
77 | val expectedInternalRow = InternalRow.fromSeq(
78 | Array[Any](1, 23L, 10.0F, 14.0, Decimal(2.5d),
79 | createArray(-2L,7L),
80 | createArray(1.0, 2.0),
81 | createArray(Decimal(3.0), Decimal(5.0)),
82 | UTF8String.fromString("r1"),
83 | createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")),
84 | "r4".getBytes,
85 | createArray("r5".getBytes(), "r6".getBytes())
86 | )
87 | )
88 |
89 | //Build example
90 | val features = Features.newBuilder()
91 | .putFeature("IntegerLabel", intFeature)
92 | .putFeature("LongLabel", longFeature)
93 | .putFeature("FloatLabel", floatFeature)
94 | .putFeature("DoubleLabel", doubleFeature)
95 | .putFeature("DecimalLabel", decimalFeature)
96 | .putFeature("LongArrayLabel", longArrFeature)
97 | .putFeature("DoubleArrayLabel", doubleArrFeature)
98 | .putFeature("DecimalArrayLabel", decimalArrFeature)
99 | .putFeature("StrLabel", strFeature)
100 | .putFeature("StrArrayLabel", strListFeature)
101 | .putFeature("BinaryTypeLabel", binaryFeature)
102 | .putFeature("BinaryTypeArrayLabel", binaryListFeature)
103 | .build()
104 | val example = Example.newBuilder()
105 | .setFeatures(features)
106 | .build()
107 | val deserializer = new TFRecordDeserializer(schema)
108 | val actualInternalRow = deserializer.deserializeExample(example)
109 |
110 | assert(actualInternalRow ~== (expectedInternalRow,schema))
111 | }
112 |
113 | "Serialize spark internalRow to tfrecord sequenceExample" in {
114 |
115 | val schema = StructType(List(
116 | StructField("FloatLabel", FloatType),
117 | StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))),
118 | StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))),
119 | StructField("DecimalArrayOfArrayLabel", ArrayType(ArrayType(DataTypes.createDecimalType()))),
120 | StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))),
121 | StructField("ByteArrayOfArrayLabel", ArrayType(ArrayType(BinaryType)))
122 | ))
123 |
124 | val expectedInternalRow = InternalRow.fromSeq(
125 | Array[Any](10.0F,
126 | createArray(createArray(-2L, 7L)),
127 | createArray(createArray(10.0F), createArray(1.0F, 2.0F)),
128 | createArray(createArray(Decimal(3.0), Decimal(5.0))),
129 | createArray(createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")),
130 | createArray(UTF8String.fromString("r1"))),
131 | createArray(createArray("r5".getBytes, "r6".getBytes), createArray("r4".getBytes))
132 | )
133 | )
134 |
135 | //Build sequence example
136 | val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build()
137 | val floatFeatureList = FeatureList.newBuilder().addFeature(floatFeature).addFeature(doubleArrFeature).build()
138 | val decimalFeatureList = FeatureList.newBuilder().addFeature(decimalArrFeature).build()
139 | val stringFeatureList = FeatureList.newBuilder().addFeature(strListFeature).addFeature(strFeature).build()
140 | val binaryFeatureList = FeatureList.newBuilder().addFeature(binaryListFeature).addFeature(binaryFeature).build()
141 |
142 |
143 | val features = Features.newBuilder()
144 | .putFeature("FloatLabel", floatFeature)
145 |
146 | val featureLists = FeatureLists.newBuilder()
147 | .putFeatureList("LongArrayOfArrayLabel", int64FeatureList)
148 | .putFeatureList("FloatArrayOfArrayLabel", floatFeatureList)
149 | .putFeatureList("DecimalArrayOfArrayLabel", decimalFeatureList)
150 | .putFeatureList("StrArrayOfArrayLabel", stringFeatureList)
151 | .putFeatureList("ByteArrayOfArrayLabel", binaryFeatureList)
152 | .build()
153 |
154 | val seqExample = SequenceExample.newBuilder()
155 | .setContext(features)
156 | .setFeatureLists(featureLists)
157 | .build()
158 |
159 | val deserializer = new TFRecordDeserializer(schema)
160 | val actualInternalRow = deserializer.deserializeSequenceExample(seqExample)
161 | assert(actualInternalRow ~== (expectedInternalRow, schema))
162 | }
163 |
164 | "Throw an exception for unsupported data types" in {
165 |
166 | val features = Features.newBuilder().putFeature("MapLabel1", intFeature)
167 | val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build()
168 | val featureLists = FeatureLists.newBuilder().putFeatureList("MapLabel2", int64FeatureList)
169 |
170 | intercept[RuntimeException] {
171 | val example = Example.newBuilder()
172 | .setFeatures(features)
173 | .build()
174 | val schema = StructType(List(StructField("MapLabel1", TimestampType)))
175 | val deserializer = new TFRecordDeserializer(schema)
176 | deserializer.deserializeExample(example)
177 | }
178 |
179 | intercept[RuntimeException] {
180 | val seqExample = SequenceExample.newBuilder()
181 | .setContext(features)
182 | .setFeatureLists(featureLists)
183 | .build()
184 | val schema = StructType(List(StructField("MapLabel2", TimestampType)))
185 | val deserializer = new TFRecordDeserializer(schema)
186 | deserializer.deserializeSequenceExample(seqExample)
187 | }
188 | }
189 |
190 | "Throw an exception for non-nullable data types" in {
191 | val features = Features.newBuilder().putFeature("FloatLabel", floatFeature)
192 | val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build()
193 | val featureLists = FeatureLists.newBuilder().putFeatureList("LongArrayOfArrayLabel", int64FeatureList)
194 |
195 | intercept[NullPointerException] {
196 | val example = Example.newBuilder()
197 | .setFeatures(features)
198 | .build()
199 | val schema = StructType(List(StructField("MissingLabel", FloatType, nullable = false)))
200 | val deserializer = new TFRecordDeserializer(schema)
201 | deserializer.deserializeExample(example)
202 | }
203 |
204 | intercept[NullPointerException] {
205 | val seqExample = SequenceExample.newBuilder()
206 | .setContext(features)
207 | .setFeatureLists(featureLists)
208 | .build()
209 | val schema = StructType(List(StructField("MissingLabel", ArrayType(ArrayType(LongType)), nullable = false)))
210 | val deserializer = new TFRecordDeserializer(schema)
211 | deserializer.deserializeSequenceExample(seqExample)
212 | }
213 | }
214 |
215 |
216 | "Return null fields for nullable data types" in {
217 | val features = Features.newBuilder().putFeature("FloatLabel", floatFeature)
218 | val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build()
219 | val featureLists = FeatureLists.newBuilder().putFeatureList("LongArrayOfArrayLabel", int64FeatureList)
220 |
221 | // Deserialize Example
222 | val schema1 = StructType(List(
223 | StructField("FloatLabel", FloatType),
224 | StructField("MissingLabel", FloatType, nullable = true))
225 | )
226 | val expectedInternalRow1 = InternalRow.fromSeq(
227 | Array[Any](10.0F, null)
228 | )
229 | val example = Example.newBuilder()
230 | .setFeatures(features)
231 | .build()
232 | val deserializer1 = new TFRecordDeserializer(schema1)
233 | val actualInternalRow1 = deserializer1.deserializeExample(example)
234 | assert(actualInternalRow1 ~== (expectedInternalRow1, schema1))
235 |
236 | // Deserialize SequenceExample
237 | val schema2 = StructType(List(
238 | StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))),
239 | StructField("MissingLabel", ArrayType(ArrayType(LongType)), nullable = true))
240 | )
241 | val expectedInternalRow2 = InternalRow.fromSeq(
242 | Array[Any](
243 | createArray(createArray(-2L, 7L)), null)
244 | )
245 | val seqExample = SequenceExample.newBuilder()
246 | .setContext(features)
247 | .setFeatureLists(featureLists)
248 | .build()
249 | val deserializer2 = new TFRecordDeserializer(schema2)
250 | val actualInternalRow2 = deserializer2.deserializeSequenceExample(seqExample)
251 | assert(actualInternalRow2 ~== (expectedInternalRow2, schema2))
252 |
253 | }
254 |
255 | val schema = StructType(Array(
256 | StructField("LongLabel", LongType))
257 | )
258 | val deserializer = new TFRecordDeserializer(schema)
259 |
260 | "Test Int64ListFeature2SeqLong" in {
261 | val int64List = Int64List.newBuilder().addValue(5L).build()
262 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
263 | assert(deserializer.Int64ListFeature2SeqLong(intFeature).head === 5L)
264 |
265 | // Throw exception if type doesn't match
266 | intercept[RuntimeException] {
267 | val floatList = FloatList.newBuilder().addValue(2.5F).build()
268 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
269 | deserializer.Int64ListFeature2SeqLong(floatFeature)
270 | }
271 | }
272 |
273 | "Test floatListFeature2SeqFloat" in {
274 | val floatList = FloatList.newBuilder().addValue(2.5F).build()
275 | val floatFeature = Feature.newBuilder().setFloatList(floatList).build()
276 | assert(deserializer.floatListFeature2SeqFloat(floatFeature).head === 2.5F)
277 |
278 | // Throw exception if type doesn't match
279 | intercept[RuntimeException] {
280 | val int64List = Int64List.newBuilder().addValue(5L).build()
281 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
282 | deserializer.floatListFeature2SeqFloat(intFeature)
283 | }
284 | }
285 |
286 | "Test bytesListFeature2SeqArrayByte" in {
287 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
288 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
289 | assert(deserializer.bytesListFeature2SeqArrayByte(bytesFeature).head.sameElements("str-input".getBytes))
290 |
291 | // Throw exception if type doesn't match
292 | intercept[RuntimeException] {
293 | val int64List = Int64List.newBuilder().addValue(5L).build()
294 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
295 | deserializer.bytesListFeature2SeqArrayByte(intFeature)
296 | }
297 | }
298 |
299 | "Test bytesListFeature2SeqString" in {
300 | val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("alice".getBytes))
301 | .addValue(ByteString.copyFrom("bob".getBytes)).build()
302 | val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
303 | assert(deserializer.bytesListFeature2SeqString(bytesFeature) === Seq("alice", "bob"))
304 |
305 | // Throw exception if type doesn't match
306 | intercept[RuntimeException] {
307 | val int64List = Int64List.newBuilder().addValue(5L).build()
308 | val intFeature = Feature.newBuilder().setInt64List(int64List).build()
309 | deserializer.bytesListFeature2SeqArrayByte(intFeature)
310 | }
311 | }
312 |
313 | "Test deserialize rows with different features should not inherit features from previous rows" in {
314 | // given features
315 | val floatFeatures = Features.newBuilder().putFeature("FloatLabel", floatFeature)
316 | val int64Features = Features.newBuilder().putFeature("IntLabel", intFeature)
317 |
318 | // Define common deserializer and schema
319 | val schema = StructType(List(
320 | StructField("FloatLabel", FloatType),
321 | StructField("IntLabel", IntegerType),
322 | StructField("MissingLabel", FloatType, nullable = true))
323 | )
324 | val deserializer = new TFRecordDeserializer(schema)
325 |
326 | // given rows with different features - row 1 has only FloatLabel feature
327 | val expectedInternalRow1 = InternalRow.fromSeq(
328 | Array[Any](10.0F, null, null)
329 | )
330 | val example1 = Example.newBuilder()
331 | .setFeatures(floatFeatures)
332 | .build()
333 | val actualInternalRow1 = deserializer.deserializeExample(example1)
334 | assert(actualInternalRow1 ~== (expectedInternalRow1, schema))
335 |
336 | // .. and the second row has only IntLabel feature
337 | val expectedInternalRow2 = InternalRow.fromSeq(
338 | Array[Any](null, 1, null)
339 | )
340 | val example2 = Example.newBuilder()
341 | .setFeatures(int64Features)
342 | .build()
343 | val actualInternalRow2 = deserializer.deserializeExample(example2)
344 | assert(actualInternalRow2 ~== (expectedInternalRow2, schema))
345 |
346 | }
347 | }
348 | }
349 |
--------------------------------------------------------------------------------
/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordIOSuite.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.linkedin.spark.datasources.tfrecord
17 |
18 | import org.apache.hadoop.fs.Path
19 | import org.apache.spark.sql.catalyst.expressions.GenericRow
20 | import org.apache.spark.sql.types._
21 | import org.apache.spark.sql.{DataFrame, Row, SaveMode}
22 |
23 | import TestingUtils._
24 |
25 | class TFRecordIOSuite extends SharedSparkSessionSuite {
26 |
27 | val exampleSchema = StructType(List(
28 | StructField("id", IntegerType),
29 | StructField("IntegerLabel", IntegerType),
30 | StructField("LongLabel", LongType),
31 | StructField("FloatLabel", FloatType),
32 | StructField("DoubleLabel", DoubleType),
33 | StructField("DecimalLabel", DataTypes.createDecimalType()),
34 | StructField("StrLabel", StringType),
35 | StructField("BinaryLabel", BinaryType),
36 | StructField("IntegerArrayLabel", ArrayType(IntegerType)),
37 | StructField("LongArrayLabel", ArrayType(LongType)),
38 | StructField("FloatArrayLabel", ArrayType(FloatType)),
39 | StructField("DoubleArrayLabel", ArrayType(DoubleType, true)),
40 | StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())),
41 | StructField("StrArrayLabel", ArrayType(StringType, true)),
42 | StructField("BinaryArrayLabel", ArrayType(BinaryType), true))
43 | )
44 |
45 | val exampleTestRows: Array[Row] = Array(
46 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, Decimal(1.1), "r1", Array[Byte](0xff.toByte, 0xf0.toByte),
47 | Seq(1, 2),
48 | Seq(11L, 12L),
49 | Seq(1.2F, 2.1F),
50 | Seq(1.1, 2.2),
51 | Seq(Decimal(1.1), Decimal(2.2)),
52 | Seq("str1", "str2"),
53 | Seq(Array[Byte](0xfa.toByte, 0xfb.toByte), Array[Byte](0xfa.toByte)))),
54 | new GenericRow(Array[Any](11, 1, 24L, 11.0F, 15.0, Decimal(2.1), "r2", Array[Byte](0xfa.toByte, 0xfb.toByte),
55 | Seq(3, 4),
56 | Seq(110L, 120L),
57 | Seq(1.2F, 2.1F),
58 | Seq(1.1, 2.2),
59 | Seq(Decimal(2.1), Decimal(3.2)),
60 | Seq("str3", "str4"),
61 | Seq(Array[Byte](0xf1.toByte, 0xf2.toByte), Array[Byte](0xfa.toByte)))),
62 | new GenericRow(Array[Any](21, 1, 23L, 10.0F, 14.0, Decimal(3.1), "r3", Array[Byte](0xfc.toByte, 0xfd.toByte),
63 | Seq(5, 6),
64 | Seq(111L, 112L),
65 | Seq(1.22F, 2.11F),
66 | Seq(11.1, 12.2),
67 | Seq(Decimal(3.1), Decimal(4.2)),
68 | Seq("str5", "str6"),
69 | Seq(Array[Byte](0xf4.toByte, 0xf2.toByte), Array[Byte](0xfa.toByte)))))
70 |
71 | val sequenceExampleTestRows: Array[Row] = Array(
72 | new GenericRow(Array[Any](23L, Seq(Seq(2, 4)), Seq(Seq(-1.1F, 0.1F)), Seq(Seq("r1", "r2")))),
73 | new GenericRow(Array[Any](24L, Seq(Seq(-1, 0)), Seq(Seq(-1.1F, 0.2F)), Seq(Seq("r3")))))
74 |
75 | val sequenceExampleSchema = StructType(List(
76 | StructField("id",LongType),
77 | StructField("IntegerArrayOfArrayLabel", ArrayType(ArrayType(IntegerType))),
78 | StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))),
79 | StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType)))
80 | ))
81 |
82 | val byteArrayTestRows: Array[Row] = Array(
83 | new GenericRow(Array[Any](Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte)))
84 | )
85 |
86 | private def createDataFrameForExampleTFRecord() : DataFrame = {
87 | val rdd = spark.sparkContext.parallelize(exampleTestRows)
88 | spark.createDataFrame(rdd, exampleSchema)
89 | }
90 |
91 | private def createDataFrameForSequenceExampleTFRecords() : DataFrame = {
92 | val rdd = spark.sparkContext.parallelize(sequenceExampleTestRows)
93 | spark.createDataFrame(rdd, sequenceExampleSchema)
94 | }
95 |
96 | private def createDataFrameForByteArrayTFRecords() : DataFrame = {
97 | val rdd = spark.sparkContext.parallelize(byteArrayTestRows)
98 | spark.createDataFrame(rdd, TensorFlowInferSchema.getSchemaForByteArray())
99 | }
100 |
101 | private def pathExists(pathStr: String): Boolean = {
102 | val hadoopConf = spark.sparkContext.hadoopConfiguration
103 | val outputPath = new Path(pathStr)
104 | val fs = outputPath.getFileSystem(hadoopConf)
105 | val qualifiedPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
106 | fs.exists(qualifiedPath)
107 | }
108 |
109 | private def getFileCount(pathStr: String): Long = {
110 | val hadoopConf = spark.sparkContext.hadoopConfiguration
111 | val outputPath = new Path(pathStr)
112 | val fs = outputPath.getFileSystem(hadoopConf)
113 | val qualifiedPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
114 | fs.getContentSummary(qualifiedPath).getFileCount
115 | }
116 |
117 | "Spark tfrecord IO" should {
118 | "Test tfrecord example Read/write " in {
119 |
120 | val path = s"$TF_SANDBOX_DIR/example.tfrecord"
121 |
122 | val df: DataFrame = createDataFrameForExampleTFRecord()
123 | df.write.format("tfrecord").option("recordType", "Example").save(path)
124 |
125 | //If schema is not provided. It will automatically infer schema
126 | val importedDf: DataFrame = spark.read.format("tfrecord").option("recordType", "Example").schema(exampleSchema).load(path)
127 |
128 | val actualDf = importedDf.select("id", "IntegerLabel", "LongLabel", "FloatLabel",
129 | "DoubleLabel", "DecimalLabel", "StrLabel", "BinaryLabel", "IntegerArrayLabel", "LongArrayLabel",
130 | "FloatArrayLabel", "DoubleArrayLabel", "DecimalArrayLabel", "StrArrayLabel", "BinaryArrayLabel").sort("StrLabel")
131 |
132 | val expectedRows = df.collect()
133 | val actualRows = actualDf.collect()
134 |
135 | expectedRows.zip(actualRows).foreach { case (expected: Row, actual: Row) =>
136 | assert(expected ~== actual, exampleSchema)
137 | }
138 | }
139 |
140 | "Test tfrecord partition by id" in {
141 | val output = s"$TF_SANDBOX_DIR/example-partition-by-id.tfrecord"
142 | val df: DataFrame = createDataFrameForExampleTFRecord()
143 | df.write.format("tfrecord").partitionBy("id").option("recordType", "Example").save(output)
144 | assert(pathExists(output))
145 | val partition1Path = s"$output/id=11"
146 | val partition2Path = s"$output/id=21"
147 | assert(pathExists(partition1Path))
148 | assert(pathExists(partition2Path))
149 | assert(getFileCount(partition1Path) == 2)
150 | assert(getFileCount(partition2Path) == 1)
151 | }
152 |
153 | "Test tfrecord read/write SequenceExample" in {
154 |
155 | val path = s"$TF_SANDBOX_DIR/sequenceExample.tfrecord"
156 |
157 | val df: DataFrame = createDataFrameForSequenceExampleTFRecords()
158 | df.write.format("tfrecord").option("recordType", "SequenceExample").save(path)
159 |
160 | val importedDf: DataFrame = spark.read.format("tfrecord").option("recordType", "SequenceExample").schema(sequenceExampleSchema).load(path)
161 | val actualDf = importedDf.select("id", "IntegerArrayOfArrayLabel", "FloatArrayOfArrayLabel", "StrArrayOfArrayLabel").sort("id")
162 |
163 | val expectedRows = df.collect()
164 | val actualRows = actualDf.collect()
165 |
166 | assert(expectedRows === actualRows)
167 | }
168 |
169 | "Test tfrecord read/write ByteArray" in {
170 |
171 | val path = s"$TF_SANDBOX_DIR/byteArray.tfrecord"
172 |
173 | val df: DataFrame = createDataFrameForByteArrayTFRecords()
174 | df.write.format("tfrecord").option("recordType", "ByteArray").save(path)
175 |
176 | val importedDf: DataFrame = spark.read.format("tfrecord").option("recordType", "ByteArray").load(path)
177 |
178 | val expectedRows = df.collect()
179 | val actualRows = importedDf.collect()
180 |
181 | assert(expectedRows === actualRows)
182 | }
183 |
184 | "Test tfrecord write overwrite mode " in {
185 |
186 | val path = s"$TF_SANDBOX_DIR/example_overwrite.tfrecord"
187 |
188 | val df: DataFrame = createDataFrameForExampleTFRecord()
189 | df.write.format("tfrecord").option("recordType", "Example").save(path)
190 |
191 | df.write.format("tfrecord").mode(SaveMode.Overwrite).option("recordType", "Example").save(path)
192 |
193 | //If schema is not provided. It will automatically infer schema
194 | val importedDf: DataFrame = spark.read.format("tfrecord").option("recordType", "Example").schema(exampleSchema).load(path)
195 |
196 | val actualDf = importedDf.select("id", "IntegerLabel", "LongLabel", "FloatLabel",
197 | "DoubleLabel", "DecimalLabel", "StrLabel", "BinaryLabel", "IntegerArrayLabel", "LongArrayLabel",
198 | "FloatArrayLabel", "DoubleArrayLabel", "DecimalArrayLabel", "StrArrayLabel", "BinaryArrayLabel").sort("StrLabel")
199 |
200 | val expectedRows = df.collect()
201 | val actualRows = actualDf.collect()
202 |
203 | expectedRows.zip(actualRows).foreach { case (expected: Row, actual: Row) =>
204 | assert(expected ~== actual, exampleSchema)
205 | }
206 | }
207 |
208 | "Test tfrecord write append mode" in {
209 |
210 | val path = s"$TF_SANDBOX_DIR/example_append.tfrecord"
211 |
212 | val df: DataFrame = createDataFrameForExampleTFRecord()
213 | df.write.format("tfrecord").option("recordType", "Example").save(path)
214 | df.write.format("tfrecord").mode(SaveMode.Append).option("recordType", "Example").save(path)
215 | }
216 |
217 | "Test tfrecord write ignore mode" in {
218 |
219 | val path = s"$TF_SANDBOX_DIR/example_ignore.tfrecord"
220 |
221 | val hadoopConf = spark.sparkContext.hadoopConfiguration
222 | val outputPath = new Path(path)
223 | val fs = outputPath.getFileSystem(hadoopConf)
224 | val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
225 |
226 | val df: DataFrame = createDataFrameForExampleTFRecord()
227 | df.write.format("tfrecord").mode(SaveMode.Ignore).option("recordType", "Example").save(path)
228 |
229 | assert(fs.exists(qualifiedOutputPath))
230 | val timestamp1 = fs.getFileStatus(qualifiedOutputPath).getModificationTime
231 |
232 | df.write.format("tfrecord").mode(SaveMode.Ignore).option("recordType", "Example").save(path)
233 |
234 | val timestamp2 = fs.getFileStatus(qualifiedOutputPath).getModificationTime
235 |
236 | assert(timestamp1 == timestamp2, "SaveMode.Ignore Error: File was overwritten. Timestamps do not match")
237 | }
238 | }
239 | }
240 |
--------------------------------------------------------------------------------
/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.linkedin.spark.datasources.tfrecord
17 |
18 | import org.apache.spark.sql.catalyst.InternalRow
19 | import org.tensorflow.example._
20 | import org.apache.spark.sql.types.{StructField, _}
21 | import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
22 | import org.apache.spark.unsafe.types.UTF8String
23 | import org.scalatest.{Matchers, WordSpec}
24 |
25 | import scala.collection.JavaConverters._
26 | import TestingUtils._
27 |
28 | class TFRecordSerializerTest extends WordSpec with Matchers {
29 |
30 | private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
31 |
32 | "Serialize spark internalRow to tfrecord" should {
33 |
34 | "Serialize ByteArray internalRow to ByteArray" in {
35 | val serializer = new TFRecordSerializer(new StructType())
36 |
37 | val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte)
38 | val internalRow = InternalRow(byteArray)
39 | val serializedByteArray = serializer.serializeByteArray(internalRow)
40 |
41 | //two byte arrays should be the same, since serialization just gets byte array from internal row
42 | assert(byteArray.length == serializedByteArray.length)
43 | assert(byteArray.sameElements(serializedByteArray))
44 | }
45 |
46 | "Serialize decimal internalRow to tfrecord example" in {
47 | val schemaStructType = StructType(Array(
48 | StructField("DecimalLabel", DataTypes.createDecimalType()),
49 | StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType()))
50 | ))
51 | val serializer = new TFRecordSerializer(schemaStructType)
52 |
53 | val decimalArray = Array(Decimal(4.0), Decimal(8.0))
54 | val rowArray = Array[Any](Decimal(6.5), ArrayData.toArrayData(decimalArray))
55 | val internalRow = InternalRow.fromSeq(rowArray)
56 |
57 | //Encode Sql InternalRow to TensorFlow example
58 | val example = serializer.serializeExample(internalRow)
59 |
60 | //Verify each Datatype converted to TensorFlow datatypes
61 | val featureMap = example.getFeatures.getFeatureMap.asScala
62 | assert(featureMap.size == rowArray.length)
63 |
64 | assert(featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
65 | assert(featureMap("DecimalLabel").getFloatList.getValue(0) == 6.5F)
66 |
67 | assert(featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
68 | assert(featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== decimalArray.map(_.toFloat))
69 | }
70 |
71 | "Serialize complex internalRow to tfrecord example" in {
72 | val schemaStructType = StructType(Array(
73 | StructField("IntegerLabel", IntegerType),
74 | StructField("LongLabel", LongType),
75 | StructField("FloatLabel", FloatType),
76 | StructField("DoubleLabel", DoubleType),
77 | StructField("DecimalLabel", DataTypes.createDecimalType()),
78 | StructField("DoubleArrayLabel", ArrayType(DoubleType)),
79 | StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())),
80 | StructField("StrLabel", StringType),
81 | StructField("StrArrayLabel", ArrayType(StringType)),
82 | StructField("BinaryLabel", BinaryType),
83 | StructField("BinaryArrayLabel", ArrayType(BinaryType))
84 | ))
85 | val doubleArray = Array(1.1, 111.1, 11111.1)
86 | val decimalArray = Array(Decimal(4.0), Decimal(8.0))
87 | val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte)
88 | val byteArray1 = Array[Byte](-128, 23, 127)
89 |
90 | val rowArray = Array[Any](1, 23L, 10.0F, 14.0, Decimal(6.5),
91 | ArrayData.toArrayData(doubleArray),
92 | ArrayData.toArrayData(decimalArray),
93 | UTF8String.fromString("r1"),
94 | ArrayData.toArrayData(Array(UTF8String.fromString("r2"), UTF8String.fromString("r3"))),
95 | byteArray,
96 | ArrayData.toArrayData(Array(byteArray, byteArray1))
97 | )
98 |
99 | val internalRow = InternalRow.fromSeq(rowArray)
100 |
101 | val serializer = new TFRecordSerializer(schemaStructType)
102 | val example = serializer.serializeExample(internalRow)
103 |
104 | //Verify each Datatype converted to TensorFlow datatypes
105 | val featureMap = example.getFeatures.getFeatureMap.asScala
106 | assert(featureMap.size == rowArray.length)
107 |
108 | assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER)
109 | assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 1)
110 |
111 | assert(featureMap("LongLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER)
112 | assert(featureMap("LongLabel").getInt64List.getValue(0).toInt == 23)
113 |
114 | assert(featureMap("FloatLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
115 | assert(featureMap("FloatLabel").getFloatList.getValue(0) == 10.0F)
116 |
117 | assert(featureMap("DoubleLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
118 | assert(featureMap("DoubleLabel").getFloatList.getValue(0) == 14.0F)
119 |
120 | assert(featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
121 | assert(featureMap("DecimalLabel").getFloatList.getValue(0) == 6.5F)
122 |
123 | assert(featureMap("DoubleArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
124 | assert(featureMap("DoubleArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== doubleArray.map(_.toFloat))
125 |
126 | assert(featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
127 | assert(featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== decimalArray.map(_.toFloat))
128 |
129 | assert(featureMap("StrLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
130 | assert(featureMap("StrLabel").getBytesList.getValue(0).toStringUtf8 == "r1")
131 |
132 | assert(featureMap("StrArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
133 | assert(featureMap("StrArrayLabel").getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("r2", "r3"))
134 |
135 | assert(featureMap("BinaryLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
136 | assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.sameElements(byteArray))
137 |
138 | assert(featureMap("BinaryArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
139 | val binaryArrayValue = featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala.map((byteArray) => byteArray.asScala.toArray.map(_.toByte))
140 | binaryArrayValue.toArray should equal(Array(byteArray, byteArray1))
141 | }
142 |
143 | "Serialize internalRow to tfrecord sequenceExample" in {
144 |
145 | val schemaStructType = StructType(Array(
146 | StructField("IntegerLabel", IntegerType),
147 | StructField("StringArrayLabel", ArrayType(StringType)),
148 | StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))),
149 | StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))) ,
150 | StructField("DoubleArrayOfArrayLabel", ArrayType(ArrayType(DoubleType))),
151 | StructField("DecimalArrayOfArrayLabel", ArrayType(ArrayType(DataTypes.createDecimalType()))),
152 | StructField("StringArrayOfArrayLabel", ArrayType(ArrayType(StringType))),
153 | StructField("BinaryArrayOfArrayLabel", ArrayType(ArrayType(BinaryType)))
154 | ))
155 |
156 | val stringList = Array(UTF8String.fromString("r1"), UTF8String.fromString("r2"), UTF8String.fromString(("r3")))
157 | val longListOfLists = Array(Array(3L, 5L), Array(-8L, 0L))
158 | val floatListOfLists = Array(Array(1.5F, -6.5F), Array(-8.2F, 0F))
159 | val doubleListOfLists = Array(Array(3.0), Array(6.0, 9.0))
160 | val decimalListOfLists = Array(Array(Decimal(2.0), Decimal(4.0)), Array(Decimal(6.0)))
161 | val stringListOfLists = Array(Array(UTF8String.fromString("r1")),
162 | Array(UTF8String.fromString("r2"), UTF8String.fromString("r3")),
163 | Array(UTF8String.fromString("r4")))
164 | val binaryListOfLists = stringListOfLists.map(stringList => stringList.map(_.getBytes))
165 |
166 | val rowArray = Array[Any](10,
167 | createArray(UTF8String.fromString("r1"), UTF8String.fromString("r2"), UTF8String.fromString(("r3"))),
168 | createArray(
169 | createArray(3L, 5L),
170 | createArray(-8L, 0L)
171 | ),
172 | createArray(
173 | createArray(1.5F, -6.5F),
174 | createArray(-8.2F, 0F)
175 | ),
176 | createArray(
177 | createArray(3.0),
178 | createArray(6.0, 9.0)
179 | ),
180 | createArray(
181 | createArray(Decimal(2.0), Decimal(4.0)),
182 | createArray(Decimal(6.0))
183 | ),
184 | createArray(
185 | createArray(UTF8String.fromString("r1")),
186 | createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")),
187 | createArray(UTF8String.fromString("r4"))
188 | ),
189 | createArray(createArray("r1".getBytes()),
190 | createArray("r2".getBytes(), "r3".getBytes),
191 | createArray("r4".getBytes())
192 | )
193 | )
194 |
195 | val internalRow = InternalRow.fromSeq(rowArray)
196 |
197 | val serializer = new TFRecordSerializer(schemaStructType)
198 | val tfexample = serializer.serializeSequenceExample(internalRow)
199 |
200 | //Verify each Datatype converted to TensorFlow datatypes
201 | val featureMap = tfexample.getContext.getFeatureMap.asScala
202 | val featureListMap = tfexample.getFeatureLists.getFeatureListMap.asScala
203 |
204 | assert(featureMap.size == 2)
205 | assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER)
206 | assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 10)
207 | assert(featureMap("StringArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
208 | assert(featureMap("StringArrayLabel").getBytesList.getValueList.asScala.map(x => UTF8String.fromString(x.toStringUtf8)) === stringList)
209 |
210 | assert(featureListMap.size == 6)
211 | assert(featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala.map(
212 | _.getInt64List.getValueList.asScala.toSeq) === longListOfLists)
213 |
214 | assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.toSeq.map(
215 | _.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists.map{arr => arr.toSeq}.toSeq)
216 | assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.toSeq.map(
217 | _.getFloatList.getValueList.asScala.map(_.toDouble).toSeq) ~== doubleListOfLists.map{arr => arr.toSeq}.toSeq)
218 |
219 | assert(featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.toSeq.map(
220 | _.getFloatList.getValueList.asScala.map(x => Decimal(x.toDouble)).toSeq) ~== decimalListOfLists.map{arr => arr.toSeq}.toSeq)
221 |
222 | assert(featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map(
223 | _.getBytesList.getValueList.asScala.map(x => UTF8String.fromString(x.toStringUtf8)).toSeq) === stringListOfLists)
224 |
225 | assert(featureListMap("BinaryArrayOfArrayLabel").getFeatureList.asScala.map(
226 | _.getBytesList.getValueList.asScala.map(byteList => byteList.asScala.toSeq)) === binaryListOfLists.map(_.map(_.toSeq)))
227 | }
228 |
229 | "Throw an exception for non-nullable data types" in {
230 | val schemaStructType = StructType(Array(
231 | StructField("NonNullLabel", ArrayType(FloatType), nullable = false)
232 | ))
233 |
234 | val internalRow = InternalRow.fromSeq(Array[Any](null))
235 |
236 | val serializer = new TFRecordSerializer(schemaStructType)
237 |
238 | intercept[NullPointerException]{
239 | serializer.serializeExample(internalRow)
240 | }
241 |
242 | intercept[NullPointerException]{
243 | serializer.serializeSequenceExample(internalRow)
244 | }
245 | }
246 |
247 | "Omit null fields from Example for nullable data types" in {
248 | val schemaStructType = StructType(Array(
249 | StructField("NullLabel", ArrayType(FloatType), nullable = true),
250 | StructField("FloatArrayLabel", ArrayType(FloatType))
251 | ))
252 |
253 | val floatArray = Array(2.5F, 5.0F)
254 | val internalRow = InternalRow.fromSeq(
255 | Array[Any](null, createArray(2.5F, 5.0F))
256 | )
257 |
258 | val serializer = new TFRecordSerializer(schemaStructType)
259 | val tfexample = serializer.serializeExample(internalRow)
260 |
261 | //Verify each Datatype converted to TensorFlow datatypes
262 | val featureMap = tfexample.getFeatures.getFeatureMap.asScala
263 | assert(featureMap.size == 1)
264 | assert(featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
265 | assert(featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== floatArray.toSeq)
266 | }
267 |
268 | "Omit null fields from SequenceExample for nullable data types" in {
269 | val schemaStructType = StructType(Array(
270 | StructField("NullLabel", ArrayType(FloatType), nullable = true),
271 | StructField("FloatArrayLabel", ArrayType(FloatType))
272 | ))
273 |
274 | val floatArray = Array(2.5F, 5.0F)
275 | val internalRow = InternalRow.fromSeq(
276 | Array[Any](null, createArray(2.5F, 5.0F)))
277 |
278 | val serializer = new TFRecordSerializer(schemaStructType)
279 | val tfSeqExample = serializer.serializeSequenceExample(internalRow)
280 |
281 | //Verify each Datatype converted to TensorFlow datatypes
282 | val featureMap = tfSeqExample.getContext.getFeatureMap.asScala
283 | val featureListMap = tfSeqExample.getFeatureLists.getFeatureListMap.asScala
284 | assert(featureMap.size == 1)
285 | assert(featureListMap.isEmpty)
286 | assert(featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER)
287 | assert(featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== floatArray.toSeq)
288 | }
289 |
290 | "Throw an exception for unsupported data types" in {
291 |
292 | val schemaStructType = StructType(Array(
293 | StructField("TimestampLabel", TimestampType)
294 | ))
295 |
296 | intercept[RuntimeException]{
297 | new TFRecordSerializer(schemaStructType)
298 | }
299 | }
300 |
301 | val schema = StructType(Array(
302 | StructField("bytesLabel", BinaryType))
303 | )
304 | val serializer = new TFRecordSerializer(schema)
305 |
306 | "Test Int64ListFeature" in {
307 | val longFeature = serializer.Int64ListFeature(Seq(10L))
308 | val longListFeature = serializer.Int64ListFeature(Seq(3L,5L,6L))
309 |
310 | assert(longFeature.getInt64List.getValueList.asScala.toSeq === Seq(10L))
311 | assert(longListFeature.getInt64List.getValueList.asScala.toSeq === Seq(3L, 5L, 6L))
312 | }
313 |
314 | "Test floatListFeature" in {
315 | val floatFeature = serializer.floatListFeature(Seq(10.1F))
316 | val floatListFeature = serializer.floatListFeature(Seq(3.1F,5.1F,6.1F))
317 |
318 | assert(floatFeature.getFloatList.getValueList.asScala.toSeq === Seq(10.1F))
319 | assert(floatListFeature.getFloatList.getValueList.asScala.toSeq === Seq(3.1F,5.1F,6.1F))
320 | }
321 |
322 | "Test bytesListFeature" in {
323 | val bytesFeature = serializer.bytesListFeature(Seq(Array(0xff.toByte, 0xd8.toByte)))
324 | val bytesListFeature = serializer.bytesListFeature(Seq(
325 | Array(0xff.toByte, 0xd8.toByte),
326 | Array(0xff.toByte, 0xd9.toByte)))
327 |
328 | bytesFeature.getBytesList.getValueList.asScala.map(_.toByteArray).toArray should equal(
329 | Array(Array(0xff.toByte, 0xd8.toByte)))
330 | bytesListFeature.getBytesList.getValueList.asScala.map(_.toByteArray).toArray should equal(
331 | Array(Array(0xff.toByte, 0xd8.toByte), Array(0xff.toByte, 0xd9.toByte)))
332 | }
333 | }
334 | }
335 |
--------------------------------------------------------------------------------
/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package com.linkedin.spark.datasources.tfrecord
17 |
18 | import org.apache.spark.sql.Row
19 | import org.apache.spark.sql.catalyst.InternalRow
20 | import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, SpecializedGetters}
21 | import org.apache.spark.sql.catalyst.util.ArrayData
22 | import org.apache.spark.sql.types._
23 | import org.scalatest.Matchers
24 |
25 | object TestingUtils extends Matchers {
26 |
27 | /**
28 | * Implicit class for comparing two double values using absolute tolerance.
29 | */
30 | implicit class FloatArrayWithAlmostEquals(val left: Seq[Float]) {
31 |
32 | /**
33 | * When the difference of two values are within eps, returns true; otherwise, returns false.
34 | */
35 | def ~==(right: Seq[Float], epsilon : Float = 1E-6F): Boolean = {
36 | if (left.size === right.size) {
37 | (left zip right) forall { case (a, b) => a === (b +- epsilon) }
38 | }
39 | else false
40 | }
41 | }
42 |
43 | /**
44 | * Implicit class for comparing two double values using absolute tolerance.
45 | */
46 | implicit class DoubleArrayWithAlmostEquals(val left: Seq[Double]) {
47 |
48 | /**
49 | * When the difference of two values are within eps, returns true; otherwise, returns false.
50 | */
51 | def ~==(right: Seq[Double], epsilon : Double = 1E-6): Boolean = {
52 | if (left.size === right.size) {
53 | (left zip right) forall { case (a, b) => a === (b +- epsilon) }
54 | }
55 | else false
56 | }
57 | }
58 |
59 | /**
60 | * Implicit class for comparing two decimal values using absolute tolerance.
61 | */
62 | implicit class DecimalArrayWithAlmostEquals(val left: Seq[Decimal]) {
63 |
64 | /**
65 | * When the difference of two values are within eps, returns true; otherwise, returns false.
66 | */
67 | def ~==(right: Seq[Decimal], epsilon : Double = 1E-6): Boolean = {
68 | if (left.size === right.size) {
69 | (left zip right) forall { case (a, b) => a.toDouble === (b.toDouble +- epsilon) }
70 | }
71 | else false
72 | }
73 | }
74 |
75 | /**
76 | * Implicit class for comparing two double values using absolute tolerance.
77 | */
78 | implicit class FloatMatrixWithAlmostEquals(val left: Seq[Seq[Float]]) {
79 |
80 | /**
81 | * When the difference of two values are within eps, returns true; otherwise, returns false.
82 | */
83 | def ~==(right: Seq[Seq[Float]], epsilon : Float = 1E-6F): Boolean = {
84 | if (left.size === right.size) {
85 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) }
86 | }
87 | else false
88 | }
89 | }
90 |
91 | /**
92 | * Implicit class for comparing two double values using absolute tolerance.
93 | */
94 | implicit class DoubleMatrixWithAlmostEquals(val left: Seq[Seq[Double]]) {
95 |
96 | /**
97 | * When the difference of two values are within eps, returns true; otherwise, returns false.
98 | */
99 | def ~==(right: Seq[Seq[Double]], epsilon : Double = 1E-6): Boolean = {
100 | if (left.size === right.size) {
101 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) }
102 | }
103 | else false
104 | }
105 | }
106 |
107 | /**
108 | * Implicit class for comparing two decimal values using absolute tolerance.
109 | */
110 | implicit class DecimalMatrixWithAlmostEquals(val left: Seq[Seq[Decimal]]) {
111 |
112 | /**
113 | * When the difference of two values are within eps, returns true; otherwise, returns false.
114 | */
115 | def ~==(right: Seq[Seq[Decimal]], epsilon : Double = 1E-6): Boolean = {
116 | if (left.size === right.size) {
117 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) }
118 | }
119 | else false
120 | }
121 | }
122 |
123 | /**
124 | * Implicit class for comparing two internalRows using absolute tolerance.
125 | */
126 | implicit class InternalRowWithAlmostEquals(val left: InternalRow) {
127 |
128 | private type valueCompare = (SpecializedGetters, SpecializedGetters, Int) => Boolean
129 | private def newValueCompare(
130 | dataType: DataType,
131 | epsilon : Float = 1E-6F): valueCompare = dataType match {
132 | case NullType => (left, right, ordinal) =>
133 | left.get(ordinal, null) == right.get(ordinal, null)
134 |
135 | case IntegerType => (left, right, ordinal) =>
136 | left.getInt(ordinal) === right.getInt(ordinal)
137 |
138 | case LongType => (left, right, ordinal) =>
139 | left.getLong(ordinal) === right.getLong(ordinal)
140 |
141 | case FloatType => (left, right, ordinal) =>
142 | left.getFloat(ordinal) === (right.getFloat(ordinal) +- epsilon)
143 |
144 | case DoubleType => (left, right, ordinal) =>
145 | left.getDouble(ordinal) === (right.getDouble(ordinal) +- epsilon)
146 |
147 | case DecimalType() => (left, right, ordinal) =>
148 | left.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale).toDouble ===
149 | (right.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale).toDouble
150 | +- epsilon)
151 |
152 | case StringType => (left, right, ordinal) =>
153 | left.getUTF8String(ordinal).getBytes === right.getUTF8String(ordinal).getBytes
154 |
155 | case BinaryType => (left, right, ordinal) =>
156 | left.getBinary(ordinal) === right.getBinary(ordinal)
157 |
158 | case ArrayType(elementType, _) => (left, right, ordinal) =>
159 | if (left.get(ordinal, null) == null || right.get(ordinal, null) == null ){
160 | left.get(ordinal, null) == right.get(ordinal, null)
161 | } else {
162 | val leftArray = left.getArray(ordinal)
163 | val rightArray = right.getArray(ordinal)
164 | if (leftArray.numElements == rightArray.numElements) {
165 | val len = leftArray.numElements()
166 | val elementValueCompare = newValueCompare(elementType)
167 | var result = true
168 | var idx: Int = 0
169 | while (idx < len && result) {
170 | result = elementValueCompare(leftArray, rightArray, idx)
171 | idx += 1
172 | }
173 | result
174 | } else false
175 | }
176 | case _ => throw new RuntimeException(s"Cannot convert field to unsupported data type ${dataType}")
177 | }
178 |
179 | /**
180 | * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false.
181 | */
182 | def ~==(right: InternalRow, schema: StructType, epsilon : Float = 1E-6F): Boolean = {
183 | if (schema != null && schema.fields.size == left.numFields && schema.fields.size == right.numFields) {
184 | schema.fields.map(_.dataType).zipWithIndex.forall {
185 | case (dataType, idx) =>
186 | val valueCompare = newValueCompare(dataType)
187 | valueCompare(left, right, idx)
188 | }
189 | }
190 | else false
191 | }
192 | }
193 |
194 | /**
195 | * Implicit class for comparing two rows using absolute tolerance.
196 | */
197 | implicit class RowWithAlmostEquals(val left: Row) {
198 |
199 | /**
200 | * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false.
201 | */
202 | def ~==(right: Row, schema: StructType): Boolean = {
203 | if (schema != null && schema.fields.size == left.size && schema.fields.size == right.size) {
204 | val leftRowWithSchema = new GenericRowWithSchema(left.toSeq.toArray, schema)
205 | val rightRowWithSchema = new GenericRowWithSchema(right.toSeq.toArray, schema)
206 | leftRowWithSchema ~== rightRowWithSchema
207 | }
208 | else false
209 | }
210 |
211 | /**
212 | * When all fields in row are equal or are within eps, returns true; otherwise, returns false.
213 | */
214 | def ~==(right: Row, epsilon : Float = 1E-6F): Boolean = {
215 | if (left.size === right.size) {
216 | val leftDataTypes = left.schema.fields.map(_.dataType)
217 | val rightDataTypes = right.schema.fields.map(_.dataType)
218 |
219 | (leftDataTypes zip rightDataTypes).zipWithIndex.forall {
220 | case (x, i) if left.get(i) == null || right.get(i) == null =>
221 | left.get(i) == null && right.get(i) == null
222 |
223 | case ((FloatType, FloatType), i) =>
224 | left.getFloat(i) === (right.getFloat(i) +- epsilon)
225 |
226 | case ((DoubleType, DoubleType), i) =>
227 | left.getDouble(i) === (right.getDouble(i) +- epsilon)
228 |
229 | case ((BinaryType, BinaryType), i) =>
230 | left.getAs[Array[Byte]](i).toSeq === right.getAs[Array[Byte]](i).toSeq
231 |
232 | case ((ArrayType(FloatType,_), ArrayType(FloatType,_)), i) =>
233 | val leftArray = ArrayData.toArrayData(left.get(i)).toFloatArray().toSeq
234 | val rightArray = ArrayData.toArrayData(right.get(i)).toFloatArray().toSeq
235 | leftArray ~== (rightArray, epsilon)
236 |
237 | case ((ArrayType(DoubleType,_), ArrayType(DoubleType,_)), i) =>
238 | val leftArray = ArrayData.toArrayData(left.get(i)).toDoubleArray().toSeq
239 | val rightArray = ArrayData.toArrayData(right.get(i)).toDoubleArray().toSeq
240 | leftArray ~== (rightArray, epsilon)
241 |
242 | case ((ArrayType(BinaryType,_), ArrayType(BinaryType,_)), i) =>
243 | val leftArray = ArrayData.toArrayData(left.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq
244 | val rightArray = ArrayData.toArrayData(right.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq
245 | leftArray === rightArray
246 |
247 | case ((ArrayType(ArrayType(FloatType,_),_), ArrayType(ArrayType(FloatType,_),_)), i) =>
248 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr =>
249 | ArrayData.toArrayData(arr).toFloatArray().toSeq
250 | }
251 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr =>
252 | ArrayData.toArrayData(arr).toFloatArray().toSeq
253 | }
254 | leftArrays ~== (rightArrays, epsilon)
255 |
256 | case ((ArrayType(ArrayType(DoubleType,_),_), ArrayType(ArrayType(DoubleType,_),_)), i) =>
257 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr =>
258 | ArrayData.toArrayData(arr).toDoubleArray().toSeq
259 | }
260 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr =>
261 | ArrayData.toArrayData(arr).toDoubleArray().toSeq
262 | }
263 | leftArrays ~== (rightArrays, epsilon)
264 |
265 | case ((ArrayType(ArrayType(BinaryType,_),_), ArrayType(ArrayType(BinaryType,_),_)), i) =>
266 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr =>
267 | ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq
268 | }
269 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr =>
270 | ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq
271 | }
272 | leftArrays === rightArrays
273 |
274 | case((a,b), i) => left.get(i) === right.get(i)
275 | }
276 | }
277 | else false
278 | }
279 | }
280 | }
--------------------------------------------------------------------------------