├── .github └── workflows │ └── spark-tensorflow-distributor.yml ├── LICENSE ├── README.md ├── data_service ├── Dockerfile.tf_std_data_server ├── README.md ├── data_service.yaml.jinja ├── data_service_interfaces.yaml.jinja └── tf_std_data_server.py ├── distribution_strategy ├── Dockerfile.keras_model_to_estimator ├── Dockerfile.tf_std_server ├── README.md ├── keras_model_to_estimator.py ├── keras_model_to_estimator_client.py ├── render_template.py ├── template.yaml.jinja └── tf_std_server.py ├── docker ├── Dockerfile ├── README.md └── mnist.py ├── hadoop ├── README.md ├── pom.xml └── src │ ├── main │ └── java │ │ └── org │ │ └── tensorflow │ │ └── hadoop │ │ ├── example │ │ └── TFRecordFileMRExample.java │ │ ├── io │ │ ├── TFRecordFileInputFormat.java │ │ ├── TFRecordFileOutputFormat.java │ │ ├── TFRecordFileOutputFormatV1.java │ │ └── TFRecordIOConf.java │ │ └── util │ │ ├── Crc32C.java │ │ ├── TFRecordReader.java │ │ └── TFRecordWriter.java │ └── test │ ├── java │ └── org │ │ └── tensorflow │ │ └── hadoop │ │ ├── io │ │ └── TFRecordFileTest.java │ │ └── util │ │ └── TFRecordTest.java │ └── resources │ ├── tf-records │ └── part-m-00000 │ └── zipped-tf-records │ └── part-m-00000.gz ├── images ├── between-graph_replication.png └── chief_worker_stdout.png ├── kubernetes ├── README.md └── template.yaml.jinja ├── marathon ├── README.md └── template.json.jinja ├── render_template.py ├── spark ├── spark-tensorflow-connector │ ├── README.md │ ├── pom.xml │ ├── project │ │ ├── build.properties │ │ └── plugins.sbt │ └── src │ │ ├── main │ │ ├── resources │ │ │ └── META-INF │ │ │ │ └── services │ │ │ │ └── org.apache.spark.sql.sources.DataSourceRegister │ │ └── scala │ │ │ └── org │ │ │ └── tensorflow │ │ │ └── spark │ │ │ └── datasources │ │ │ └── tfrecords │ │ │ ├── DefaultSource.scala │ │ │ ├── TensorFlowInferSchema.scala │ │ │ ├── TensorflowRelation.scala │ │ │ ├── serde │ │ │ ├── DefaultTfRecordRowDecoder.scala │ │ │ ├── DefaultTfRecordRowEncoder.scala │ │ │ ├── FeatureDecoder.scala │ │ │ ├── FeatureEncoder.scala │ │ │ ├── FeatureListDecoder.scala │ │ │ └── FeatureListEncoder.scala │ │ │ └── udf │ │ │ └── DataFrameTfrConverter.scala │ │ └── test │ │ └── scala │ │ └── org │ │ └── tensorflow │ │ └── spark │ │ └── datasources │ │ └── tfrecords │ │ ├── InferSchemaSuite.scala │ │ ├── LocalWriteSuite.scala │ │ ├── SharedSparkSessionSuite.scala │ │ ├── TensorFlowSuite.scala │ │ ├── TestingUtils.scala │ │ ├── serde │ │ ├── FeatureDecoderTest.scala │ │ ├── FeatureEncoderTest.scala │ │ ├── FeatureListDecoderTest.scala │ │ ├── FeatureListEncoderTest.scala │ │ ├── TfRecordRowDecoderTest.scala │ │ └── TfRecordRowEncoderTest.scala │ │ └── udf │ │ └── DataFrameTfrConverterTest.scala └── spark-tensorflow-distributor │ ├── .gitignore │ ├── .style.yapf │ ├── Dockerfile │ ├── README.md │ ├── docker-compose.yaml │ ├── examples │ └── simple │ │ └── example.py │ ├── pylintrc │ ├── requirements.txt │ ├── setup.py │ ├── spark_tensorflow_distributor │ ├── __init__.py │ └── mirrored_strategy_runner.py │ └── tests │ ├── integration │ ├── conftest.py │ ├── pytest.ini │ ├── restart_spark.sh │ ├── run.sh │ ├── run_tests.py │ ├── set_spark_conf.py │ ├── spark_conf │ │ ├── gpuDiscoveryScriptStub.sh │ │ └── spark-base.conf │ ├── start_master.sh │ ├── start_spark.py │ ├── start_worker.sh │ ├── stop_master.sh │ ├── stop_spark.py │ ├── stop_worker.sh │ └── test_mirrored_strategy_runner.py │ └── lint.sh └── swarm ├── README.md └── template.yaml.jinja /.github/workflows/spark-tensorflow-distributor.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run tests when changes are made to the Spark TensorFlow Distributor project 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Spark TensorFlow Distributor 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | paths: 10 | - 'spark/spark-tensorflow-distributor/**' 11 | pull_request: 12 | branches: [ master ] 13 | paths: 14 | - 'spark/spark-tensorflow-distributor/**' 15 | 16 | jobs: 17 | build: 18 | runs-on: ubuntu-${{ matrix.ubuntu-version }} 19 | strategy: 20 | matrix: 21 | python-version: [3.7] 22 | ubuntu-version: [18.04] 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v1 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install Pip Dependencies 30 | run: | 31 | cd spark/spark-tensorflow-distributor/ 32 | python -m pip install --upgrade pip 33 | pip install -r requirements.txt 34 | - name: Lint 35 | run: | 36 | cd spark/spark-tensorflow-distributor/ 37 | chmod +x ./tests/lint.sh 38 | ./tests/lint.sh 39 | 40 | # Disable unit tests for now until we have unit tests to run! 41 | # - name: Run Unit Tests 42 | # run: | 43 | # cd spark/spark-tensorflow-distributor/ 44 | # python -m pytest tests/unit 45 | 46 | - name: Install docker-compose 47 | run: | 48 | sudo rm /usr/local/bin/docker-compose 49 | curl -L https://github.com/docker/compose/releases/download/1.22.0/docker-compose-`uname -s`-`uname -m` > docker-compose 50 | chmod +x docker-compose 51 | sudo mv docker-compose /usr/local/bin 52 | 53 | # Disable integration tests for now becuase GitHub workflows exhausts compute and loses executors 54 | # - name: Run Integration Tests 55 | # run: | 56 | # cd spark/spark-tensorflow-distributor/ 57 | # docker-compose pull 58 | # chmod +x ./tests/integration/run.sh 59 | # ./tests/integration/run.sh --num-workers 2 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Ecosystem 2 | 3 | This repository contains examples for integrating TensorFlow with other 4 | open-source frameworks. The examples are minimal and intended for use as 5 | templates. Users can tailor the templates for their own use-cases. 6 | 7 | If you have any additions or improvements, please create an issue or pull 8 | request. 9 | 10 | ## Contents 11 | 12 | - [docker](docker) - Docker configuration for running TensorFlow on 13 | cluster managers. 14 | - [kubeflow](https://github.com/kubeflow/kubeflow) - A Kubernetes native platform for ML 15 | * A K8s custom resource for running distributed [TensorFlow jobs](https://github.com/kubeflow/kubeflow/blob/master/user_guide.md#submitting-a-tensorflow-training-job) 16 | * Jupyter images for different versions of TensorFlow 17 | * [TFServing](https://github.com/kubeflow/kubeflow/blob/master/user_guide.md#serve-a-model-using-tensorflow-serving) Docker images and K8s templates 18 | - [kubernetes](kubernetes) - Templates for running distributed TensorFlow on 19 | Kubernetes. 20 | - [marathon](marathon) - Templates for running distributed TensorFlow using 21 | Marathon, deployed on top of Mesos. 22 | - [hadoop](hadoop) - TFRecord file InputFormat/OutputFormat for Hadoop MapReduce 23 | and Spark. 24 | - [spark-tensorflow-connector](spark/spark-tensorflow-connector) - Spark TensorFlow Connector 25 | - [spark-tensorflow-distributor](spark/spark-tensorflow-distributor) - Python package that helps users do distributed training with TensorFlow on their Spark clusters. 26 | 27 | ## Distributed TensorFlow 28 | 29 | See the [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed) 30 | documentation for a description of how it works. The examples in this 31 | repository focus on the most common form of distributed training: between-graph 32 | replication with asynchronous updates. 33 | 34 | ### Common Setup for distributed training 35 | 36 | Every distributed training program has some common setup. First, define flags so 37 | that the worker knows about other workers and knows what role it plays in 38 | distributed training: 39 | 40 | ```python 41 | # Flags for configuring the task 42 | flags.DEFINE_integer("task_index", None, 43 |                      "Worker task index, should be >= 0. task_index=0 is " 44 |                      "the master worker task the performs the variable " 45 |                      "initialization.") 46 | flags.DEFINE_string("ps_hosts", None, 47 |                     "Comma-separated list of hostname:port pairs") 48 | flags.DEFINE_string("worker_hosts", None, 49 |                     "Comma-separated list of hostname:port pairs") 50 | flags.DEFINE_string("job_name", None, "job name: worker or ps") 51 | ``` 52 | 53 | Then, start your server. Since worker and parameter servers (ps jobs) usually 54 | share a common program, parameter servers should stop at this point and so they 55 | are joined with the server. 56 | 57 | ```python 58 | # Construct the cluster and start the server 59 | ps_spec = FLAGS.ps_hosts.split(",") 60 | worker_spec = FLAGS.worker_hosts.split(",") 61 | 62 | cluster = tf.train.ClusterSpec({ 63 |     "ps": ps_spec, 64 |     "worker": worker_spec}) 65 | 66 | server = tf.train.Server( 67 |     cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) 68 | 69 | if FLAGS.job_name == "ps": 70 |   server.join() 71 | ``` 72 | 73 | Afterwards, your code varies depending on the form of distributed training you 74 | intend on doing. The most common form is between-graph replication. 75 | 76 | ### Between-graph Replication 77 | 78 | In this mode, each worker separately constructs the exact same graph. Each 79 | worker then runs the graph in isolation, only sharing gradients with the 80 | parameter servers. This set up is illustrated by the following diagram. Please 81 | note that each dashed box indicates a task. 82 | ![Diagram for Between-graph replication](images/between-graph_replication.png "Between-graph Replication") 83 | 84 | You must explicitly set the device before graph construction for this mode of 85 | training. The following code snippet from the 86 | [Distributed TensorFlow tutorial](https://www.tensorflow.org/deploy/distributed) 87 | demonstrates the setup: 88 | 89 | ```python 90 | with tf.device(tf.train.replica_device_setter( 91 | worker_device="/job:worker/task:%d" % FLAGS.task_index, 92 | cluster=cluster)): 93 | # Construct the TensorFlow graph. 94 | 95 | # Run the TensorFlow graph. 96 | ``` 97 | 98 | ### Requirements To Run the Examples 99 | 100 | To run our examples, [Jinja templates](http://jinja.pocoo.org/) must be installed: 101 | 102 | ```sh 103 | # On Ubuntu 104 | sudo apt-get install python-jinja2 105 | 106 | # On most other platforms 107 | sudo pip install Jinja2 108 | ``` 109 | 110 | Jinja is used for template expansion. There are other framework-specific 111 | requirements, please refer to the README page of each framework. 112 | -------------------------------------------------------------------------------- /data_service/Dockerfile.tf_std_data_server: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:nightly 2 | 3 | COPY tf_std_data_server.py / 4 | ENTRYPOINT ["python", "-u", "/tf_std_data_server.py"] 5 | -------------------------------------------------------------------------------- /data_service/README.md: -------------------------------------------------------------------------------- 1 | # Distributed input processing with tf.data service. 2 | 3 | This directory provides an example of running the tf.data service to 4 | horizontally scale tf.data input processing. We use GKE 5 | (Google Kubernetes Engine) to manage the tf.data servers. 6 | 7 | This directory contains the following files: 8 | 9 | - `Dockerfile.tf_std_data_server`: A dockerfile to build a tf.data server image. 10 | - `data_service.yaml.jinja`: A Jinja-templated Kubernetes definition for running 11 | tf.data service servers 12 | - `data_service_interfaces.yaml.jinja`: A Jinja-templated Kubernetes definition 13 | for creating load balancers which expose the tf.data service endpoints 14 | outside the GKE cluster (but within the same VPC network). This is needed 15 | for TPUs to be able to connect to servers running in GKE. 16 | - `tf_std_data_server.py`: A basic tf.data server implementation. 17 | 18 | ## Run the tf.data service in GKE 19 | 20 | ### Start a GKE cluster 21 | 22 | If you don't already have a [GKE](https://cloud.google.com/kubernetes-engine) 23 | cluster, create one: 24 | 25 | Replace `${CLUSTER_NAME}` with a name of your choice. 26 | Replace `${NUM_NODES}` with the number of tf.data service machines to run, e.g. 27 | `8`. 28 | Replace `${MACHINE_TYPE}` with the machine type to use, e.g. `e2-standard-4` 29 | 30 | ``` 31 | gcloud container clusters create ${CLUSTER_NAME} --zone europe-west4-a \ 32 | --scopes=cloud-platform --enable-ip-alias --num-nodes=${NUM_NODES} \ 33 | --machine-type=${MACHINE_TYPE} 34 | ``` 35 | 36 | `--enable-ip-alias` is needed to be able to connect to the cluster from a TPU. 37 | 38 | ### Create service endpoints 39 | 40 | Set number of workers in `data_service_interfaces` 41 | Edit the variable at the start of `data_service_interfaces.yaml.jinja` to set the number of workers. 42 | {%- set workers = 8 -%} 43 | 44 | Create data service endpoints so that the data service can be accessed from outside GKE. 45 | This requires `jinja2`, install it if you don't have it already: `pip3 install jinja2`. 46 | 47 | ``` 48 | python3 ../render_template.py data_service_interfaces.yaml.jinja | kubectl apply -f - 49 | ``` 50 | 51 | ### Create tf.data server image 52 | 53 | ``` 54 | docker build --no-cache -t gcr.io/${PROJECT_ID}/tf_std_data_server:latest \ 55 | -f Dockerfile.tf_std_data_server . 56 | docker push gcr.io/${PROJECT_ID}/tf_std_data_server:latest 57 | ``` 58 | 59 | ### Start tf.data servers 60 | 61 | Edit `data_service.yaml.jinja`, setting the image variable at the top of the 62 | file to the image created in the previous step, e.g. 63 | `"gcr.io/${PROJECT_ID}/tf_std_data_server:latest"` 64 | 65 | Wait for GKE to assign endpoints for all services created in the "Create service 66 | endpoints" step. This may 67 | take a few minutes. The below command will query all worker endpoints: 68 | 69 | ``` 70 | kubectl get services -o=jsonpath='{"\n"}{range .items[*]}"{.metadata.name}": "{.status.loadBalancer.ingress[*].ip}",{"\n"}{end}{"\n"}' | grep data-service-worker 71 | ``` 72 | 73 | Once the command shows non-empty addresses for all workers, copy the output 74 | of the command into the `ip_mapping` variable at the start of `data_service.yaml.jinja`. 75 | 76 | ``` 77 | {% set ip_mapping = { 78 | "data-service-worker-0": "10.164.0.40", 79 | "data-service-worker-1": "10.164.0.41", 80 | ... 81 | } %} 82 | ``` 83 | 84 | Now launch the tf.data servers: 85 | 86 | ``` 87 | python3 ../render_template.py data_service.yaml.jinja | kubectl apply -f - 88 | ``` 89 | 90 | The service is now ready to use. To find the service address, run 91 | 92 | ``` 93 | kubectl get services data-service-dispatcher 94 | ``` 95 | 96 | and examine the `EXTERNAL-IP` and `PORT(S)` columns. To access the cluster, 97 | you will use the string `'grpc://:'` 98 | 99 | ## Run ResNet using the tf.data service for input. 100 | 101 | The `classifier_trainer.py` script in the [TensorFlow Model 102 | Garden](https://github.com/tensorflow/models) supports using the tf.data service to 103 | get input data. 104 | 105 | To run the script, do the following: 106 | 107 | ``` 108 | git clone https://github.com/tensorflow/models.git 109 | cd models/official/vision/image_classification 110 | ``` 111 | 112 | Edit either `configs/examples/resnet/imagenet/gpu.yaml` or 113 | `configs/examples/resnet/imagenet/tpu.yaml`, 114 | depending on whether you want to run on GPU or TPU. Under the `train_dataset` 115 | and `validation_dataset` sections, update `builder` from `'tfds'` to 116 | `'records'`. Then under the `train_dataset` section, add `tf_data_service: 117 | 'grpc://:'`. 118 | 119 | Finally, run the ResNet model. 120 | 121 | ``` 122 | export PYTHONPATH=/path/to/models 123 | python3 classifier_trainer.py \ 124 | --mode=train_and_eval --model_type=resnet --dataset=imagenet --tpu=$TPU_NAME \ 125 | --model_dir=$MODEL_DIR --data_dir=gs://cloud-tpu-test-datasets/fake_imagenet \ 126 | --config_file=path/to/config 127 | ``` 128 | 129 | ## Restarting tf.data servers 130 | 131 | tf.data servers are meant to live for the duration of a single training job. 132 | When starting a new job, you can use the following commands to stop the tf.data 133 | servers: 134 | 135 | ``` 136 | kubectl get rs --no-headers=true | grep "data-service-" | xargs kubectl delete rs 137 | ``` 138 | 139 | Then to start the servers again, run 140 | 141 | ``` 142 | python3 ../render_template.py data_service.yaml.jinja | kubectl apply -f - 143 | ``` 144 | -------------------------------------------------------------------------------- /data_service/data_service.yaml.jinja: -------------------------------------------------------------------------------- 1 | {%- set image = "gcr.io//tf_std_data_server:latest" -%} 2 | {%- set port = 5050 -%} 3 | {% set ip_mapping = { 4 | } %} 5 | 6 | kind: ReplicaSet 7 | apiVersion: apps/v1 8 | metadata: 9 | name: data-service-dispatcher 10 | spec: 11 | replicas: 1 12 | selector: 13 | matchLabels: 14 | name: data-service-dispatcher 15 | template: 16 | metadata: 17 | labels: 18 | name: data-service-dispatcher 19 | spec: 20 | containers: 21 | - name: tensorflow 22 | image: {{ image }} 23 | ports: 24 | - containerPort: {{ port }} 25 | args: 26 | - "--port={{ port }}" 27 | - "--is_dispatcher=true" 28 | --- 29 | 30 | {% for worker_name, worker_ip in ip_mapping.items() %} 31 | kind: ReplicaSet 32 | apiVersion: extensions/v1beta1 33 | metadata: 34 | name: {{ worker_name }} 35 | spec: 36 | replicas: 1 37 | selector: 38 | matchLabels: 39 | name: {{ worker_name }} 40 | template: 41 | metadata: 42 | labels: 43 | name: {{ worker_name }} 44 | spec: 45 | containers: 46 | - name: tensorflow 47 | image: {{ image }} 48 | ports: 49 | - containerPort: {{ port }} 50 | args: 51 | - "--port={{ port }}" 52 | - "--is_dispatcher=false" 53 | - "--dispatcher_address=data-service-dispatcher:{{ port }}" 54 | - "--worker_address={{ worker_ip }}:{{ port }}" 55 | --- 56 | {% endfor %} 57 | -------------------------------------------------------------------------------- /data_service/data_service_interfaces.yaml.jinja: -------------------------------------------------------------------------------- 1 | {%- set workers = 8 -%} 2 | {%- set port = 5050 -%} 3 | 4 | kind: Service 5 | apiVersion: v1 6 | metadata: 7 | name: data-service-dispatcher 8 | annotations: 9 | cloud.google.com/load-balancer-type: "Internal" 10 | spec: 11 | type: LoadBalancer 12 | selector: 13 | name: data-service-dispatcher 14 | ports: 15 | - port: {{ port }} 16 | targetPort: {{ port }} 17 | protocol: TCP 18 | --- 19 | {% for i in range(workers) %} 20 | kind: Service 21 | apiVersion: v1 22 | metadata: 23 | name: data-service-worker-{{ i }} 24 | annotations: 25 | cloud.google.com/load-balancer-type: "Internal" 26 | spec: 27 | type: LoadBalancer 28 | selector: 29 | name: data-service-worker-{{ i }} 30 | ports: 31 | - port: {{ port }} 32 | targetPort: {{ port }} 33 | protocol: TCP 34 | --- 35 | {% endfor %} 36 | -------------------------------------------------------------------------------- /data_service/tf_std_data_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Run a tf.data service server.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | flags = tf.compat.v1.app.flags 23 | 24 | flags.DEFINE_integer("port", 0, "Port to listen on") 25 | flags.DEFINE_bool("is_dispatcher", False, "Whether to start a dispatcher (as opposed to a worker server") 26 | flags.DEFINE_string("dispatcher_address", "", "The address of the dispatcher. This is only needed when starting a worker server.") 27 | flags.DEFINE_string("worker_address", "", "The address of the worker server. This is only needed when starting a worker server.") 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | def main(unused_argv): 33 | if FLAGS.is_dispatcher: 34 | print("Starting tf.data service dispatcher") 35 | server = tf.data.experimental.service.DispatchServer( 36 | tf.data.experimental.service.DispatcherConfig( 37 | port=FLAGS.port, 38 | protocol="grpc")) 39 | else: 40 | print("Starting tf.data service worker") 41 | server = tf.data.experimental.service.WorkerServer( 42 | tf.data.experimental.service.WorkerConfig( 43 | port=FLAGS.port, 44 | protocol="grpc", 45 | dispatcher_address=FLAGS.dispatcher_address, 46 | worker_address=FLAGS.worker_address)) 47 | server.join() 48 | 49 | 50 | if __name__ == "__main__": 51 | tf.compat.v1.app.run() 52 | -------------------------------------------------------------------------------- /distribution_strategy/Dockerfile.keras_model_to_estimator: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:nightly-gpu 2 | 3 | COPY keras_model_to_estimator.py / 4 | ENTRYPOINT ["python", "/keras_model_to_estimator.py"] 5 | -------------------------------------------------------------------------------- /distribution_strategy/Dockerfile.tf_std_server: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:nightly-gpu 2 | 3 | COPY tf_std_server.py / 4 | ENTRYPOINT ["python", "/tf_std_server.py"] 5 | -------------------------------------------------------------------------------- /distribution_strategy/README.md: -------------------------------------------------------------------------------- 1 | # Multi-worker Training Using Distribution Strategies 2 | 3 | This directory provides an example of running multi-worker training with 4 | Distribution Strategies. 5 | 6 | Please first read the 7 | [documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md#multi-worker-training) 8 | of Distribution Strategy for multi-worker training. We also assume that readers 9 | of this page have experience with [Google Cloud](https://cloud.google.com/) and 10 | its [Kubernetes Engine](https://cloud.google.com/kubernetes-engine/). 11 | 12 | This directory contains the following files: 13 | 14 | - template.yaml.jinja: a jinja template to be rendered into a Kubernetes yaml 15 | file 16 | - Dockerfile.keras_model_to_estimator: a docker file to build the model image 17 | - Dockerfile.tf_std_server: a docker file to build the standard TensorFlow 18 | server image 19 | - keras_model_to_estimator.py: model code to run multi-worker training 20 | - tf_std_server.py: a standard TensorFlow binary 21 | - keras_model_to_estimator_client.py: model code to run in standalone client 22 | mode 23 | 24 | ## Prerequisite 25 | 26 | 1. You first need to have a Google Cloud project, set up a 27 | [service account](https://cloud.google.com/compute/docs/access/service-accounts) 28 | and download its JSON file. Make sure this service account has access to 29 | [Google Cloud Storage](https://cloud.google.com/storage/). 30 | 2. Install 31 | [gcloud commandline tools](https://cloud.google.com/functions/docs/quickstart) 32 | on your workstation and login, set project and zone, etc. 33 | 3. Install kubectl: 34 | 35 | ```bash 36 | gcloud components install kubectl 37 | ``` 38 | 39 | 4. Start a Kubernetes cluster eiter with `gcloud` command or with 40 | [GKE](https://cloud.google.com/kubernetes-engine/) web UI. Optionally you 41 | can add GPUs to each node. 42 | 43 | 5. Set context for `kubectl` so that `kubectl` knows which cluster to use: 44 | 45 | ```bash 46 | kubectl config use-context 47 | ``` 48 | 49 | 6. Install CUDA drivers in your cluster: 50 | 51 | ```bash 52 | kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/stable/nvidia-driver-installer/cos/daemonset-preloaded.yaml 53 | ``` 54 | 55 | 7. Create a Kubernetes secret for the JSON file of your service account: 56 | 57 | ```bash 58 | kubectl create secret generic credential --from-file=key.json= 59 | ``` 60 | 61 | ## How to run the example 62 | 63 | 1. Let's first build the Docker image: 64 | 65 | ```bash 66 | docker build --no-cache -t keras_model_to_estimator:v1 -f Dockerfile.keras_model_to_estimator . 67 | 68 | ``` 69 | 70 | and push the image to 71 | [Google Cloud Container Registery](https://cloud.google.com/container-registry/): 72 | 73 | ```bash 74 | docker tag keras_model_to_estimator:v1 gcr.io//keras_model_to_estimator:v1 75 | docker push gcr.io//keras_model_to_estimator:v1 76 | ``` 77 | 78 | 2. Modify the header of jinja template. You probably want to change `name`, 79 | `image`, `worker_replicas`, `num_gpus_per_worker`, `has_eval`, 80 | `has_tensorboard`, `script` and `cmdline_args`. 81 | 82 | * `name`: name your cluster, e.g. "my-dist-strat-example". 83 | * `image`: the name of your docker image. 84 | * `worker_replicas`: number of workers. 85 | * `num_gpus_per_worker`: number of GPUs per worker, also for the 86 | "evaluator" job if it exists. 87 | * `has_eval`: whether to include a "evaluator" job. If this is False, no 88 | evaluation will be done even though `tf.estimator.train_and_evaluate` is 89 | used. 90 | * `has_tensorboard`: whether to run tensorboard in the cluster. 91 | * `train_dir`: the model directory. 92 | * `script`: the script in the docker image to run. 93 | * `cmdline_args`: the command line arguments passed to the `script` 94 | delimited by spaces. 95 | * `credential_secret_json`: the filename of the json file for your service 96 | account. 97 | * `credential_secret_key`: the name of the Kubernetes secret storing the 98 | credential of your service account. 99 | * `port`: the port for all tasks including tensorboard. 100 | 101 | 3. Start training cluster: 102 | 103 | ```bash 104 | python ../render_template.py template.yaml.jinja | kubectl create -f - 105 | ``` 106 | 107 | You'll see your cluster has started training. You can inspect logs of 108 | workers or use tensorboard to watch your model training. 109 | 110 | ## How to run with standalone client mode 111 | 112 | Please refer to the 113 | [documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md#standalone-client-mode) 114 | of Distribution Strategy for the details of multi-worker training with 115 | standalone client mode. It basically consists of a cluster of standard 116 | TensorFlow servers and a model running on your workstation which connects to the 117 | cluster to request and coordinate training. All the training will be controlled 118 | by the model running on your workstation. 119 | 120 | 1. First install Kubernetes python client: 121 | 122 | ```bash 123 | pip install kubernetes 124 | ``` 125 | 126 | 2. Build a docker image for standard TensorFlow server: 127 | 128 | ```bash 129 | docker build --no-cache -t tf_std_server:v1 -f Dockerfile.tf_std_server . 130 | ``` 131 | 132 | and push it to the container registry as well. 133 | 134 | 135 | 3. Modify the header of jinja template: set `image`, `script` to 136 | `/tf_std_server.py` and `cmdline_args` to empty to run this standard 137 | TensorFlow server on each Kubernetes pod. 138 | 139 | 4. Start the cluster of standard TensorFlow servers: 140 | 141 | ```bash python 142 | ../render_template.py template.yaml.jinja | kubectl create -f - 143 | ``` 144 | 145 | 5. Run the model binary on your workstation: 146 | 147 | ```bash python 148 | keras_model_to_estimator_client.py gs:// 149 | ``` 150 | 151 | You'll find your 152 | model starts training and logs printed on your terminal. 153 | 154 | If you see any authentication issue, it is possibly because your workstation 155 | doesn't have access to the GCS bucket. In this case you can set the 156 | credential pointing to the json file of your service account before you run 157 | the model binary: 158 | 159 | ```bash export 160 | GOOGLE_APPLICATION_CREDENTIALS="" 161 | ``` 162 | -------------------------------------------------------------------------------- /distribution_strategy/keras_model_to_estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """An example of training Keras model with multi-worker strategies.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import sys 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | def input_fn(): 27 | x = np.random.random((1024, 10)) 28 | y = np.random.randint(2, size=(1024, 1)) 29 | x = tf.cast(x, tf.float32) 30 | dataset = tf.data.Dataset.from_tensor_slices((x, y)) 31 | dataset = dataset.repeat(100) 32 | dataset = dataset.batch(32) 33 | return dataset 34 | 35 | 36 | def main(args): 37 | if len(args) < 2: 38 | print('You must specify model_dir for checkpoints such as' 39 | ' /tmp/tfkeras_example/.') 40 | return 41 | 42 | model_dir = args[1] 43 | print('Using %s to store checkpoints.' % model_dir) 44 | 45 | # Define a Keras Model. 46 | model = tf.keras.Sequential() 47 | model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) 48 | model.add(tf.keras.layers.Dense(1, activation='sigmoid')) 49 | 50 | # Compile the model. 51 | optimizer = tf.train.GradientDescentOptimizer(0.2) 52 | model.compile(loss='binary_crossentropy', optimizer=optimizer) 53 | model.summary() 54 | tf.keras.backend.set_learning_phase(True) 55 | 56 | # Define DistributionStrategies and convert the Keras Model to an 57 | # Estimator that utilizes these DistributionStrateges. 58 | # Evaluator is a single worker, so using MirroredStrategy. 59 | config = tf.estimator.RunConfig( 60 | experimental_distribute=tf.contrib.distribute.DistributeConfig( 61 | train_distribute=tf.contrib.distribute.CollectiveAllReduceStrategy( 62 | num_gpus_per_worker=2), 63 | eval_distribute=tf.contrib.distribute.MirroredStrategy( 64 | num_gpus_per_worker=2))) 65 | keras_estimator = tf.keras.estimator.model_to_estimator( 66 | keras_model=model, config=config, model_dir=model_dir) 67 | 68 | # Train and evaluate the model. Evaluation will be skipped if there is not an 69 | # "evaluator" job in the cluster. 70 | tf.estimator.train_and_evaluate( 71 | keras_estimator, 72 | train_spec=tf.estimator.TrainSpec(input_fn=input_fn), 73 | eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)) 74 | 75 | 76 | if __name__ == '__main__': 77 | tf.logging.set_verbosity(tf.logging.INFO) 78 | tf.app.run(argv=sys.argv) 79 | -------------------------------------------------------------------------------- /distribution_strategy/keras_model_to_estimator_client.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """An example of training Keras model with multi-worker strategies.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import sys 21 | import kubernetes 22 | 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | 27 | def _parse_task_name_fn(load_balancer_name): 28 | """Parses task type and id from a service name.""" 29 | splits = load_balancer_name.split('-') 30 | if len(splits) <= 3: 31 | raise ValueError( 32 | "Could recognize load_balancer's name: %r" % load_balancer_name) 33 | task_type = splits[-2] 34 | if task_type not in ['chief', 'worker', 'ps', 'evaluator']: 35 | return None, None 36 | task_id = int(splits[-1]) 37 | assert task_id >= 0 38 | return task_type, task_id 39 | 40 | 41 | def resolve_cluster(port=5000, parse_task_name_fn=_parse_task_name_fn): 42 | """Queries Kubernetes cluster and gets cluster_spec.""" 43 | kubernetes.config.load_kube_config() 44 | v1 = kubernetes.client.CoreV1Api() 45 | ret = v1.list_service_for_all_namespaces() 46 | cluster_spec = {} 47 | for item in ret.items: 48 | if item.status.load_balancer and item.status.load_balancer.ingress: 49 | task_type, task_id = parse_task_name_fn(item.metadata.name) 50 | if not task_type: 51 | continue 52 | if task_type not in cluster_spec: 53 | cluster_spec[task_type] = [] 54 | while len(cluster_spec[task_type]) <= task_id: 55 | cluster_spec[task_type].append(None) 56 | cluster_spec[task_type][task_id] = '%s:%d' % ( 57 | item.status.load_balancer.ingress[0].ip, port) 58 | 59 | if not cluster_spec: 60 | raise ValueError( 61 | "Cannot get cluster_spec. It's possible the cluster is not ready.") 62 | for task_type, targets in cluster_spec.items(): 63 | for target in targets: 64 | if target is None: 65 | raise ValueError( 66 | 'Not all %s tasks are found in the cluster' % task_type) 67 | tf.logging.info('Using cluster_spec %r' % cluster_spec) 68 | return cluster_spec 69 | 70 | 71 | def input_fn(): 72 | x = np.random.random((1024, 10)) 73 | y = np.random.randint(2, size=(1024, 1)) 74 | x = tf.cast(x, tf.float32) 75 | dataset = tf.data.Dataset.from_tensor_slices((x, y)) 76 | dataset = dataset.repeat(100) 77 | dataset = dataset.batch(32) 78 | return dataset 79 | 80 | 81 | def main(args): 82 | if len(args) < 2: 83 | print('You must specify model_dir for checkpoints such as' 84 | ' /tmp/tfkeras_example/.') 85 | return 86 | 87 | model_dir = args[1] 88 | print('Using %s to store checkpoints.' % model_dir) 89 | 90 | # Define a Keras Model. 91 | model = tf.keras.Sequential() 92 | model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) 93 | model.add(tf.keras.layers.Dense(1, activation='sigmoid')) 94 | 95 | # Compile the model. 96 | optimizer = tf.train.GradientDescentOptimizer(0.2) 97 | model.compile(loss='binary_crossentropy', optimizer=optimizer) 98 | model.summary() 99 | tf.keras.backend.set_learning_phase(True) 100 | 101 | # Define DistributionStrategies and convert the Keras Model to an 102 | # Estimator that utilizes these DistributionStrateges. 103 | # Evaluator is a single worker, so using MirroredStrategy. 104 | run_config = tf.estimator.RunConfig( 105 | experimental_distribute=tf.contrib.distribute.DistributeConfig( 106 | train_distribute=tf.contrib.distribute.CollectiveAllReduceStrategy( 107 | num_gpus_per_worker=2), 108 | eval_distribute=tf.contrib.distribute.MirroredStrategy( 109 | num_gpus_per_worker=2), 110 | remote_cluster=resolve_cluster())) 111 | keras_estimator = tf.keras.estimator.model_to_estimator( 112 | keras_model=model, config=run_config, model_dir=model_dir) 113 | 114 | # Train and evaluate the model. Evaluation will be skipped if there is not an 115 | # "evaluator" job in the cluster. 116 | tf.estimator.train_and_evaluate( 117 | keras_estimator, 118 | train_spec=tf.estimator.TrainSpec(input_fn=input_fn), 119 | eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)) 120 | 121 | 122 | if __name__ == '__main__': 123 | tf.logging.set_verbosity(tf.logging.INFO) 124 | tf.app.run(argv=sys.argv) 125 | -------------------------------------------------------------------------------- /distribution_strategy/render_template.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import jinja2 6 | import sys 7 | 8 | if len(sys.argv) != 2: 9 | print("usage: {} [template-file]".format(sys.argv[0]), file=sys.stderr) 10 | sys.exit(1) 11 | with open(sys.argv[1], "r") as f: 12 | print(jinja2.Template(f.read()).render()) 13 | -------------------------------------------------------------------------------- /distribution_strategy/template.yaml.jinja: -------------------------------------------------------------------------------- 1 | {%- set name = "dist-strat-example" -%} 2 | {%- set image = "gcr.io//keras_model_to_estimator:v1" -%} 3 | {%- set worker_replicas = 2 -%} 4 | {%- set ps_replicas = 0 -%} 5 | {%- set num_gpus_per_worker = 2 -%} 6 | {%- set has_eval = False -%} 7 | {%- set has_tensorboard = False -%} 8 | {%- set train_dir = "gs://" -%} 9 | {%- set script = "/keras_model_to_estimator.py" -%} 10 | {%- set cmdline_args = train_dir -%} 11 | {%- set credential_secret_json = "key.json" -%} 12 | {%- set credential_secret_key = "credential" -%} 13 | {%- set port = 5000 -%} 14 | 15 | 16 | {%- set replicas = {"worker": worker_replicas, 17 | "ps": ps_replicas, 18 | "evaluator": has_eval|int, 19 | "tensorboard": has_tensorboard|int} -%} 20 | {% set cmdline_arg_list = cmdline_args.split(" ") %} 21 | 22 | {%- macro worker_hosts() -%} 23 | {%- for i in range(worker_replicas) -%} 24 | {%- if not loop.first -%},{%- endif -%} 25 | \"{{ name }}-worker-{{ i }}:{{ port }}\" 26 | {%- endfor -%} 27 | {%- endmacro -%} 28 | 29 | {%- macro ps_hosts() -%} 30 | {%- for i in range(ps_replicas) -%} 31 | {%- if not loop.first -%},{%- endif -%} 32 | \"{{ name }}-ps-{{ i }}:{{ port }}\" 33 | {%- endfor -%} 34 | {%- endmacro -%} 35 | 36 | {%- macro tf_config(task_type, task_id) -%} 37 | { 38 | \"cluster\": { 39 | \"worker\": [{{ worker_hosts() }}] 40 | {%- if ps_replicas > 0 -%}, \"ps\": [{{ ps_hosts() }}]{%- endif -%} 41 | {%- if has_eval -%}, 42 | \"evaluator\": [\"{{ name }}-evaluator-0:{{ port }}\"]{%- endif -%} 43 | }, 44 | \"task\": { 45 | \"type\": \"{{ task_type }}\", 46 | \"index\": \"{{ task_id }}\" 47 | } 48 | } 49 | {%- endmacro -%} 50 | 51 | {% for job in ["worker", "ps", "evaluator", "tensorboard"] -%} 52 | {%- for i in range(replicas[job]) -%} 53 | kind: Service 54 | apiVersion: v1 55 | metadata: 56 | name: {{ name }}-{{ job }}-{{ i }} 57 | spec: 58 | type: LoadBalancer 59 | selector: 60 | name: {{ name }} 61 | job: {{ job }} 62 | task: "{{ i }}" 63 | ports: 64 | - port: {{ port }} 65 | --- 66 | kind: ReplicationController 67 | apiVersion: v1 68 | metadata: 69 | name: {{ name }}-{{ job }}-{{ i }} 70 | spec: 71 | replicas: 1 72 | template: 73 | metadata: 74 | labels: 75 | name: {{ name }} 76 | job: {{ job }} 77 | task: "{{ i }}" 78 | spec: 79 | containers: 80 | {% if job == "tensorboard" %} 81 | - name: tensorflow 82 | image: tensorflow/tensorflow 83 | {% else %} 84 | - name: tensorflow 85 | image: {{ image }} 86 | resources: 87 | limits: 88 | nvidia.com/gpu: {{ num_gpus_per_worker }} 89 | {% endif %} 90 | env: 91 | {% if job != "tensorboard" %} 92 | - name: TF_CONFIG 93 | value: "{{ tf_config(job, i) }}" 94 | {% endif %} 95 | - name: GOOGLE_APPLICATION_CREDENTIALS 96 | value: "/var/secrets/google/{{ credential_secret_json }}" 97 | ports: 98 | - containerPort: {{ port }} 99 | {% if job == "tensorboard" %} 100 | command: 101 | - "tensorboard" 102 | args: 103 | - "--logdir={{ train_dir }}" 104 | - "--port={{ port }}" 105 | {% else %} 106 | command: 107 | - "/usr/bin/python" 108 | - "{{ script }}" 109 | {%- for cmdline_arg in cmdline_arg_list %} 110 | - "{{ cmdline_arg }}" 111 | {%- endfor -%} 112 | {% endif %} 113 | volumeMounts: 114 | - name: credential 115 | mountPath: /var/secrets/google 116 | volumes: 117 | - name: credential 118 | secret: 119 | secretName: {{ credential_secret_key }} 120 | --- 121 | {% endfor %} 122 | {%- endfor -%} 123 | -------------------------------------------------------------------------------- /distribution_strategy/tf_std_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Run a standard tensorflow server.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | 23 | def main(unused_argv): 24 | # Contrib ops are lazily loaded. So we touch one contrib module to load them 25 | # immediately. 26 | to_import_contrib_ops = tf.contrib.resampler 27 | 28 | # Load you custom ops here before starting the standard TensorFlow server. 29 | 30 | # Start and join the standard TensorFlow server. 31 | tf.contrib.distribute.run_standard_tensorflow_server().join() 32 | 33 | 34 | if __name__ == "__main__": 35 | tf.app.run() 36 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:nightly 2 | 3 | COPY mnist.py / 4 | ENTRYPOINT ["python", "/mnist.py"] 5 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Docker Images 2 | 3 | This directory contains example Dockerfiles to run TensorFlow on cluster 4 | managers. 5 | 6 | - [Dockerfile](Dockerfile) is the most basic example, which just adds a Python 7 | training program on top of the tensorflow/tensorflow Docker image. 8 | - [mnist.py](mnist.py) demonstrates the programmatic setup for distributed 9 | TensorFlow training. 10 | 11 | ## Best Practices 12 | 13 | - Always pin the TensorFlow version with the Docker image tag. This ensures that 14 | TensorFlow updates don't adversely impact your training program for future 15 | runs. 16 | - When creating an image, specify version tags (see below). If you make code 17 | changes, increment the version. Cluster managers will not pull an updated 18 | Docker image if they have them cached. Also, versions ensure that you have 19 | a single copy of the code running for each job. 20 | 21 | ## Building the Docker Files 22 | 23 | First, pick an image name for the job. When running on a cluster manager, you 24 | will want to push your images to a container registry. Note that both the 25 | [Google Container Registry](https://cloud.google.com/container-registry/) 26 | and the [Amazon EC2 Container Registry](https://aws.amazon.com/ecr/) require 27 | special paths. We append `:v1` to version our images. Versioning images is 28 | strongly recommended for reasons described in the best practices section. 29 | 30 | ```sh 31 | docker build -t :v1 -f Dockerfile . 32 | # Use gcloud docker push instead if on Google Container Registry. 33 | docker push :v1 34 | ``` 35 | 36 | If you make any updates to the code, increment the version and rerun the above 37 | commands with the new version. 38 | 39 | ## Running the mnist Example 40 | 41 | The [mnist.py](mnist.py) example reads the mnist data in the TFRecords format. 42 | You can run the [convert_to_records.py](https://github.com/tensorflow/tensorflow/blob/r0.11/tensorflow/examples/how_tos/reading_data/convert_to_records.py) 43 | program to convert mnist data to TFRecords. 44 | 45 | When running distributed TensorFlow, you should upload the converted data to 46 | a common location on distributed storage, such as GCS. 47 | -------------------------------------------------------------------------------- /docker/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import print_function 17 | 18 | import math 19 | import os 20 | 21 | import tensorflow as tf 22 | 23 | from tensorflow.examples.tutorials.mnist import mnist 24 | 25 | flags = tf.app.flags 26 | 27 | # Flags for configuring the task 28 | flags.DEFINE_string("job_name", None, "job name: worker or ps") 29 | flags.DEFINE_integer("task_index", 0, 30 | "Worker task index, should be >= 0. task_index=0 is " 31 | "the chief worker task the performs the variable " 32 | "initialization") 33 | flags.DEFINE_string("ps_hosts", "", 34 | "Comma-separated list of hostname:port pairs") 35 | flags.DEFINE_string("worker_hosts", "", 36 | "Comma-separated list of hostname:port pairs") 37 | 38 | # Training related flags 39 | flags.DEFINE_string("data_dir", None, 40 | "Directory where the mnist data is stored") 41 | flags.DEFINE_string("train_dir", None, 42 | "Directory for storing the checkpoints") 43 | flags.DEFINE_integer("hidden1", 128, 44 | "Number of units in the 1st hidden layer of the NN") 45 | flags.DEFINE_integer("hidden2", 128, 46 | "Number of units in the 2nd hidden layer of the NN") 47 | flags.DEFINE_integer("batch_size", 100, "Training batch size") 48 | flags.DEFINE_float("learning_rate", 0.01, "Learning rate") 49 | 50 | FLAGS = flags.FLAGS 51 | TRAIN_FILE = "train.tfrecords" 52 | 53 | 54 | def read_and_decode(filename_queue): 55 | reader = tf.TFRecordReader() 56 | _, serialized_example = reader.read(filename_queue) 57 | features = tf.parse_single_example( 58 | serialized_example, 59 | # Defaults are not specified since both keys are required. 60 | features={ 61 | 'image_raw': tf.FixedLenFeature([], tf.string), 62 | 'label': tf.FixedLenFeature([], tf.int64), 63 | }) 64 | 65 | # Convert from a scalar string tensor (whose single string has 66 | # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape 67 | # [mnist.IMAGE_PIXELS]. 68 | image = tf.decode_raw(features['image_raw'], tf.uint8) 69 | image.set_shape([mnist.IMAGE_PIXELS]) 70 | 71 | # OPTIONAL: Could reshape into a 28x28 image and apply distortions 72 | # here. Since we are not applying any distortions in this 73 | # example, and the next step expects the image to be flattened 74 | # into a vector, we don't bother. 75 | 76 | # Convert from [0, 255] -> [-0.5, 0.5] floats. 77 | image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 78 | 79 | # Convert label from a scalar uint8 tensor to an int32 scalar. 80 | label = tf.cast(features['label'], tf.int32) 81 | 82 | return image, label 83 | 84 | 85 | def inputs(batch_size): 86 | """Reads input data. 87 | 88 | Args: 89 | batch_size: Number of examples per returned batch. 90 | 91 | Returns: 92 | A tuple (images, labels), where: 93 | * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] 94 | in the range [-0.5, 0.5]. 95 | * labels is an int32 tensor with shape [batch_size] with the true label, 96 | a number in the range [0, mnist.NUM_CLASSES). 97 | """ 98 | filename = os.path.join(FLAGS.data_dir, TRAIN_FILE) 99 | 100 | with tf.name_scope('input'): 101 | filename_queue = tf.train.string_input_producer([filename]) 102 | 103 | # Even when reading in multiple threads, share the filename 104 | # queue. 105 | image, label = read_and_decode(filename_queue) 106 | 107 | # Shuffle the examples and collect them into batch_size batches. 108 | # (Internally uses a RandomShuffleQueue.) 109 | # We run this in two threads to avoid being a bottleneck. 110 | images, sparse_labels = tf.train.shuffle_batch( 111 | [image, label], batch_size=batch_size, num_threads=2, 112 | capacity=1000 + 3 * batch_size, 113 | # Ensures a minimum amount of shuffling of examples. 114 | min_after_dequeue=1000) 115 | 116 | return images, sparse_labels 117 | 118 | 119 | def device_and_target(): 120 | # If FLAGS.job_name is not set, we're running single-machine TensorFlow. 121 | # Don't set a device. 122 | if FLAGS.job_name is None: 123 | print("Running single-machine training") 124 | return (None, "") 125 | 126 | # Otherwise we're running distributed TensorFlow. 127 | print("Running distributed training") 128 | if FLAGS.task_index is None or FLAGS.task_index == "": 129 | raise ValueError("Must specify an explicit `task_index`") 130 | if FLAGS.ps_hosts is None or FLAGS.ps_hosts == "": 131 | raise ValueError("Must specify an explicit `ps_hosts`") 132 | if FLAGS.worker_hosts is None or FLAGS.worker_hosts == "": 133 | raise ValueError("Must specify an explicit `worker_hosts`") 134 | 135 | cluster_spec = tf.train.ClusterSpec({ 136 | "ps": FLAGS.ps_hosts.split(","), 137 | "worker": FLAGS.worker_hosts.split(","), 138 | }) 139 | server = tf.train.Server( 140 | cluster_spec, job_name=FLAGS.job_name, task_index=FLAGS.task_index) 141 | if FLAGS.job_name == "ps": 142 | server.join() 143 | 144 | worker_device = "/job:worker/task:{}".format(FLAGS.task_index) 145 | # The device setter will automatically place Variables ops on separate 146 | # parameter servers (ps). The non-Variable ops will be placed on the workers. 147 | return ( 148 | tf.train.replica_device_setter( 149 | worker_device=worker_device, 150 | cluster=cluster_spec), 151 | server.target, 152 | ) 153 | 154 | 155 | def main(unused_argv): 156 | if FLAGS.data_dir is None or FLAGS.data_dir == "": 157 | raise ValueError("Must specify an explicit `data_dir`") 158 | if FLAGS.train_dir is None or FLAGS.train_dir == "": 159 | raise ValueError("Must specify an explicit `train_dir`") 160 | 161 | device, target = device_and_target() 162 | with tf.device(device): 163 | images, labels = inputs(FLAGS.batch_size) 164 | logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2) 165 | loss = mnist.loss(logits, labels) 166 | train_op = mnist.training(loss, FLAGS.learning_rate) 167 | 168 | with tf.train.MonitoredTrainingSession( 169 | master=target, 170 | is_chief=(FLAGS.task_index == 0), 171 | checkpoint_dir=FLAGS.train_dir) as sess: 172 | while not sess.should_stop(): 173 | sess.run(train_op) 174 | 175 | 176 | if __name__ == "__main__": 177 | tf.app.run() 178 | -------------------------------------------------------------------------------- /hadoop/README.md: -------------------------------------------------------------------------------- 1 | # Hadoop MapReduce InputFormat/OutputFormat for TFRecords 2 | 3 | This directory contains a [Apache Hadoop](http://hadoop.apache.org/) MapReduce 4 | InputFormat/OutputFormat implementation for TensorFlow's TFRecords format. 5 | This can also be used with [Apache Spark](http://spark.apache.org/). 6 | 7 | ## Prerequisites 8 | 9 | 1. [Apache Maven](https://maven.apache.org/) 10 | 11 | 2. Tested with Hadoop 2.6.0. Patches are welcome if there are incompatibilities 12 | with your Hadoop version. 13 | 14 | ## Breaking changes 15 | 16 | * 08/20/2018 - Reverted artifactId back to `org.tensorflow.tensorflow-hadoop` 17 | * 05/29/2018 - Changed the artifactId from `org.tensorflow.tensorflow-hadoop` to `org.tensorflow.hadoop` 18 | 19 | ## Build and install 20 | 21 | 1. Compile the code 22 | 23 | ```sh 24 | mvn clean package 25 | ``` 26 | 27 | Alternatively, if you would like to build jars for a different version of TensorFlow, e.g., 1.5.0: 28 | 29 | ```sh 30 | mvn versions:set -DnewVersion=1.5.0 31 | mvn clean package 32 | ``` 33 | 34 | 2. Optionally install (or deploy) the jars 35 | 36 | ```sh 37 | mvn install 38 | ``` 39 | 40 | After installation (or deployment), the package can be used with the following dependency: 41 | 42 | ```xml 43 | 44 | org.tensorflow 45 | tensorflow-hadoop 46 | 1.10.0 47 | 48 | ``` 49 | 50 | 51 | ## Use with MapReduce 52 | The Hadoop MapReduce example can be found [here](src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java). 53 | 54 | ## Use with Apache Spark 55 | The [Spark-TensorFlow-Connector](../spark/spark-tensorflow-connector) uses TensorFlow Hadoop to load and save 56 | TensorFlow's TFRecords format using Spark DataFrames. -------------------------------------------------------------------------------- /hadoop/pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | org.tensorflow 6 | tensorflow-hadoop 7 | jar 8 | 1.10.0 9 | tensorflow-hadoop 10 | https://www.tensorflow.org 11 | TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop 12 | 13 | 14 | UTF-8 15 | 1.6 16 | 1.6 17 | 2.10.1 18 | 3.16.1 19 | 4.13.1 20 | 21 | 22 | 23 | 24 | Apache License Version 2.0 25 | http://www.apache.org/licenses/LICENSE-2.0.txt 26 | 27 | 28 | 29 | 30 | https://github.com/tensorflow/ecosystem.git 31 | git@github.com:tensorflow/ecosystem.git 32 | scm:git:https://github.com/tensorflow/ecosystem.git 33 | 34 | 35 | 36 | 37 | 38 | 39 | org.apache.maven.plugins 40 | maven-gpg-plugin 41 | 1.5 42 | 43 | 44 | sign-artifacts 45 | verify 46 | 47 | sign 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | org.apache.maven.plugins 57 | maven-source-plugin 58 | 2.2.1 59 | 60 | 61 | attach-sources 62 | 63 | jar-no-fork 64 | 65 | 66 | 67 | 68 | 69 | org.apache.maven.plugins 70 | maven-javadoc-plugin 71 | 2.9.1 72 | 73 | 74 | attach-javadocs 75 | 76 | jar 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | org.tensorflow 87 | proto 88 | ${project.version} 89 | 90 | 91 | org.apache.hadoop 92 | hadoop-common 93 | ${hadoop.version} 94 | 95 | 96 | com.google.protobuf 97 | protobuf-java 98 | 99 | 100 | 101 | 102 | org.apache.hadoop 103 | hadoop-mapreduce-client-core 104 | ${hadoop.version} 105 | 106 | 107 | com.google.protobuf 108 | protobuf-java 109 | 110 | 111 | 112 | 113 | com.google.protobuf 114 | protobuf-java 115 | ${protobuf.version} 116 | 117 | 118 | junit 119 | junit 120 | ${junit.version} 121 | test 122 | 123 | 124 | org.apache.hadoop 125 | hadoop-mapreduce-client-jobclient 126 | ${hadoop.version} 127 | test-jar 128 | true 129 | test 130 | 131 | 132 | com.google.protobuf 133 | protobuf-java 134 | 135 | 136 | 137 | 138 | 139 | 142 | 143 | 144 | ossrh 145 | 146 | 147 | 148 | ossrh 149 | https://oss.sonatype.org/content/repositories/snapshots 150 | 151 | 152 | ossrh 153 | https://oss.sonatype.org/service/local/staging/deploy/maven2/ 154 | 155 | 156 | 157 | 158 | 159 | org.apache.maven.plugins 160 | maven-gpg-plugin 161 | 162 | 163 | 164 | 165 | 166 | bintray 167 | 168 | 169 | 170 | bintray 171 | https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0 172 | 173 | 174 | 175 | 176 | 177 | org.apache.maven.plugins 178 | maven-gpg-plugin 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | TensorFlowers 188 | TensorFlow 189 | http://www.tensorflow.org 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /hadoop/src/main/java/org/tensorflow/hadoop/example/TFRecordFileMRExample.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.example; 17 | 18 | import com.google.protobuf.ByteString; 19 | import org.apache.hadoop.mapreduce.InputFormat; 20 | import org.apache.hadoop.mapreduce.OutputFormat; 21 | import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; 22 | import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; 23 | import org.tensorflow.example.*; 24 | import org.tensorflow.hadoop.io.TFRecordFileInputFormat; 25 | import org.tensorflow.hadoop.io.TFRecordFileOutputFormat; 26 | import org.apache.hadoop.conf.Configuration; 27 | import org.apache.hadoop.fs.FileSystem; 28 | import org.apache.hadoop.fs.Path; 29 | import org.apache.hadoop.io.*; 30 | import org.apache.hadoop.mapreduce.Job; 31 | import org.apache.hadoop.mapreduce.Mapper; 32 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 33 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 34 | 35 | import java.io.IOException; 36 | import java.util.Map; 37 | 38 | public class TFRecordFileMRExample { 39 | /** 40 | * Convert from text file to TFRecord file. Each line is converted into two dummy features: the 41 | * content of each line and the starting offset of each line. 42 | */ 43 | static class ToTFRecordMapper extends Mapper { 44 | ToTFRecordMapper(){} 45 | 46 | @Override protected void map(LongWritable key, Text value, 47 | Context context) throws IOException, InterruptedException { 48 | Int64List int64List = Int64List.newBuilder().addValue(key.get()).build(); 49 | Feature offset = Feature.newBuilder().setInt64List(int64List).build(); 50 | 51 | ByteString byteString = ByteString.copyFrom(value.copyBytes()); 52 | BytesList bytesList = BytesList.newBuilder().addValue(byteString).build(); 53 | Feature text = Feature.newBuilder().setBytesList(bytesList).build(); 54 | 55 | Features features = Features.newBuilder() 56 | .putFeature("offset", offset) 57 | .putFeature("text", text) 58 | .build(); 59 | Example example = Example.newBuilder().setFeatures(features).build(); 60 | context.write(new BytesWritable(example.toByteArray()), NullWritable.get()); 61 | } 62 | } 63 | 64 | /** 65 | * Convert from previous TFRecord file to text file. 66 | */ 67 | static class FromTFRecordMapper extends Mapper { 68 | FromTFRecordMapper(){} 69 | 70 | @Override protected void map(BytesWritable key, NullWritable value, 71 | Context context) throws IOException, InterruptedException { 72 | Example example = Example.parseFrom(key.getBytes()); 73 | Map featureMap = example.getFeatures().getFeatureMap(); 74 | byte[] text = featureMap.get("text").getBytesList().getValue(0).toByteArray(); 75 | context.write(NullWritable.get(), new Text(text)); 76 | } 77 | } 78 | 79 | public static boolean convert(String jobName, 80 | Class mapperClass, 81 | Class outputKeyClass, 82 | Class outputValueClass, 83 | Class inFormatClass, 84 | Class outFormatClass, 85 | Path input, 86 | Path output) throws InterruptedException, IOException, ClassNotFoundException { 87 | Configuration conf = new Configuration(); 88 | Job job = Job.getInstance(conf, jobName); 89 | job.setJarByClass(mapperClass); 90 | job.setMapperClass(mapperClass); 91 | job.setNumReduceTasks(0); 92 | 93 | job.setInputFormatClass(inFormatClass); 94 | job.setOutputFormatClass(outFormatClass); 95 | job.setOutputKeyClass(outputKeyClass); 96 | job.setOutputValueClass(outputValueClass); 97 | 98 | final FileSystem fs = FileSystem.get(output.toUri(), conf); 99 | fs.delete(output, true); 100 | FileInputFormat.addInputPath(job, input); 101 | FileOutputFormat.setOutputPath(job, output); 102 | 103 | return job.waitForCompletion(true); 104 | } 105 | 106 | public static void main(String[] args) throws Exception { 107 | String testRoot = "/tmp/tfrecord-file-test"; 108 | if (args.length == 1) { 109 | testRoot = args[0]; 110 | } else if (args.length > 1) { 111 | System.out.println("Usage: TFRecordFileMRExample [path]"); 112 | } 113 | 114 | Path testRootPath = new Path(testRoot); 115 | Path input = new Path(testRootPath, "input.txt"); 116 | Path tfrout = new Path(testRootPath, "output.tfr"); 117 | Path txtout = new Path(testRootPath, "output.txt"); 118 | 119 | convert("ToTFR", ToTFRecordMapper.class, BytesWritable.class, NullWritable.class, 120 | TextInputFormat.class, TFRecordFileOutputFormat.class, input, tfrout); 121 | convert("FromTFR", FromTFRecordMapper.class, NullWritable.class, Text.class, 122 | TFRecordFileInputFormat.class, TextOutputFormat.class, tfrout, txtout); 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileInputFormat.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.io; 17 | 18 | import org.apache.hadoop.fs.Seekable; 19 | import org.apache.hadoop.io.compress.CompressionCodec; 20 | import org.apache.hadoop.io.compress.CompressionCodecFactory; 21 | import org.tensorflow.hadoop.util.TFRecordReader; 22 | import org.apache.hadoop.conf.Configuration; 23 | import org.apache.hadoop.fs.FSDataInputStream; 24 | import org.apache.hadoop.fs.FileSystem; 25 | import org.apache.hadoop.fs.Path; 26 | import org.apache.hadoop.io.BytesWritable; 27 | import org.apache.hadoop.io.NullWritable; 28 | import org.apache.hadoop.mapreduce.InputSplit; 29 | import org.apache.hadoop.mapreduce.JobContext; 30 | import org.apache.hadoop.mapreduce.RecordReader; 31 | import org.apache.hadoop.mapreduce.TaskAttemptContext; 32 | import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; 33 | import org.apache.hadoop.mapreduce.lib.input.FileSplit; 34 | 35 | import java.io.IOException; 36 | import java.io.InputStream; 37 | 38 | public class TFRecordFileInputFormat extends FileInputFormat { 39 | @Override public RecordReader createRecordReader( 40 | InputSplit inputSplit, final TaskAttemptContext context) throws IOException, InterruptedException { 41 | 42 | return new RecordReader() { 43 | private InputStream fsdis; 44 | private TFRecordReader reader; 45 | private long length; 46 | private long begin; 47 | private byte[] current; 48 | 49 | @Override public void initialize(InputSplit split, TaskAttemptContext context) 50 | throws IOException, InterruptedException { 51 | Configuration conf = context.getConfiguration(); 52 | FileSplit fileSplit = (FileSplit) split; 53 | length = fileSplit.getLength(); 54 | begin = fileSplit.getStart(); 55 | 56 | final Path file = fileSplit.getPath(); 57 | CompressionCodecFactory compressionCodecFactory = new CompressionCodecFactory(conf); 58 | CompressionCodec codec = compressionCodecFactory.getCodec(file); 59 | FileSystem fs = file.getFileSystem(conf); 60 | 61 | FSDataInputStream fsIn = fs.open(file, TFRecordIOConf.getBufferSize(conf)); 62 | if (codec != null) { 63 | fsdis = codec.createInputStream(fsIn); 64 | } else { 65 | fsdis = fsIn; 66 | } 67 | reader = new TFRecordReader(fsdis, TFRecordIOConf.getDoCrc32Check(conf)); 68 | } 69 | 70 | @Override public boolean nextKeyValue() throws IOException, InterruptedException { 71 | current = reader.read(); 72 | return current != null; 73 | } 74 | 75 | @Override public BytesWritable getCurrentKey() throws IOException, InterruptedException { 76 | return new BytesWritable(current); 77 | } 78 | 79 | @Override public NullWritable getCurrentValue() throws IOException, InterruptedException { 80 | return NullWritable.get(); 81 | } 82 | 83 | @Override public float getProgress() throws IOException, InterruptedException { 84 | return (((Seekable)fsdis).getPos() - begin) / (length + 1e-6f); 85 | } 86 | 87 | @Override public void close() throws IOException { 88 | fsdis.close(); 89 | } 90 | }; 91 | } 92 | 93 | @Override 94 | protected boolean isSplitable(JobContext context, Path file) { 95 | return false; 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormat.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.io; 17 | 18 | import org.apache.hadoop.io.compress.CompressionCodec; 19 | import org.apache.hadoop.io.compress.GzipCodec; 20 | import org.apache.hadoop.util.ReflectionUtils; 21 | import org.tensorflow.hadoop.util.TFRecordWriter; 22 | import org.apache.hadoop.conf.Configuration; 23 | import org.apache.hadoop.fs.FSDataOutputStream; 24 | import org.apache.hadoop.fs.FileSystem; 25 | import org.apache.hadoop.fs.Path; 26 | import org.apache.hadoop.io.BytesWritable; 27 | import org.apache.hadoop.io.NullWritable; 28 | import org.apache.hadoop.mapreduce.RecordWriter; 29 | import org.apache.hadoop.mapreduce.TaskAttemptContext; 30 | import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; 31 | 32 | import java.io.DataOutputStream; 33 | import java.io.IOException; 34 | 35 | public class TFRecordFileOutputFormat extends FileOutputFormat { 36 | @Override public RecordWriter getRecordWriter( 37 | TaskAttemptContext context) throws IOException, InterruptedException { 38 | Configuration conf = context.getConfiguration(); 39 | boolean isCompressed = getCompressOutput(context); 40 | CompressionCodec codec = null; 41 | String extension = ""; 42 | if (isCompressed) { 43 | Class codecClass = getOutputCompressorClass(context, GzipCodec.class); 44 | codec = ReflectionUtils.newInstance(codecClass, conf); 45 | extension = codec.getDefaultExtension(); 46 | } 47 | int bufferSize = TFRecordIOConf.getBufferSize(conf); 48 | Path file = getDefaultWorkFile(context, extension); 49 | FileSystem fs = file.getFileSystem(conf); 50 | FSDataOutputStream fsDataOutputStream = fs.create(file, true, bufferSize); 51 | final DataOutputStream fsdos = isCompressed ? 52 | new DataOutputStream(codec.createOutputStream(fsDataOutputStream)) : fsDataOutputStream; 53 | final TFRecordWriter writer = new TFRecordWriter(fsdos); 54 | return new RecordWriter() { 55 | @Override public void write(BytesWritable key, NullWritable value) 56 | throws IOException, InterruptedException { 57 | writer.write(key.getBytes(), 0, key.getLength()); 58 | } 59 | 60 | @Override public void close(TaskAttemptContext context) 61 | throws IOException, InterruptedException { 62 | fsdos.close(); 63 | } 64 | }; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormatV1.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.io; 17 | 18 | import org.apache.hadoop.fs.FSDataOutputStream; 19 | import org.apache.hadoop.fs.FileSystem; 20 | import org.apache.hadoop.fs.Path; 21 | import org.apache.hadoop.io.BytesWritable; 22 | import org.apache.hadoop.io.NullWritable; 23 | import org.apache.hadoop.io.Writable; 24 | import org.apache.hadoop.mapred.FileOutputFormat; 25 | import org.apache.hadoop.mapred.JobConf; 26 | import org.apache.hadoop.mapred.RecordWriter; 27 | import org.apache.hadoop.mapred.Reporter; 28 | import org.apache.hadoop.util.Progressable; 29 | import org.tensorflow.hadoop.util.TFRecordWriter; 30 | 31 | import java.io.IOException; 32 | 33 | public class TFRecordFileOutputFormatV1 extends FileOutputFormat { 34 | @Override 35 | public RecordWriter getRecordWriter(FileSystem ignored, 36 | JobConf job, String name, 37 | Progressable progress) throws IOException { 38 | Path file = FileOutputFormat.getTaskOutputPath(job, name); 39 | FileSystem fs = file.getFileSystem(job); 40 | 41 | int bufferSize = TFRecordIOConf.getBufferSize(job); 42 | final FSDataOutputStream fsdos = fs.create(file, true, bufferSize); 43 | final TFRecordWriter writer = new TFRecordWriter(fsdos); 44 | return new RecordWriter() { 45 | @Override 46 | public void write(BytesWritable key, Writable value) 47 | throws IOException { 48 | writer.write(key.getBytes(), 0, key.getLength()); 49 | } 50 | 51 | @Override 52 | public void close(Reporter reporter) 53 | throws IOException { 54 | fsdos.close(); 55 | } 56 | }; 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordIOConf.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.io; 17 | 18 | import org.apache.hadoop.conf.Configuration; 19 | 20 | public class TFRecordIOConf { 21 | static int getBufferSize(Configuration conf) { 22 | return conf.getInt("io.file.buffer.size", 4096); 23 | } 24 | 25 | static boolean getDoCrc32Check(Configuration conf) { 26 | return conf.getBoolean("tensorflow.read.crc32check", true); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /hadoop/src/main/java/org/tensorflow/hadoop/util/Crc32C.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.util; 17 | 18 | import org.apache.hadoop.util.PureJavaCrc32C; 19 | 20 | import java.util.zip.Checksum; 21 | 22 | public class Crc32C implements Checksum { 23 | private static final int MASK_DELTA = 0xa282ead8; 24 | private PureJavaCrc32C crc32C; 25 | 26 | public static int maskedCrc32c(byte[] data) { 27 | return maskedCrc32c(data, 0, data.length); 28 | } 29 | 30 | public static int maskedCrc32c(byte[] data, int offset, int length) { 31 | Crc32C crc32c = new Crc32C(); 32 | crc32c.update(data, offset, length); 33 | return crc32c.getMaskedValue(); 34 | } 35 | 36 | /** 37 | * Return a masked representation of crc. 38 | *

39 | * Motivation: it is problematic to compute the CRC of a string that 40 | * contains embedded CRCs. Therefore we recommend that CRCs stored 41 | * somewhere (e.g., in files) should be masked before being stored. 42 | *

43 | * @param crc CRC 44 | * @return masked CRC 45 | */ 46 | public static int mask(int crc) { 47 | // Rotate right by 15 bits and add a constant. 48 | return ((crc >>> 15) | (crc << 17)) + MASK_DELTA; 49 | } 50 | 51 | /** 52 | * Return the crc whose masked representation is masked_crc. 53 | * @param maskedCrc masked CRC 54 | * @return crc whose masked representation is masked_crc 55 | */ 56 | public static int unmask(int maskedCrc) { 57 | int rot = maskedCrc - MASK_DELTA; 58 | return ((rot >>> 17) | (rot << 15)); 59 | } 60 | 61 | public Crc32C() { 62 | crc32C = new PureJavaCrc32C(); 63 | } 64 | 65 | public int getMaskedValue() { 66 | return mask(getIntValue()); 67 | } 68 | 69 | public int getIntValue() { 70 | return (int) getValue(); 71 | } 72 | 73 | @Override public void update(int b) { 74 | crc32C.update(b); 75 | } 76 | 77 | @Override public void update(byte[] b, int off, int len) { 78 | crc32C.update(b, off, len); 79 | } 80 | 81 | @Override public long getValue() { 82 | return crc32C.getValue(); 83 | } 84 | 85 | @Override public void reset() { 86 | crc32C.reset(); 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordReader.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.util; 17 | 18 | import java.io.DataInput; 19 | import java.io.EOFException; 20 | import java.io.IOException; 21 | import java.io.InputStream; 22 | import java.nio.ByteBuffer; 23 | import java.nio.ByteOrder; 24 | 25 | public class TFRecordReader { 26 | private final InputStream input; 27 | private final boolean crcCheck; 28 | 29 | public TFRecordReader(InputStream input, boolean crcCheck) { 30 | this.input = input; 31 | this.crcCheck = crcCheck; 32 | } 33 | 34 | public byte[] read() throws IOException { 35 | /** 36 | * TFRecord format: 37 | * uint64 length 38 | * uint32 masked_crc32_of_length 39 | * byte data[length] 40 | * uint32 masked_crc32_of_data 41 | */ 42 | byte[] lenBytes = new byte[8]; 43 | try { 44 | // Only catch EOF here, other case means corrupted file 45 | readFully(input, lenBytes); 46 | } catch (EOFException eof) { 47 | return null; // return null means EOF 48 | } 49 | Long len = fromInt64LE(lenBytes); 50 | 51 | // Verify length crc32 52 | if (!crcCheck) { 53 | input.skip(4); 54 | } else { 55 | byte[] lenCrc32Bytes = new byte[4]; 56 | readFully(input, lenCrc32Bytes); 57 | int lenCrc32 = fromInt32LE(lenCrc32Bytes); 58 | if (lenCrc32 != Crc32C.maskedCrc32c(lenBytes)) { 59 | throw new IOException("Length header crc32 checking failed: " + lenCrc32 + " != " + 60 | Crc32C.maskedCrc32c(lenBytes) + ", length = " + len); 61 | } 62 | } 63 | 64 | if (len > Integer.MAX_VALUE) { 65 | throw new IOException("Record size exceeds max value of int32: " + len); 66 | } 67 | byte[] data = new byte[len.intValue()]; 68 | readFully(input, data); 69 | 70 | // Verify data crc32 71 | if (!crcCheck) { 72 | input.skip(4); 73 | } else { 74 | byte[] dataCrc32Bytes = new byte[4]; 75 | readFully(input, dataCrc32Bytes); 76 | int dataCrc32 = fromInt32LE(dataCrc32Bytes); 77 | if (dataCrc32 != Crc32C.maskedCrc32c(data)) { 78 | throw new IOException("Data crc32 checking failed: " + dataCrc32 + " != " + 79 | Crc32C.maskedCrc32c(data)); 80 | } 81 | } 82 | return data; 83 | } 84 | 85 | private long fromInt64LE(byte[] data) { 86 | assert data.length == 8; 87 | ByteBuffer bb = ByteBuffer.wrap(data); 88 | bb.order(ByteOrder.LITTLE_ENDIAN); 89 | return bb.getLong(); 90 | } 91 | 92 | private int fromInt32LE(byte[] data) { 93 | assert data.length == 4; 94 | ByteBuffer bb = ByteBuffer.wrap(data); 95 | bb.order(ByteOrder.LITTLE_ENDIAN); 96 | return bb.getInt(); 97 | } 98 | 99 | private void readFully(InputStream in, byte[] buffer) throws IOException { 100 | int nbytes; 101 | for(int nread = 0; nread < buffer.length; nread += nbytes) { 102 | nbytes = in.read(buffer, nread, buffer.length - nread); 103 | if (nbytes < 0) { 104 | throw new EOFException("End of file reached before reading fully."); 105 | } 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /hadoop/src/main/java/org/tensorflow/hadoop/util/TFRecordWriter.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.util; 17 | 18 | import java.io.*; 19 | import java.nio.ByteBuffer; 20 | import java.nio.ByteOrder; 21 | 22 | public class TFRecordWriter { 23 | private final DataOutput output; 24 | 25 | public TFRecordWriter(DataOutput output) { 26 | this.output = output; 27 | } 28 | 29 | public void write(byte[] record, int offset, int length) throws IOException { 30 | /** 31 | * TFRecord format: 32 | * uint64 length 33 | * uint32 masked_crc32_of_length 34 | * byte data[length] 35 | * uint32 masked_crc32_of_data 36 | */ 37 | byte[] len = toInt64LE(length); 38 | output.write(len); 39 | output.write(toInt32LE(Crc32C.maskedCrc32c(len))); 40 | output.write(record, offset, length); 41 | output.write(toInt32LE(Crc32C.maskedCrc32c(record, offset, length))); 42 | } 43 | 44 | public void write(byte[] record) throws IOException { 45 | write(record, 0, record.length); 46 | } 47 | 48 | private byte[] toInt64LE(long data) { 49 | byte[] buff = new byte[8]; 50 | ByteBuffer bb = ByteBuffer.wrap(buff); 51 | bb.order(ByteOrder.LITTLE_ENDIAN); 52 | bb.putLong(data); 53 | return buff; 54 | } 55 | 56 | private byte[] toInt32LE(int data) { 57 | byte[] buff = new byte[4]; 58 | ByteBuffer bb = ByteBuffer.wrap(buff); 59 | bb.order(ByteOrder.LITTLE_ENDIAN); 60 | bb.putInt(data); 61 | return buff; 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /hadoop/src/test/java/org/tensorflow/hadoop/io/TFRecordFileTest.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.io; 17 | 18 | import org.apache.commons.io.FileUtils; 19 | import org.apache.hadoop.conf.Configuration; 20 | import org.apache.hadoop.fs.FileSystem; 21 | import org.apache.hadoop.fs.Path; 22 | import org.apache.hadoop.io.BytesWritable; 23 | import org.apache.hadoop.io.NullWritable; 24 | import org.apache.hadoop.io.compress.GzipCodec; 25 | import org.apache.hadoop.mapreduce.*; 26 | import org.apache.hadoop.mapreduce.task.MapContextImpl; 27 | import org.junit.Test; 28 | import org.tensorflow.example.Example; 29 | import org.tensorflow.example.Feature; 30 | import org.tensorflow.example.Features; 31 | import org.tensorflow.example.Int64List; 32 | 33 | import java.io.File; 34 | import java.util.Map; 35 | import java.util.Random; 36 | import java.util.TreeMap; 37 | 38 | import static org.junit.Assert.assertEquals; 39 | 40 | public class TFRecordFileTest { 41 | private static final int RECORDS = 10000; 42 | 43 | @Test 44 | public void testZippedInputOutputFormat() throws Exception { 45 | testInputOutputFormat(true); 46 | } 47 | 48 | @Test 49 | public void testInputOutputFormat() throws Exception { 50 | testInputOutputFormat(false); 51 | } 52 | 53 | private void testInputOutputFormat(boolean zipped) throws Exception { 54 | Job job = Job.getInstance(new Configuration()); 55 | TaskAttemptContext context = 56 | MapReduceTestUtil.createDummyMapTaskAttemptContext(job.getConfiguration()); 57 | 58 | Random rand = new Random(1234); 59 | Map expectedRecords = new TreeMap(); 60 | for (int i = 0; i < RECORDS; ++i) { 61 | long randValue = rand.nextLong(); 62 | expectedRecords.put((long) i, randValue); 63 | } 64 | 65 | Path dir = new Path(getTmpDirectory().toString(), "tfr-test-zipped-" + zipped); 66 | writeTFRecords(job, 67 | context, 68 | dir, 69 | expectedRecords, 70 | zipped); 71 | 72 | Map records = readTFRecords(job, 73 | context, 74 | dir); 75 | 76 | assertEquals(expectedRecords, records); 77 | deleteDirectory(job, dir); 78 | } 79 | 80 | @Test 81 | public void testWriteSmallTfRecords() throws Exception { 82 | Job job = Job.getInstance(new Configuration()); 83 | TaskAttemptContext context = 84 | MapReduceTestUtil.createDummyMapTaskAttemptContext(job.getConfiguration()); 85 | 86 | Path dir = new Path(getTmpDirectory().toString(), "tfr-test-small"); 87 | writeTFRecords(job, context, dir, getExpectedRecords(), false); 88 | 89 | String fileName = getFileName(job, new Path(getResourcesDirectory().toString(), "tf-records")); 90 | 91 | assertEquals(FileUtils.readFileToString(new File(new File(getResourcesDirectory(), "tf-records"), fileName)), 92 | FileUtils.readFileToString(new File(dir.toString(), fileName))); 93 | 94 | deleteDirectory(job, dir); 95 | } 96 | 97 | @Test 98 | public void testReadSmallTfRecords() throws Exception { 99 | Job job = Job.getInstance(new Configuration()); 100 | TaskAttemptContext context = 101 | MapReduceTestUtil.createDummyMapTaskAttemptContext(job.getConfiguration()); 102 | 103 | Path dir = new Path(getResourcesDirectory().toString(), "tf-records"); 104 | Map records = readTFRecords(job, context, dir); 105 | 106 | assertEquals(getExpectedRecords(), records); 107 | } 108 | 109 | @Test 110 | public void testReadSmallZippedTfRecords() throws Exception { 111 | Job job = Job.getInstance(new Configuration()); 112 | TaskAttemptContext context = 113 | MapReduceTestUtil.createDummyMapTaskAttemptContext(job.getConfiguration()); 114 | 115 | Path dir = new Path(getResourcesDirectory().toString(), "zipped-tf-records"); 116 | Map records = readTFRecords(job, context, dir); 117 | 118 | assertEquals(getExpectedRecords(), records); 119 | } 120 | 121 | private void writeTFRecords(Job job, 122 | TaskAttemptContext context, 123 | Path dir, 124 | Map records, 125 | boolean zipped) throws Exception { 126 | TFRecordFileOutputFormat.setOutputPath(job, dir); 127 | 128 | if (zipped) { 129 | TFRecordFileOutputFormat.setCompressOutput(job, true); 130 | TFRecordFileOutputFormat.setOutputCompressorClass(job, GzipCodec.class); 131 | } 132 | 133 | OutputFormat outputFormat = 134 | new TFRecordFileOutputFormat(); 135 | OutputCommitter committer = outputFormat.getOutputCommitter(context); 136 | committer.setupJob(job); 137 | RecordWriter writer = outputFormat. 138 | getRecordWriter(context); 139 | 140 | try { 141 | for (Map.Entry entry : records.entrySet()) { 142 | Int64List data = Int64List.newBuilder().addValue(entry.getKey()).addValue(entry.getValue()).build(); 143 | Feature feature = Feature.newBuilder().setInt64List(data).build(); 144 | Features features = Features.newBuilder().putFeature("data", feature).build(); 145 | Example example = Example.newBuilder().setFeatures(features).build(); 146 | BytesWritable key = new BytesWritable(example.toByteArray()); 147 | writer.write(key, NullWritable.get()); 148 | } 149 | } finally { 150 | writer.close(context); 151 | } 152 | committer.commitTask(context); 153 | committer.commitJob(job); 154 | } 155 | 156 | private Map readTFRecords(Job job, TaskAttemptContext context, Path dir) throws Exception { 157 | Map records = new TreeMap(); 158 | TFRecordFileInputFormat.setInputPaths(job, dir); 159 | InputFormat inputFormat = new TFRecordFileInputFormat(); 160 | for (InputSplit split : inputFormat.getSplits(job)) { 161 | RecordReader reader = 162 | inputFormat.createRecordReader(split, context); 163 | MapContext mcontext = 164 | new MapContextImpl 165 | (job.getConfiguration(), context.getTaskAttemptID(), reader, null, null, 166 | MapReduceTestUtil.createDummyReporter(), 167 | split); 168 | reader.initialize(split, mcontext); 169 | try { 170 | while (reader.nextKeyValue()) { 171 | BytesWritable bytes = reader.getCurrentKey(); 172 | Example example = Example.parseFrom(bytes.getBytes()); 173 | Int64List data = example.getFeatures().getFeatureMap().get("data").getInt64List(); 174 | Long key = data.getValue(0); 175 | Long value = data.getValue(1); 176 | records.put(key, value); 177 | } 178 | } finally { 179 | reader.close(); 180 | } 181 | } 182 | return records; 183 | } 184 | 185 | private Map getExpectedRecords() { 186 | Map records = new TreeMap(); 187 | records.put(0L, -6519408338692630574L); 188 | records.put(1L, -897291810407650440L); 189 | records.put(2L, -2627029093267243214L); 190 | records.put(3L, 8452912497529882771L); 191 | records.put(4L, 6197228047171027195L); 192 | return records; 193 | } 194 | 195 | private File getResourcesDirectory() { 196 | return new File("src/test/resources"); 197 | } 198 | 199 | private File getTmpDirectory() { 200 | return new File(System.getProperty("test.build.data", "/tmp")); 201 | } 202 | 203 | private void deleteDirectory(Job job, Path dir) throws Exception { 204 | FileSystem fs = dir.getFileSystem(job.getConfiguration()); 205 | fs.delete(dir, true); 206 | } 207 | 208 | private String getFileName(Job job, Path dir) throws Exception { 209 | FileSystem fs = dir.getFileSystem(job.getConfiguration()); 210 | return fs.listFiles(dir, false).next().getPath().getName(); 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /hadoop/src/test/java/org/tensorflow/hadoop/util/TFRecordTest.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.hadoop.util; 17 | 18 | import java.io.*; 19 | 20 | import org.junit.Test; 21 | 22 | import static org.junit.Assert.assertEquals; 23 | import static org.junit.Assert.assertNull; 24 | 25 | public class TFRecordTest { 26 | @Test 27 | public void testTFRecord() throws IOException { 28 | int count = 1000; 29 | ByteArrayOutputStream baos = new ByteArrayOutputStream(); 30 | TFRecordWriter writer = new TFRecordWriter(new DataOutputStream(baos)); 31 | for (int i = 0; i < count; ++i) { 32 | writer.write((Integer.toString(i)).getBytes()); 33 | } 34 | baos.close(); 35 | 36 | ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); 37 | TFRecordReader reader = new TFRecordReader(new DataInputStream(bais), true); 38 | for (int i = 0; i < count; ++i) { 39 | assertEquals(Integer.toString(i), new String(reader.read())); 40 | } 41 | assertNull(reader.read()); // EOF 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /hadoop/src/test/resources/tf-records/part-m-00000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/ecosystem/c39378894bf080843a86edacd0176c53f70d939f/hadoop/src/test/resources/tf-records/part-m-00000 -------------------------------------------------------------------------------- /hadoop/src/test/resources/zipped-tf-records/part-m-00000.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/ecosystem/c39378894bf080843a86edacd0176c53f70d939f/hadoop/src/test/resources/zipped-tf-records/part-m-00000.gz -------------------------------------------------------------------------------- /images/between-graph_replication.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/ecosystem/c39378894bf080843a86edacd0176c53f70d939f/images/between-graph_replication.png -------------------------------------------------------------------------------- /images/chief_worker_stdout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/ecosystem/c39378894bf080843a86edacd0176c53f70d939f/images/chief_worker_stdout.png -------------------------------------------------------------------------------- /kubernetes/README.md: -------------------------------------------------------------------------------- 1 | # Running Distributed TensorFlow on Kubernetes 2 | 3 | This directory contains a template for running distributed TensorFlow on 4 | Kubernetes. 5 | 6 | ## Prerequisites 7 | 8 | 1. You must be running Kubernetes 1.3 or above. If you are running an earlier 9 | version, the DNS addon must be enabled. See the 10 | [Google Container Engine](https://cloud.google.com/container-engine/) if you 11 | want to quickly setup a Kubernetes cluster from scratch. 12 | 13 | 2. [Jinja templates](http://jinja.pocoo.org/) must be installed. 14 | 15 | ## Steps to Run the job 16 | 17 | 1. Follow the instructions for creating the training program in the parent 18 | [README](../README.md). 19 | 20 | 2. Follow the instructions for building and pushing the Docker image in the 21 | [Docker README](../docker/README.md). 22 | 23 | 3. Copy the template file: 24 | 25 | ```sh 26 | cp kubernetes/template.yaml.jinja myjob.template.jinja 27 | ``` 28 | 29 | 4. Edit the `myjob.template.jinja` file to edit job parameters. At the minimum, 30 | you'll want to specify `name`, `image`, `worker_replicas`, `ps_replicas`, 31 | `script`, `data_dir`, and `train_dir`. You may optionally specify 32 | `credential_secret_name` and `credential_secret_key` if you need to read and 33 | write to Google Cloud Storage. See the Google Cloud Storage section below. 34 | 35 | 5. Run the job: 36 | 37 | ```sh 38 | python render_template.py myjob.template.jinja | kubectl create -f - 39 | ``` 40 | 41 | If you later want to stop the job, then run: 42 | ```sh 43 | python render_template.py myjob.template.jinja | kubectl delete -f - 44 | ``` 45 | 46 | ## Google Cloud Storage 47 | 48 | To support reading and writing to Google Cloud Storage, you need to set up 49 | a [Kubernetes secret](http://kubernetes.io/docs/user-guide/secrets/) with the 50 | credentials. 51 | 52 | 1. [Set up a service 53 | account](https://cloud.google.com/vision/docs/common/auth#set_up_a_service_account) 54 | and download the JSON file. 55 | 56 | 2. Add the JSON file as a Kubernetes secret. Replace `[json_filename]` with 57 | the name of the downloaded file: 58 | 59 | ```sh 60 | kubectl create secret generic credential --from-file=[json_filename] 61 | ``` 62 | 63 | 3. In your template, set `credential_secret_name` to `"credential"` (as 64 | specified above) and `credential_secret_key` to the `"[json_filename]"` in 65 | the template. 66 | -------------------------------------------------------------------------------- /kubernetes/template.yaml.jinja: -------------------------------------------------------------------------------- 1 | {%- set name = "" -%} 2 | {%- set image = "" -%} 3 | {%- set worker_replicas = 2 -%} 4 | {%- set ps_replicas = 1 -%} 5 | {%- set script = "" -%} 6 | {%- set data_dir = "" -%} 7 | {%- set train_dir = "" -%} 8 | {%- set credential_secret_name = "" -%} 9 | {%- set credential_secret_key = "" -%} 10 | 11 | {%- set port = 5000 -%} 12 | {%- set replicas = {"worker": worker_replicas, "ps": ps_replicas} -%} 13 | 14 | {%- macro worker_hosts() -%} 15 | {%- for i in range(worker_replicas) -%} 16 | {%- if not loop.first -%},{%- endif -%} 17 | {{ name }}-worker-{{ i }}:{{ port }} 18 | {%- endfor -%} 19 | {%- endmacro -%} 20 | 21 | {%- macro ps_hosts() -%} 22 | {%- for i in range(ps_replicas) -%} 23 | {%- if not loop.first -%},{%- endif -%} 24 | {{ name }}-ps-{{ i }}:{{ port }} 25 | {%- endfor -%} 26 | {%- endmacro -%} 27 | 28 | {%- for job in ["worker", "ps"] -%} 29 | {%- for i in range(replicas[job]) -%} 30 | kind: Service 31 | apiVersion: v1 32 | metadata: 33 | name: {{ name }}-{{ job }}-{{ i }} 34 | spec: 35 | selector: 36 | name: {{ name }} 37 | job: {{ job }} 38 | task: "{{ i }}" 39 | ports: 40 | - port: {{ port }} 41 | --- 42 | kind: ReplicaSet 43 | apiVersion: extensions/v1beta1 44 | metadata: 45 | name: {{ name }}-{{ job }}-{{ i }} 46 | spec: 47 | replicas: 1 48 | template: 49 | metadata: 50 | labels: 51 | name: {{ name }} 52 | job: {{ job }} 53 | task: "{{ i }}" 54 | spec: 55 | containers: 56 | - name: tensorflow 57 | image: {{ image }} 58 | {% if credential_secret_name != "" %} 59 | env: 60 | - name: GOOGLE_APPLICATION_CREDENTIALS 61 | value: "/etc/credential/{{ credential_secret_key }}" 62 | {% endif %} 63 | ports: 64 | - containerPort: {{ port }} 65 | {% if job == "tensorboard" %} 66 | command: 67 | - "tensorboard" 68 | args: 69 | - "--logdir={{ train_dir }}" 70 | {% else %} 71 | command: 72 | - "/usr/bin/python" 73 | - "{{ script }}" 74 | args: 75 | - "--data_dir={{ data_dir }}" 76 | - "--train_dir={{ train_dir }}" 77 | - "--task_index={{ i }}" 78 | - "--job_name={{ job }}" 79 | - "--worker_hosts={{ worker_hosts() }}" 80 | - "--ps_hosts={{ ps_hosts() }}" 81 | {% endif %} 82 | {% if credential_secret_name != "" %} 83 | volumeMounts: 84 | - name: credential 85 | mountPath: /etc/credential 86 | readOnly: true 87 | volumes: 88 | - name: credential 89 | secret: 90 | secretName: {{ credential_secret_name }} 91 | {% endif %} 92 | --- 93 | {% endfor %} 94 | {%- endfor -%} 95 | -------------------------------------------------------------------------------- /marathon/README.md: -------------------------------------------------------------------------------- 1 | # Running Distributed TensorFlow on Mesos/Marathon 2 | 3 | ## Prerequisite 4 | Before you start, you need to set up a Mesos cluster with Marathon installed and Mesos-DNS enabled. It is also preferable to set up some shared storage such as HDFS in the cluster. All of these could be easily installed and configured with the help of [DC/OS](https://dcos.io/docs/1.9/administration/installing/custom/gui/). You need to remember the master target, DNS domain and HDFS namenode which are needed to bring up the TensorFlow cluster. 5 | 6 | ## Write the Training Program 7 | This section covers instructions on how to write your training program and build your docker image. 8 | 9 | 1. Write your own training program. This program must accept `worker_hosts`, `ps_hosts`, `job_name`, `task_index` as command line flags which are then parsed to build `ClusterSpec`. After that, the task either joins with the server or starts building graphs. Please refero to the [main page](../README.md) for code snippets and description of between-graph replication. An example can be found in `docker/mnist.py`. 10 | 11 | In the case of large training input is needed by the training program, we recommend copying your data to shared storage first and then point each worker to the data. You may want to add a flag called `data_dir`. Please refer to the [adding flags](#add-commandline-flags) section for adding this flag into the marathon config. 12 | 13 | 2. Write your own Docker file which simply copies your training program into the image and optionally specify an entrypoint. An example is located in `docker/Dockerfile` or `docker/Dockerfile.hdfs` if you need the HDFS support. TensorBoard can also use the same image, but with a different entry point. 14 | 15 | 3. Build your docker image, push it to a docker repository: 16 | 17 | ```sh 18 | cd docker 19 | docker build -t -f Dockerfile.hdfs . 20 | # Use gcloud docker push instead if on Google Container Registry. 21 | docker push 22 | ``` 23 | Please refer to [docker images](../docker/README.md) page for best practices of building docker images. 24 | 25 | 26 | ## Generate Marathon Config 27 | The Marathon config is generated from a Jinja template where you need to customize your own cluster configuration in the file header. 28 | 29 | 1. Copy over the template file: 30 | 31 | ```sh 32 | cp marathon/template.json.jinja mycluster.json.jinja 33 | ``` 34 | 35 | 2. Edit the `mycluster.json.jinja` file. You need to specify the `name`, `image_name`, `train_dir` and optionally change number of worker and ps replicas. The `train_dir` must point to the directory on shared storage if you would like to use TensorBoard or sharded checkpoint. 36 | 37 | 3. Generate the Marathon json config: 38 | 39 | ```sh 40 | python render_template.py mycluster.json.jinja > mycluster.json 41 | ``` 42 | 43 | ## Start the Tensorflow Cluster 44 | To start the cluster, simply post the Marathon json config file to the Marathon master target which is `marathon.mesos:8080` by default: 45 | 46 | ```sh 47 | curl -i -H 'Content-Type: application/json' -d @mycluster.json http://marathon.mesos:8080/v2/groups 48 | ``` 49 | 50 | You may want to make sure your cluster is running the training program correctly. Navigate to the DC/OS web console and look for stdout or stderr of the chief worker. The `mnist.py` example would print losses for each step and final loss when training is done. 51 | 52 | ![Screenshot of the chief worker](../images/chief_worker_stdout.png "Screenshot of the chief worker") 53 | 54 | If TensorBoard is enabled, navigate to `tensorboard.marathon.mesos:6006` with your browser or find out its IP address from the DC/OS web console. 55 | 56 | 57 | ## Add Commandline Flags 58 | 59 | Let's suppose you would like to add a flag called `data_dir` into the rendered config. Before rendering the template, make following changes: 60 | 61 | 1. Add a variable in the header of `mycluster.json.jinja`: 62 | ``` 63 | {%- set data_dir = "hdfs://namenode/data_dir" %} 64 | ``` 65 | 66 | 2. Add the flag into the `args` section of the template: 67 | ``` 68 | # replace "args": ["--worker_hosts", ...] with 69 | "args": ["--data_dir", {{ data_dir}}, --worker_hosts", ...] 70 | ``` 71 | -------------------------------------------------------------------------------- /marathon/template.json.jinja: -------------------------------------------------------------------------------- 1 | {%- set name = "mnist" %} 2 | {%- set image = "" %} 3 | {%- set worker_replicas = 3 %} 4 | {%- set ps_replicas = 2 -%} 5 | {%- set dns_domain = "marathon.mesos" %} 6 | {%- set port = 2333 %} 7 | {%- set train_dir = "hdfs://namenode/train_dir" %} 8 | {%- set tensorboard = true %} 9 | {%- set cpu = 2 %} 10 | {%- set mem = 4096 %} 11 | 12 | {%- set replicas = {"worker": worker_replicas, "ps": ps_replicas} -%} 13 | 14 | {%- macro worker_hosts() -%} 15 | {%- for i in range(worker_replicas) -%} 16 | {%- if not loop.first -%},{%- endif -%} 17 | worker-{{ i }}-{{ name }}.{{ dns_domain }}:{{ port + i }} 18 | {%- endfor -%} 19 | {%- endmacro -%} 20 | 21 | {%- macro ps_hosts() -%} 22 | {%- for i in range(ps_replicas) -%} 23 | {%- if not loop.first -%},{%- endif -%} 24 | ps-{{ i }}-{{ name }}.{{ dns_domain }}:{{ port + worker_replicas + i }} 25 | {%- endfor -%} 26 | {%- endmacro -%} 27 | 28 | { 29 | "id": "{{ name }}", 30 | "apps": [ 31 | {%- for job in ["worker", "ps"] -%} 32 | {%- for i in range(replicas[job]) %} 33 | { 34 | "id": "{{ job }}-{{ i }}", 35 | "container": { 36 | "docker": { 37 | "image": "{{ image }}" 38 | }, 39 | "network": "HOST", 40 | "type": "MESOS" 41 | }, 42 | "args": ["--worker_hosts", "{{ worker_hosts() }}", "--ps_hosts", "{{ ps_hosts() }}", "--job_name", "{{ job }}", "--task_index", "{{ i }}", "--train_dir", "{{ train_dir }}", "--sync_replica", "True", "--train_steps", "2000"], 43 | "cpus": {{ cpu }}, 44 | "mem": {{ mem }}, 45 | "instances": 1 46 | }{%- if not loop.last or tensorboard %},{% endif %} 47 | {%- endfor %} 48 | {%- endfor %} 49 | { 50 | "id": "tensorboard", 51 | "container": { 52 | "docker": { 53 | "image": "{{ image }}" 54 | }, 55 | "network": "HOST", 56 | "type": "MESOS" 57 | }, 58 | "cmd": "tensorboard --logdir {{ train_dir }}", 59 | "cpus": {{ cpu }}, 60 | "mem": {{ mem }}, 61 | "instances": 1 62 | } 63 | ] 64 | } 65 | -------------------------------------------------------------------------------- /render_template.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import jinja2 6 | import sys 7 | 8 | if len(sys.argv) != 2: 9 | print("usage: {} [template-file]".format(sys.argv[0]), file=sys.stderr) 10 | sys.exit(1) 11 | with open(sys.argv[1], "r") as f: 12 | print(jinja2.Template(f.read()).render()) 13 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.13 -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" 2 | 3 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3") 4 | 5 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.5") 6 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | org.tensorflow.spark.datasources.tfrecords.DefaultSource -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords 17 | 18 | import java.io._ 19 | import java.nio.file.Paths 20 | 21 | import org.apache.hadoop.fs.Path 22 | import org.apache.hadoop.io.{BytesWritable, NullWritable} 23 | import org.apache.spark.rdd.RDD 24 | import org.apache.hadoop.conf.Configuration 25 | import org.apache.spark.sql._ 26 | import org.apache.spark.sql.sources._ 27 | import org.apache.spark.sql.types.StructType 28 | import org.tensorflow.hadoop.io.TFRecordFileOutputFormat 29 | import org.tensorflow.hadoop.util._ 30 | import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder 31 | 32 | /** 33 | * Provides access to TensorFlow record source 34 | */ 35 | class DefaultSource extends DataSourceRegister 36 | with CreatableRelationProvider 37 | with RelationProvider 38 | with SchemaRelationProvider{ 39 | 40 | /** 41 | * Short alias for spark-tensorflow data source. 42 | */ 43 | override def shortName(): String = "tfrecords" 44 | 45 | // Writes DataFrame as TensorFlow Records 46 | override def createRelation( 47 | sqlContext: SQLContext, 48 | mode: SaveMode, 49 | parameters: Map[String, String], 50 | data: DataFrame): BaseRelation = { 51 | 52 | val path = parameters("path") 53 | val codec = parameters.getOrElse("codec", "") 54 | 55 | val recordType = parameters.getOrElse("recordType", "Example") 56 | 57 | //Export DataFrame as TFRecords 58 | val features = data.rdd.map(row => { 59 | recordType match { 60 | case "Example" => 61 | val example = DefaultTfRecordRowEncoder.encodeExample(row) 62 | (new BytesWritable(example.toByteArray), NullWritable.get()) 63 | case "SequenceExample" => 64 | val sequenceExample = DefaultTfRecordRowEncoder.encodeSequenceExample(row) 65 | (new BytesWritable(sequenceExample.toByteArray), NullWritable.get()) 66 | case _ => 67 | throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample") 68 | } 69 | }) 70 | 71 | parameters.getOrElse("writeLocality", "distributed") match { 72 | case "distributed" => 73 | saveDistributed(features, path, sqlContext, mode, codec) 74 | case "local" => 75 | saveLocal(features, path, mode, codec) 76 | case s: String => 77 | throw new IllegalArgumentException( 78 | s"Expected 'distributed' or 'local', got $s") 79 | } 80 | TensorflowRelation(parameters)(sqlContext.sparkSession) 81 | } 82 | 83 | private def save(sqlContext: SQLContext, features: RDD[(BytesWritable, NullWritable)], path: String, codec: String) = { 84 | val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) 85 | if (!codec.isEmpty) { 86 | hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") 87 | hadoopConf.set("mapreduce.output.fileoutputformat.compress.codec", codec) 88 | } 89 | features.saveAsNewAPIHadoopFile( 90 | path, 91 | classOf[NullWritable], 92 | classOf[BytesWritable], 93 | classOf[TFRecordFileOutputFormat], 94 | hadoopConf 95 | ) 96 | } 97 | 98 | private def saveDistributed( 99 | features: RDD[(BytesWritable, NullWritable)], 100 | path: String, 101 | sqlContext: SQLContext, 102 | mode: SaveMode, 103 | codec: String): Unit = { 104 | val hadoopConf = sqlContext.sparkContext.hadoopConfiguration 105 | val outputPath = new Path(path) 106 | val fs = outputPath.getFileSystem(hadoopConf) 107 | val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) 108 | 109 | val pathExists = fs.exists(qualifiedOutputPath) 110 | 111 | mode match { 112 | case SaveMode.Overwrite => 113 | fs.delete(qualifiedOutputPath, true) 114 | save(sqlContext, features, path, codec) 115 | 116 | case SaveMode.Append => 117 | throw new IllegalArgumentException("Append mode is not supported") 118 | 119 | case SaveMode.ErrorIfExists => 120 | if (pathExists) 121 | throw new IllegalStateException( 122 | s"Path $path already exists. SaveMode: ErrorIfExists.") 123 | save(sqlContext, features, path, codec) 124 | 125 | case SaveMode.Ignore => 126 | // With `SaveMode.Ignore` mode, if data already exists, the save operation is expected 127 | // to not save the contents of the DataFrame and to not change the existing data. 128 | // Therefore, it is okay to do nothing here and then just return the relation below. 129 | if (pathExists == false) 130 | save(sqlContext, features, path, codec) 131 | } 132 | } 133 | 134 | private def saveLocal( 135 | features: RDD[(BytesWritable, NullWritable)], 136 | localPath: String, 137 | mode: SaveMode, 138 | codec: String): Unit = { 139 | val cleanedPath = Paths.get(localPath).toAbsolutePath.toString 140 | if (!codec.isEmpty) { 141 | throw new IllegalArgumentException("codec can not be used in local write mode") 142 | } 143 | if (mode == SaveMode.Append) { 144 | throw new IllegalArgumentException("Append mode is not supported in local write mode") 145 | } 146 | // Not supported now, but it should be a small fix eventually. 147 | if (mode == SaveMode.Overwrite) { 148 | throw new IllegalArgumentException("Overwrite mode is not supported in local write mode") 149 | } 150 | 151 | val f = DefaultSource.writePartitionLocalFun(localPath, mode) 152 | 153 | // Perform the action. 154 | features.mapPartitionsWithIndex(f).collect() 155 | } 156 | 157 | // Reads TensorFlow Records into DataFrame with Custom Schema 158 | override def createRelation(sqlContext: SQLContext, 159 | parameters: Map[String, String], 160 | schema: StructType): BaseRelation = { 161 | TensorflowRelation(parameters, Some(schema))(sqlContext.sparkSession) 162 | } 163 | 164 | // Reads TensorFlow Records into DataFrame with schema inferred 165 | override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): TensorflowRelation = { 166 | TensorflowRelation(parameters)(sqlContext.sparkSession) 167 | } 168 | } 169 | 170 | object DefaultSource extends scala.Serializable { 171 | // The function run on each worker. 172 | // Writes the partition to a file and returns the number of records output. 173 | private def writePartitionLocal( 174 | index: Int, 175 | part: Iterator[(BytesWritable, NullWritable)], 176 | localPath: String, 177 | mode: SaveMode): Iterator[Int] = { 178 | val dir = new File(localPath) 179 | if (dir.exists()) { 180 | if (mode == SaveMode.ErrorIfExists) { 181 | throw new IllegalStateException( 182 | s"LocalPath $localPath already exists. SaveMode: ErrorIfExists.") 183 | } 184 | if (mode == SaveMode.Ignore) { 185 | return Iterator.empty 186 | } 187 | } 188 | 189 | // Make the directory if it does not exist 190 | dir.mkdirs() 191 | // The path to the partition file. 192 | val filePath = localPath + s"/part-" + String.format("%05d", java.lang.Integer.valueOf(index)) 193 | val fos = new DataOutputStream(new FileOutputStream(filePath)) 194 | var count = 0 195 | try { 196 | val tfw = new TFRecordWriter(fos) 197 | for((bw, _) <- part) { 198 | tfw.write(bw.getBytes) 199 | count += 1 200 | } 201 | } finally { 202 | fos.close() 203 | } 204 | Iterator(count) 205 | } 206 | 207 | // Working around the closure variable captures. 208 | private def writePartitionLocalFun( 209 | localPath: String, 210 | mode: SaveMode): (Int, Iterator[(BytesWritable, NullWritable)]) => Iterator[Int] = { 211 | def mapFun(index: Int, part: Iterator[(BytesWritable, NullWritable)]) = { 212 | writePartitionLocal(index, part, localPath, mode) 213 | } 214 | mapFun 215 | } 216 | 217 | } 218 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords 17 | 18 | import org.apache.spark.rdd.RDD 19 | import org.apache.spark.sql.types._ 20 | import org.tensorflow.example.{FeatureList, SequenceExample, Example, Feature} 21 | import scala.collection.JavaConverters._ 22 | import scala.collection.mutable 23 | import scala.reflect.runtime.universe._ 24 | 25 | object TensorFlowInferSchema { 26 | 27 | /** 28 | * Similar to the JSON schema inference. 29 | * [[org.apache.spark.sql.execution.datasources.json.InferSchema]] 30 | * 1. Infer type of each row 31 | * 2. Merge row types to find common type 32 | * 3. Replace any null types with string type 33 | */ 34 | def apply[T : TypeTag](rdd: RDD[T]): StructType = { 35 | val startType: mutable.Map[String, DataType] = mutable.Map.empty[String, DataType] 36 | 37 | val rootTypes: mutable.Map[String, DataType] = typeOf[T] match { 38 | case t if t =:= typeOf[Example] => { 39 | rdd.asInstanceOf[RDD[Example]].aggregate(startType)(inferExampleRowType, mergeFieldTypes) 40 | } 41 | case t if t =:= typeOf[SequenceExample] => { 42 | rdd.asInstanceOf[RDD[SequenceExample]].aggregate(startType)(inferSequenceExampleRowType, mergeFieldTypes) 43 | } 44 | case _ => throw new IllegalArgumentException(s"Unsupported recordType: recordType can be Example or SequenceExample") 45 | } 46 | 47 | val columnsList = rootTypes.map { 48 | case (featureName, featureType) => 49 | if (featureType == null) { 50 | StructField(featureName, StringType) 51 | } 52 | else { 53 | StructField(featureName, featureType) 54 | } 55 | } 56 | StructType(columnsList.toSeq) 57 | } 58 | 59 | private def inferSequenceExampleRowType(schemaSoFar: mutable.Map[String, DataType], 60 | next: SequenceExample): mutable.Map[String, DataType] = { 61 | val featureMap = next.getContext.getFeatureMap.asScala 62 | val updatedSchema = inferFeatureTypes(schemaSoFar, featureMap) 63 | 64 | val featureListMap = next.getFeatureLists.getFeatureListMap.asScala 65 | inferFeatureListTypes(updatedSchema, featureListMap) 66 | } 67 | 68 | private def inferExampleRowType(schemaSoFar: mutable.Map[String, DataType], 69 | next: Example): mutable.Map[String, DataType] = { 70 | val featureMap = next.getFeatures.getFeatureMap.asScala 71 | inferFeatureTypes(schemaSoFar, featureMap) 72 | } 73 | 74 | private def inferFeatureTypes(schemaSoFar: mutable.Map[String, DataType], 75 | featureMap: mutable.Map[String, Feature]): mutable.Map[String, DataType] = { 76 | featureMap.foreach { 77 | case (featureName, feature) => { 78 | val currentType = inferField(feature) 79 | if (schemaSoFar.contains(featureName)) { 80 | val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) 81 | schemaSoFar(featureName) = updatedType.orNull 82 | } 83 | else { 84 | schemaSoFar += (featureName -> currentType) 85 | } 86 | } 87 | } 88 | schemaSoFar 89 | } 90 | 91 | def inferFeatureListTypes(schemaSoFar: mutable.Map[String, DataType], 92 | featureListMap: mutable.Map[String, FeatureList]): mutable.Map[String, DataType] = { 93 | featureListMap.foreach { 94 | case (featureName, featureList) => { 95 | val featureType = featureList.getFeatureList.asScala.map(f => inferField(f)) 96 | .reduceLeft((a, b) => findTightestCommonType(a, b).orNull) 97 | val currentType = featureType match { 98 | case ArrayType(_, _) => ArrayType(featureType) 99 | case _ => ArrayType(ArrayType(featureType)) 100 | } 101 | if (schemaSoFar.contains(featureName)) { 102 | val updatedType = findTightestCommonType(schemaSoFar(featureName), currentType) 103 | schemaSoFar(featureName) = updatedType.orNull 104 | } 105 | else { 106 | schemaSoFar += (featureName -> currentType) 107 | } 108 | } 109 | } 110 | schemaSoFar 111 | } 112 | 113 | private def mergeFieldTypes(first: mutable.Map[String, DataType], 114 | second: mutable.Map[String, DataType]): mutable.Map[String, DataType] = { 115 | //Merge two maps and do the comparison. 116 | val mutMap = collection.mutable.Map[String, DataType]((first.keySet ++ second.keySet) 117 | .map(key => (key, findTightestCommonType(first.getOrElse(key, null), second.getOrElse(key, null)).get)) 118 | .toSeq: _*) 119 | mutMap 120 | } 121 | 122 | /** 123 | * Infer Feature datatype based on field number 124 | */ 125 | private def inferField(feature: Feature): DataType = { 126 | feature.getKindCase.getNumber match { 127 | case Feature.BYTES_LIST_FIELD_NUMBER => { 128 | parseBytesList(feature) 129 | } 130 | case Feature.INT64_LIST_FIELD_NUMBER => { 131 | parseInt64List(feature) 132 | } 133 | case Feature.FLOAT_LIST_FIELD_NUMBER => { 134 | parseFloatList(feature) 135 | } 136 | case _ => throw new RuntimeException("unsupported type ...") 137 | } 138 | } 139 | 140 | private def parseBytesList(feature: Feature): DataType = { 141 | val length = feature.getBytesList.getValueCount 142 | 143 | if (length == 0) { 144 | null 145 | } 146 | else if (length > 1) { 147 | ArrayType(StringType) 148 | } 149 | else { 150 | StringType 151 | } 152 | } 153 | 154 | private def parseInt64List(feature: Feature): DataType = { 155 | val int64List = feature.getInt64List.getValueList.asScala.toArray 156 | val length = int64List.length 157 | 158 | if (length == 0) { 159 | null 160 | } 161 | else if (length > 1) { 162 | ArrayType(LongType) 163 | } 164 | else { 165 | LongType 166 | } 167 | } 168 | 169 | private def parseFloatList(feature: Feature): DataType = { 170 | val floatList = feature.getFloatList.getValueList.asScala.toArray 171 | val length = floatList.length 172 | if (length == 0) { 173 | null 174 | } 175 | else if (length > 1) { 176 | ArrayType(FloatType) 177 | } 178 | else { 179 | FloatType 180 | } 181 | } 182 | 183 | /** 184 | * Copied from internal Spark api 185 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] 186 | */ 187 | private def getNumericPrecedence(dataType: DataType): Int = { 188 | dataType match { 189 | case LongType => 1 190 | case FloatType => 2 191 | case StringType => 3 192 | case ArrayType(LongType, _) => 4 193 | case ArrayType(FloatType, _) => 5 194 | case ArrayType(StringType, _) => 6 195 | case ArrayType(ArrayType(LongType, _), _) => 7 196 | case ArrayType(ArrayType(FloatType, _), _) => 8 197 | case ArrayType(ArrayType(StringType, _), _) => 9 198 | case _ => throw new RuntimeException("Unable to get the precedence for given datatype...") 199 | } 200 | } 201 | 202 | /** 203 | * Copied from internal Spark api 204 | * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] 205 | */ 206 | private def findTightestCommonType(tt1: DataType, tt2: DataType) : Option[DataType] = { 207 | val currType = (tt1, tt2) match { 208 | case (t1, t2) if t1 == t2 => Some(t1) 209 | case (null, t2) => Some(t2) 210 | case (t1, null) => Some(t1) 211 | 212 | // Promote types based on numeric precedence 213 | case (t1, t2) => 214 | val t1Precedence = getNumericPrecedence(t1) 215 | val t2Precedence = getNumericPrecedence(t2) 216 | val newType = if (t1Precedence > t2Precedence) t1 else t2 217 | Some(newType) 218 | case _ => None 219 | } 220 | currType 221 | } 222 | } 223 | 224 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorflowRelation.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords 17 | 18 | import org.apache.hadoop.io.{BytesWritable, NullWritable} 19 | import org.apache.spark.rdd.RDD 20 | import org.apache.spark.sql.sources.{BaseRelation, TableScan} 21 | import org.apache.spark.sql.types.StructType 22 | import org.apache.spark.sql.{Row, SQLContext, SparkSession} 23 | import org.tensorflow.example.{SequenceExample, Example} 24 | import org.tensorflow.hadoop.io.TFRecordFileInputFormat 25 | import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowDecoder 26 | 27 | 28 | case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None) 29 | (@transient val session: SparkSession) extends BaseRelation with TableScan { 30 | 31 | //Import TFRecords as DataFrame happens here 32 | lazy val (tfRdd, tfSchema) = { 33 | val rdd = session.sparkContext.newAPIHadoopFile(options("path"), classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable]) 34 | 35 | val recordType = options.getOrElse("recordType", "Example") 36 | 37 | recordType match { 38 | case "Example" => 39 | val exampleRdd = rdd.map{case (bytesWritable, nullWritable) => 40 | Example.parseFrom(bytesWritable.getBytes) 41 | } 42 | val finalSchema = customSchema.getOrElse(TensorFlowInferSchema(exampleRdd)) 43 | val rowRdd = exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeExample(example, finalSchema)) 44 | (rowRdd, finalSchema) 45 | case "SequenceExample" => 46 | val sequenceExampleRdd = rdd.map{case (bytesWritable, nullWritable) => 47 | SequenceExample.parseFrom(bytesWritable.getBytes) 48 | } 49 | val finalSchema = customSchema.getOrElse(TensorFlowInferSchema(sequenceExampleRdd)) 50 | val rowRdd = sequenceExampleRdd.map(example => DefaultTfRecordRowDecoder.decodeSequenceExample(example, finalSchema)) 51 | (rowRdd, finalSchema) 52 | case _ => 53 | throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample") 54 | } 55 | } 56 | 57 | override def sqlContext: SQLContext = session.sqlContext 58 | 59 | override def schema: StructType = tfSchema 60 | 61 | override def buildScan(): RDD[Row] = tfRdd 62 | } 63 | 64 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/DefaultTfRecordRowDecoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords.serde 17 | 18 | import org.apache.spark.sql.types._ 19 | import org.apache.spark.sql.Row 20 | import org.apache.spark.ml.linalg.SQLDataTypes.VectorType 21 | import org.apache.spark.ml.linalg.Vectors 22 | import org.tensorflow.example._ 23 | import scala.collection.JavaConverters._ 24 | 25 | trait TfRecordRowDecoder { 26 | /** 27 | * Decodes each TensorFlow "Example" as DataFrame "Row" 28 | * 29 | * Maps each feature in Example to element in Row with DataType based on custom schema 30 | * 31 | * @param example TensorFlow Example to decode 32 | * @param schema Decode Example using specified schema 33 | * @return a DataFrame row 34 | */ 35 | def decodeExample(example: Example, schema: StructType): Row 36 | 37 | /** 38 | * Decodes each TensorFlow "SequenceExample" as DataFrame "Row" 39 | * 40 | * Maps each feature in SequenceExample to element in Row with DataType based on custom schema 41 | * 42 | * @param sequenceExample TensorFlow SequenceExample to decode 43 | * @param schema Decode SequenceExample using specified schema 44 | * @return a DataFrame row 45 | */ 46 | def decodeSequenceExample(sequenceExample: SequenceExample, schema: StructType): Row 47 | } 48 | 49 | object DefaultTfRecordRowDecoder extends TfRecordRowDecoder { 50 | 51 | /** 52 | * Decodes each TensorFlow "Example" as DataFrame "Row" 53 | * 54 | * Maps each feature in Example to element in Row with DataType based on custom schema 55 | * 56 | * @param example TensorFlow Example to decode 57 | * @param schema Decode Example using specified schema 58 | * @return a DataFrame row 59 | */ 60 | def decodeExample(example: Example, schema: StructType): Row = { 61 | val row = Array.fill[Any](schema.length)(null) 62 | val featureMap = example.getFeatures.getFeatureMap.asScala 63 | schema.fields.zipWithIndex.foreach { 64 | case (field, index) => 65 | val feature = featureMap.get(field.name) 66 | feature match { 67 | case Some(f) => row(index) = decodeFeature(f, schema, index) 68 | case None => if (!field.nullable) throw new NullPointerException(s"Field ${field.name} does not allow null values") 69 | } 70 | } 71 | Row.fromSeq(row) 72 | } 73 | 74 | /** 75 | * Decodes each TensorFlow "SequenceExample" as DataFrame "Row" 76 | * 77 | * Maps each feature in SequenceExample to element in Row with DataType based on custom schema 78 | * 79 | * @param sequenceExample TensorFlow SequenceExample to decode 80 | * @param schema Decode Example using specified schema 81 | * @return a DataFrame row 82 | */ 83 | def decodeSequenceExample(sequenceExample: SequenceExample, schema: StructType): Row = { 84 | val row = Array.fill[Any](schema.length)(null) 85 | 86 | //Decode features 87 | val featureMap = sequenceExample.getContext.getFeatureMap.asScala 88 | val featureListMap = sequenceExample.getFeatureLists.getFeatureListMap.asScala 89 | 90 | schema.fields.zipWithIndex.foreach { 91 | case (field, index) => 92 | val feature = featureMap.get(field.name) 93 | 94 | feature match { 95 | case Some(f) => row(index) = decodeFeature(f, schema, index) 96 | case None => { 97 | featureListMap.get(field.name) match { 98 | case Some(list) => row(index) = decodeFeatureList(list, schema, index) 99 | case None => if (!field.nullable) throw new NullPointerException(s"Field ${field.name} does not allow null values") 100 | } 101 | } 102 | } 103 | } 104 | 105 | Row.fromSeq(row) 106 | } 107 | 108 | // Decode Feature to Scala Type based on field in schema 109 | private def decodeFeature(feature: Feature, schema: StructType, fieldIndex: Int): Any = { 110 | val colDataType = schema.fields(fieldIndex).dataType 111 | 112 | colDataType match { 113 | case IntegerType => IntFeatureDecoder.decode(feature) 114 | case LongType => LongFeatureDecoder.decode(feature) 115 | case FloatType => FloatFeatureDecoder.decode(feature) 116 | case DoubleType => DoubleFeatureDecoder.decode(feature) 117 | case DecimalType() => DecimalFeatureDecoder.decode(feature) 118 | case StringType => StringFeatureDecoder.decode(feature) 119 | case BinaryType => BinaryFeatureDecoder.decode(feature) 120 | case ArrayType(IntegerType, _) => IntListFeatureDecoder.decode(feature) 121 | case ArrayType(LongType, _) => LongListFeatureDecoder.decode(feature) 122 | case ArrayType(FloatType, _) => FloatListFeatureDecoder.decode(feature) 123 | case ArrayType(DoubleType, _) => DoubleListFeatureDecoder.decode(feature) 124 | case ArrayType(DecimalType(), _) => DecimalListFeatureDecoder.decode(feature) 125 | case ArrayType(StringType, _) => StringListFeatureDecoder.decode(feature) 126 | case ArrayType(BinaryType, _) => BinaryListFeatureDecoder.decode(feature) 127 | case VectorType => Vectors.dense(DoubleListFeatureDecoder.decode(feature).toArray) 128 | case _ => throw new scala.RuntimeException(s"Cannot convert Feature to unsupported data type ${colDataType}") 129 | } 130 | } 131 | 132 | // Decode FeatureList to Scala Type based on field in schema 133 | private def decodeFeatureList(featureList: FeatureList, schema: StructType, fieldIndex: Int): Any = { 134 | val colDataType = schema.fields(fieldIndex).dataType 135 | colDataType match { 136 | case ArrayType(ArrayType(IntegerType, _), _) => IntFeatureListDecoder.decode(featureList) 137 | case ArrayType(ArrayType(LongType, _), _) => LongFeatureListDecoder.decode(featureList) 138 | case ArrayType(ArrayType(FloatType, _), _) => FloatFeatureListDecoder.decode(featureList) 139 | case ArrayType(ArrayType(DoubleType, _), _) => DoubleFeatureListDecoder.decode(featureList) 140 | case ArrayType(ArrayType(DecimalType(), _), _) => DecimalFeatureListDecoder.decode(featureList) 141 | case ArrayType(ArrayType(StringType, _), _) => StringFeatureListDecoder.decode(featureList) 142 | case ArrayType(ArrayType(BinaryType, _), _) => BinaryFeatureListDecoder.decode(featureList) 143 | case _ => throw new scala.RuntimeException(s"Cannot convert FeatureList to unsupported data type ${colDataType}") 144 | } 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords.serde 17 | 18 | import org.tensorflow.example._ 19 | import com.google.protobuf.ByteString 20 | 21 | trait FeatureEncoder[T] { 22 | /** 23 | * Encodes input value as TensorFlow "Feature" 24 | * 25 | * Maps input value to one of Int64List, FloatList, BytesList 26 | * 27 | * @param value Input value 28 | * @return TensorFlow Feature 29 | */ 30 | def encode(value: T): Feature 31 | } 32 | 33 | /** 34 | * Encode input value to Int64List 35 | */ 36 | object Int64ListFeatureEncoder extends FeatureEncoder[Seq[Long]] { 37 | override def encode(value: Seq[Long]): Feature = { 38 | val intListBuilder = Int64List.newBuilder() 39 | value.foreach {x => 40 | intListBuilder.addValue(x) 41 | } 42 | val int64List = intListBuilder.build() 43 | Feature.newBuilder().setInt64List(int64List).build() 44 | } 45 | } 46 | 47 | /** 48 | * Encode input value to FloatList 49 | */ 50 | object FloatListFeatureEncoder extends FeatureEncoder[Seq[Float]] { 51 | override def encode(value: Seq[Float]): Feature = { 52 | val floatListBuilder = FloatList.newBuilder() 53 | value.foreach {x => 54 | floatListBuilder.addValue(x) 55 | } 56 | val floatList = floatListBuilder.build() 57 | Feature.newBuilder().setFloatList(floatList).build() 58 | } 59 | } 60 | 61 | /** 62 | * Encode input value to ByteList 63 | */ 64 | object BytesListFeatureEncoder extends FeatureEncoder[Seq[Array[Byte]]] { 65 | override def encode(value: Seq[Array[Byte]]): Feature = { 66 | val bytesListBuilder = BytesList.newBuilder() 67 | value.foreach {x => 68 | bytesListBuilder.addValue(ByteString.copyFrom(x)) 69 | } 70 | val bytesList = bytesListBuilder.build() 71 | Feature.newBuilder().setBytesList(bytesList).build() 72 | } 73 | } -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListDecoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords.serde 17 | 18 | import org.apache.spark.sql.types.Decimal 19 | import org.tensorflow.example.FeatureList 20 | import scala.collection.JavaConverters._ 21 | 22 | trait FeatureListDecoder[T] extends Serializable{ 23 | /** 24 | * Decodes each TensorFlow "FeatureList" to desired Scala type 25 | * 26 | * @param featureList TensorFlow FeatureList 27 | * @return Decoded featureList 28 | */ 29 | def decode(featureList: FeatureList): T 30 | } 31 | 32 | /** 33 | * Decode TensorFlow "FeatureList" to 2-dimensional Integer array 34 | */ 35 | object IntFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Int]]] { 36 | override def decode(featureList: FeatureList): Seq[Seq[Int]] = { 37 | featureList.getFeatureList.asScala.map(x => IntListFeatureDecoder.decode(x)).toSeq 38 | } 39 | } 40 | 41 | /** 42 | * Decode TensorFlow "FeatureList" to 2-dimensional Long array 43 | */ 44 | object LongFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Long]]] { 45 | override def decode(featureList: FeatureList): Seq[Seq[Long]] = { 46 | featureList.getFeatureList.asScala.map(x => LongListFeatureDecoder.decode(x)).toSeq 47 | } 48 | } 49 | 50 | /** 51 | * Decode TensorFlow "FeatureList" to 2-dimensional Float array 52 | */ 53 | object FloatFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Float]]] { 54 | override def decode(featureList: FeatureList): Seq[Seq[Float]] = { 55 | featureList.getFeatureList.asScala.map(x => FloatListFeatureDecoder.decode(x)).toSeq 56 | } 57 | } 58 | 59 | /** 60 | * Decode TensorFlow "FeatureList" to 2-dimensional Double array 61 | */ 62 | object DoubleFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Double]]] { 63 | override def decode(featureList: FeatureList): Seq[Seq[Double]] = { 64 | featureList.getFeatureList.asScala.map(x => DoubleListFeatureDecoder.decode(x)).toSeq 65 | } 66 | } 67 | 68 | /** 69 | * Decode TensorFlow "FeatureList" to 2-dimensional Decimal array 70 | */ 71 | object DecimalFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Decimal]]] { 72 | override def decode(featureList: FeatureList): Seq[Seq[Decimal]] = { 73 | featureList.getFeatureList.asScala.map(x => DecimalListFeatureDecoder.decode(x)).toSeq 74 | } 75 | } 76 | 77 | /** 78 | * Decode TensorFlow "FeatureList" to 2-dimensional String array 79 | */ 80 | object StringFeatureListDecoder extends FeatureListDecoder[Seq[Seq[String]]] { 81 | override def decode(featureList: FeatureList): Seq[Seq[String]] = { 82 | featureList.getFeatureList.asScala.map(x => StringListFeatureDecoder.decode(x)).toSeq 83 | } 84 | } 85 | 86 | /** 87 | * Decode TensorFlow "FeatureList" to 2-dimensional array of Array[Byte] (a 3-dimensional array) 88 | */ 89 | object BinaryFeatureListDecoder extends FeatureListDecoder[Seq[Seq[Array[Byte]]]] { 90 | override def decode(featureList: FeatureList): Seq[Seq[Array[Byte]]] = { 91 | featureList.getFeatureList.asScala.map(x => BinaryListFeatureDecoder.decode(x)).toSeq 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords.serde 17 | 18 | import org.tensorflow.example.FeatureList 19 | 20 | trait FeatureListEncoder[T] extends Serializable{ 21 | /** 22 | * Encodes input value as TensorFlow "FeatureList" 23 | * 24 | * Maps input value to a feature list of type Int64List, FloatList, or BytesList 25 | * 26 | * @param values Input values 27 | * @return TensorFlow FeatureList 28 | */ 29 | def encode(values: T): FeatureList 30 | } 31 | 32 | 33 | /** 34 | * Encode 2-dimensional Long array to TensorFlow "FeatureList" of type Int64List 35 | */ 36 | object Int64FeatureListEncoder extends FeatureListEncoder[Seq[Seq[Long]]] { 37 | def encode(values: Seq[Seq[Long]]) : FeatureList = { 38 | val builder = FeatureList.newBuilder() 39 | values.foreach { x => 40 | val int64list = Int64ListFeatureEncoder.encode(x) 41 | builder.addFeature(int64list) 42 | } 43 | builder.build() 44 | } 45 | } 46 | 47 | /** 48 | * Encode 2-dimensional Float array to TensorFlow "FeatureList" of type FloatList 49 | */ 50 | object FloatFeatureListEncoder extends FeatureListEncoder[Seq[Seq[Float]]] { 51 | def encode(value: Seq[Seq[Float]]) : FeatureList = { 52 | val builder = FeatureList.newBuilder() 53 | value.foreach { x => 54 | val floatList = FloatListFeatureEncoder.encode(x) 55 | builder.addFeature(floatList) 56 | } 57 | builder.build() 58 | } 59 | } 60 | 61 | /** 62 | * Encode 2-dimensional String array to TensorFlow "FeatureList" of type BytesList 63 | */ 64 | object BytesFeatureListEncoder extends FeatureListEncoder[Seq[Seq[Array[Byte]]]] { 65 | def encode(value: Seq[Seq[Array[Byte]]]) : FeatureList = { 66 | val builder = FeatureList.newBuilder() 67 | value.foreach { x => 68 | val bytesList = BytesListFeatureEncoder.encode(x) 69 | builder.addFeature(bytesList) 70 | } 71 | builder.build() 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/udf/DataFrameTfrConverter.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.spark.datasources.tfrecords.udf 2 | 3 | import org.apache.spark.sql.Row 4 | import org.apache.spark.sql.expressions.UserDefinedFunction 5 | import org.apache.spark.sql.functions.udf 6 | import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowEncoder 7 | 8 | object DataFrameTfrConverter { 9 | def getRowToTFRecordExampleUdf: UserDefinedFunction = udf(rowToTFRecordExampleUdf _ ) 10 | 11 | private def rowToTFRecordExampleUdf(row: Row): Array[Byte] = { 12 | DefaultTfRecordRowEncoder.encodeExample(row).toByteArray 13 | } 14 | 15 | def getRowToTFRecordSequenceExampleUdf: UserDefinedFunction = udf(rowToTFRecordSequenceExampleUdf _ ) 16 | 17 | private def rowToTFRecordSequenceExampleUdf(row: Row): Array[Byte] = { 18 | DefaultTfRecordRowEncoder.encodeSequenceExample(row).toByteArray 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/InferSchemaSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords 17 | 18 | import org.apache.spark.rdd.RDD 19 | import org.apache.spark.sql.types._ 20 | import org.tensorflow.example._ 21 | import com.google.protobuf.ByteString 22 | 23 | class InferSchemaSuite extends SharedSparkSessionSuite { 24 | 25 | val longFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(Int.MaxValue + 10L)).build() 26 | val floatFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(10.0F).build()).build() 27 | val strFeature = Feature.newBuilder().setBytesList( 28 | BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes))).build() 29 | 30 | val longList = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(-2L).addValue(20L).build()).build() 31 | val floatList = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(2.5F).addValue(7F).build()).build() 32 | val strList = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)) 33 | .addValue(ByteString.copyFrom("r2".getBytes)).build()).build() 34 | 35 | "InferSchema" should { 36 | 37 | "Infer schema from Example records" in { 38 | //Build example1 39 | val features1 = Features.newBuilder() 40 | .putFeature("LongFeature", longFeature) 41 | .putFeature("FloatFeature", floatFeature) 42 | .putFeature("StrFeature", strFeature) 43 | .putFeature("LongList", longFeature) 44 | .putFeature("FloatList", floatFeature) 45 | .putFeature("StrList", strFeature) 46 | .putFeature("MixedTypeList", longList) 47 | .build() 48 | val example1 = Example.newBuilder() 49 | .setFeatures(features1) 50 | .build() 51 | 52 | //Example2 contains subset of features in example1 to test behavior with missing features 53 | val features2 = Features.newBuilder() 54 | .putFeature("StrFeature", strFeature) 55 | .putFeature("LongList", longList) 56 | .putFeature("FloatList", floatList) 57 | .putFeature("StrList", strList) 58 | .putFeature("MixedTypeList", floatList) 59 | .build() 60 | val example2 = Example.newBuilder() 61 | .setFeatures(features2) 62 | .build() 63 | 64 | val exampleRdd: RDD[Example] = spark.sparkContext.parallelize(List(example1, example2)) 65 | val inferredSchema = TensorFlowInferSchema(exampleRdd) 66 | 67 | //Verify each TensorFlow Datatype is inferred as one of our Datatype 68 | assert(inferredSchema.fields.length == 7) 69 | val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap 70 | assert(schemaMap("LongFeature") === LongType) 71 | assert(schemaMap("FloatFeature") === FloatType) 72 | assert(schemaMap("StrFeature") === StringType) 73 | assert(schemaMap("LongList") === ArrayType(LongType)) 74 | assert(schemaMap("FloatList") === ArrayType(FloatType)) 75 | assert(schemaMap("StrList") === ArrayType(StringType)) 76 | assert(schemaMap("MixedTypeList") === ArrayType(FloatType)) 77 | } 78 | 79 | "Infer schema from SequenceExample records" in { 80 | 81 | //Build sequence example1 82 | val features1 = Features.newBuilder() 83 | .putFeature("FloatFeature", floatFeature) 84 | 85 | val longFeatureList1 = FeatureList.newBuilder().addFeature(longFeature).addFeature(longList).build() 86 | val floatFeatureList1 = FeatureList.newBuilder().addFeature(floatFeature).addFeature(floatList).build() 87 | val strFeatureList1 = FeatureList.newBuilder().addFeature(strFeature).build() 88 | val mixedFeatureList1 = FeatureList.newBuilder().addFeature(floatFeature).addFeature(strList).build() 89 | 90 | val featureLists1 = FeatureLists.newBuilder() 91 | .putFeatureList("LongListOfLists", longFeatureList1) 92 | .putFeatureList("FloatListOfLists", floatFeatureList1) 93 | .putFeatureList("StringListOfLists", strFeatureList1) 94 | .putFeatureList("MixedListOfLists", mixedFeatureList1) 95 | .build() 96 | 97 | val seqExample1 = SequenceExample.newBuilder() 98 | .setContext(features1) 99 | .setFeatureLists(featureLists1) 100 | .build() 101 | 102 | //SequenceExample2 contains subset of features in example1 to test behavior with missing features 103 | val longFeatureList2 = FeatureList.newBuilder().addFeature(longList).build() 104 | val floatFeatureList2 = FeatureList.newBuilder().addFeature(floatFeature).build() 105 | val strFeatureList2 = FeatureList.newBuilder().addFeature(strFeature).build() //test both string lists of length=1 106 | val mixedFeatureList2 = FeatureList.newBuilder().addFeature(longFeature).addFeature(strFeature).build() 107 | 108 | val featureLists2 = FeatureLists.newBuilder() 109 | .putFeatureList("LongListOfLists", longFeatureList2) 110 | .putFeatureList("FloatListOfLists", floatFeatureList2) 111 | .putFeatureList("StringListOfLists", strFeatureList2) 112 | .putFeatureList("MixedListOfLists", mixedFeatureList2) 113 | .build() 114 | 115 | val seqExample2 = SequenceExample.newBuilder() 116 | .setFeatureLists(featureLists2) 117 | .build() 118 | 119 | val seqExampleRdd: RDD[SequenceExample] = spark.sparkContext.parallelize(List(seqExample1, seqExample2)) 120 | val inferredSchema = TensorFlowInferSchema(seqExampleRdd) 121 | 122 | //Verify each TensorFlow Datatype is inferred as one of our Datatype 123 | assert(inferredSchema.fields.length == 5) 124 | val schemaMap = inferredSchema.map(f => (f.name, f.dataType)).toMap 125 | assert(schemaMap("FloatFeature") === FloatType) 126 | assert(schemaMap("LongListOfLists") === ArrayType(ArrayType(LongType))) 127 | assert(schemaMap("FloatListOfLists") === ArrayType(ArrayType(FloatType))) 128 | assert(schemaMap("StringListOfLists") === ArrayType(ArrayType(StringType))) 129 | assert(schemaMap("MixedListOfLists") === ArrayType(ArrayType(StringType))) 130 | } 131 | } 132 | 133 | "Throw an exception for unsupported record types" in { 134 | intercept[Exception] { 135 | val rdd: RDD[Long] = spark.sparkContext.parallelize(List(5L, 6L)) 136 | TensorFlowInferSchema(rdd) 137 | } 138 | 139 | } 140 | } 141 | 142 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/LocalWriteSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords 17 | 18 | import java.nio.file.Files 19 | import java.nio.file.Paths 20 | 21 | import org.apache.spark.sql.Row 22 | import org.apache.spark.sql.catalyst.expressions.GenericRow 23 | import org.apache.spark.sql.types._ 24 | 25 | import org.apache.commons.io.FileUtils 26 | 27 | class LocalWriteSuite extends SharedSparkSessionSuite { 28 | 29 | val testRows: Array[Row] = Array( 30 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 3.0), "r1")), 31 | new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 3.0), "r2")), 32 | new GenericRow(Array[Any](31, 3, 25L, 14.0F, 16.0, List(3.0, 3.0), "r3"))) 33 | val schema = StructType(List(StructField("id", IntegerType), 34 | StructField("IntegerTypeLabel", IntegerType), 35 | StructField("LongTypeLabel", LongType), 36 | StructField("FloatTypeLabel", FloatType), 37 | StructField("DoubleTypeLabel", DoubleType), 38 | StructField("VectorLabel", ArrayType(DoubleType, true)), 39 | StructField("name", StringType))) 40 | 41 | 42 | "Propagate" should { 43 | "write data locally" in { 44 | // Create a dataframe with 2 partitions 45 | val rdd = spark.sparkContext.parallelize(testRows, numSlices = 2) 46 | val df = spark.createDataFrame(rdd, schema) 47 | 48 | // Write the partitions onto the local hard drive. Since it is going to be the 49 | // local file system, the partitions will be written in the same directory of the 50 | // same machine. 51 | // In a distributed setting though, two different machines would each hold a single 52 | // partition. 53 | val localPath = Files.createTempDirectory("spark-connector-propagate").toAbsolutePath.toString 54 | val savePath = localPath + "/testResult" 55 | df.write.format("tfrecords") 56 | .option("recordType", "Example") 57 | .option("writeLocality", "local") 58 | .save(savePath) 59 | 60 | // Read again this directory, this time using the Hadoop file readers, it should 61 | // return the same data. 62 | // This only works in this test and does not hold in general, because the partitions 63 | // will be written on the workers. Everything runs locally for tests. 64 | val df2 = spark.read.format("tfrecords").option("recordType", "Example") 65 | .load(savePath).sort("id").select("id", "IntegerTypeLabel", "LongTypeLabel", 66 | "FloatTypeLabel", "DoubleTypeLabel", "VectorLabel", "name") // Correct column order. 67 | 68 | assert(df2.collect().toSeq === testRows.toSeq) 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/SharedSparkSessionSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords 17 | 18 | import java.io.File 19 | 20 | import org.apache.commons.io.FileUtils 21 | import org.apache.spark.SharedSparkSession 22 | import org.junit.{After, Before} 23 | import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} 24 | 25 | 26 | trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll 27 | 28 | class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { 29 | val TF_SANDBOX_DIR = "tf-sandbox" 30 | val file = new File(TF_SANDBOX_DIR) 31 | 32 | @Before 33 | override def beforeAll() = { 34 | super.setUp() 35 | FileUtils.deleteQuietly(file) 36 | file.mkdirs() 37 | } 38 | 39 | @After 40 | override def afterAll() = { 41 | FileUtils.deleteQuietly(file) 42 | super.tearDown() 43 | } 44 | } 45 | 46 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords 17 | 18 | import org.apache.hadoop.fs.Path 19 | import org.apache.spark.sql.catalyst.expressions.GenericRow 20 | import org.apache.spark.sql.types._ 21 | import org.apache.spark.sql.{DataFrame, Row, SaveMode} 22 | 23 | import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ 24 | 25 | class TensorFlowSuite extends SharedSparkSessionSuite { 26 | 27 | val exampleSchema = StructType(List( 28 | StructField("id", IntegerType), 29 | StructField("IntegerLabel", IntegerType), 30 | StructField("LongLabel", LongType), 31 | StructField("FloatLabel", FloatType), 32 | StructField("DoubleLabel", DoubleType), 33 | StructField("DoubleArrayLabel", ArrayType(DoubleType, true)), 34 | StructField("StrLabel", StringType), 35 | StructField("BinaryLabel", BinaryType))) 36 | 37 | val exampleTestRows: Array[Row] = Array( 38 | new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1", 39 | Array[Byte](0xff.toByte, 0xf0.toByte))), 40 | new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2", 41 | Array[Byte](0xff.toByte, 0xf1.toByte)))) 42 | 43 | 44 | val sequenceExampleTestRows: Array[Row] = Array( 45 | new GenericRow(Array[Any](23L, Seq(Seq(2.0F, 4.5F)), Seq(Seq("r1", "r2")))), 46 | new GenericRow(Array[Any](24L, Seq(Seq(-1.0F, 0F)), Seq(Seq("r3"))))) 47 | 48 | val sequenceExampleSchema = StructType(List( 49 | StructField("id",LongType), 50 | StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), 51 | StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))) 52 | )) 53 | 54 | private def createDataFrameForExampleTFRecord() : DataFrame = { 55 | val rdd = spark.sparkContext.parallelize(exampleTestRows) 56 | spark.createDataFrame(rdd, exampleSchema) 57 | } 58 | 59 | private def createDataFrameForSequenceExampleTFRecords() : DataFrame = { 60 | val rdd = spark.sparkContext.parallelize(sequenceExampleTestRows) 61 | spark.createDataFrame(rdd, sequenceExampleSchema) 62 | } 63 | 64 | 65 | "Spark TensorFlow module" should { 66 | 67 | "Test Import/Export of Example records" in { 68 | val path = s"$TF_SANDBOX_DIR/example.tfrecord" 69 | 70 | val df: DataFrame = createDataFrameForExampleTFRecord() 71 | df.write.format("tfrecords").option("recordType", "Example").save(path) 72 | 73 | //If schema is not provided. It will automatically infer schema 74 | val importedDf: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").schema(exampleSchema).load(path) 75 | val actualDf = importedDf.select("id", "IntegerLabel", "LongLabel", "FloatLabel", 76 | "DoubleLabel", "DoubleArrayLabel", "StrLabel", "BinaryLabel").sort("StrLabel") 77 | 78 | val expectedRows = df.collect() 79 | val actualRows = actualDf.collect() 80 | 81 | expectedRows.zip(actualRows).foreach { case (expected: Row, actual: Row) => 82 | assert(expected ~== actual, exampleSchema) 83 | } 84 | } 85 | 86 | "Test Import/Export of SequenceExample records" in { 87 | 88 | val path = s"$TF_SANDBOX_DIR/sequenceExample.tfrecord" 89 | 90 | val df: DataFrame = createDataFrameForSequenceExampleTFRecords() 91 | df.write.format("tfrecords").option("recordType", "SequenceExample").save(path) 92 | 93 | val importedDf: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").schema(sequenceExampleSchema).load(path) 94 | val actualDf = importedDf.select("id", "FloatArrayOfArrayLabel", "StrArrayOfArrayLabel").sort("id") 95 | 96 | val expectedRows = df.collect() 97 | val actualRows = actualDf.collect() 98 | 99 | assert(expectedRows === actualRows) 100 | } 101 | 102 | "Test overwrite mode during export of Example records" in { 103 | 104 | val path = s"$TF_SANDBOX_DIR/example_overwrite.tfrecord" 105 | 106 | val df: DataFrame = createDataFrameForExampleTFRecord() 107 | df.write.format("tfrecords").option("recordType", "Example").save(path) 108 | 109 | df.write.format("tfrecords").mode(SaveMode.Overwrite).option("recordType", "Example").save(path) 110 | 111 | //If schema is not provided. It will automatically infer schema 112 | val importedDf: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").schema(exampleSchema).load(path) 113 | val actualDf = importedDf.select("id", "IntegerLabel", "LongLabel", "FloatLabel", 114 | "DoubleLabel", "DoubleArrayLabel", "StrLabel", "BinaryLabel").sort("StrLabel") 115 | 116 | val expectedRows = df.collect() 117 | val actualRows = actualDf.collect() 118 | 119 | expectedRows.zip(actualRows).foreach { case (expected: Row, actual: Row) => 120 | assert(expected ~== actual, exampleSchema) 121 | } 122 | 123 | } 124 | 125 | "Test append mode during export of Example records" in { 126 | 127 | val path = s"$TF_SANDBOX_DIR/example_append.tfrecord" 128 | 129 | val df: DataFrame = createDataFrameForExampleTFRecord() 130 | df.write.format("tfrecords").option("recordType", "Example").save(path) 131 | 132 | intercept [IllegalArgumentException] { 133 | df.write.format("tfrecords").mode(SaveMode.Append).option("recordType", "Example").save(path) 134 | } 135 | } 136 | 137 | "Test errorIfExists mode during export of Example records" in { 138 | 139 | val path = s"$TF_SANDBOX_DIR/example_errorIfExists.tfrecord" 140 | 141 | val df: DataFrame = createDataFrameForExampleTFRecord() 142 | df.write.format("tfrecords").mode(SaveMode.ErrorIfExists).option("recordType", "Example").save(path) 143 | 144 | intercept [IllegalStateException] { 145 | df.write.format("tfrecords").mode(SaveMode.ErrorIfExists).option("recordType", "Example").save(path) 146 | } 147 | } 148 | 149 | "Test ignore mode during export of Example records" in { 150 | 151 | val path = s"$TF_SANDBOX_DIR/example_ignore.tfrecord" 152 | 153 | val hadoopConf = spark.sparkContext.hadoopConfiguration 154 | val outputPath = new Path(path) 155 | val fs = outputPath.getFileSystem(hadoopConf) 156 | val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) 157 | 158 | val df: DataFrame = createDataFrameForExampleTFRecord() 159 | df.write.format("tfrecords").mode(SaveMode.Ignore).option("recordType", "Example").save(path) 160 | 161 | assert(fs.exists(qualifiedOutputPath)) 162 | val timestamp1 = fs.getFileStatus(qualifiedOutputPath).getModificationTime 163 | 164 | df.write.format("tfrecords").mode(SaveMode.Ignore).option("recordType", "Example").save(path) 165 | 166 | val timestamp2 = fs.getFileStatus(qualifiedOutputPath).getModificationTime 167 | 168 | assert(timestamp1 == timestamp2, "SaveMode.Ignore Error: File was overwritten. Timestamps do not match") 169 | } 170 | 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/TestingUtils.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords 17 | 18 | import org.apache.spark.sql.Row 19 | import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema 20 | import org.apache.spark.sql.catalyst.util.ArrayData 21 | import org.apache.spark.sql.types._ 22 | import org.scalatest.Matchers 23 | 24 | object TestingUtils extends Matchers { 25 | 26 | /** 27 | * Implicit class for comparing two double values using absolute tolerance. 28 | */ 29 | implicit class FloatArrayWithAlmostEquals(val left: Seq[Float]) { 30 | 31 | /** 32 | * When the difference of two values are within eps, returns true; otherwise, returns false. 33 | */ 34 | def ~==(right: Seq[Float], epsilon : Float = 1E-6F): Boolean = { 35 | if (left.size === right.size) { 36 | (left zip right) forall { case (a, b) => a === (b +- epsilon) } 37 | } 38 | else false 39 | } 40 | } 41 | 42 | /** 43 | * Implicit class for comparing two double values using absolute tolerance. 44 | */ 45 | implicit class DoubleArrayWithAlmostEquals(val left: Seq[Double]) { 46 | 47 | /** 48 | * When the difference of two values are within eps, returns true; otherwise, returns false. 49 | */ 50 | def ~==(right: Seq[Double], epsilon : Double = 1E-6): Boolean = { 51 | if (left.size === right.size) { 52 | (left zip right) forall { case (a, b) => a === (b +- epsilon) } 53 | } 54 | else false 55 | } 56 | } 57 | 58 | /** 59 | * Implicit class for comparing two decimal values using absolute tolerance. 60 | */ 61 | implicit class DecimalArrayWithAlmostEquals(val left: Seq[Decimal]) { 62 | 63 | /** 64 | * When the difference of two values are within eps, returns true; otherwise, returns false. 65 | */ 66 | def ~==(right: Seq[Decimal], epsilon : Double = 1E-6): Boolean = { 67 | if (left.size === right.size) { 68 | (left zip right) forall { case (a, b) => a.toDouble === (b.toDouble +- epsilon) } 69 | } 70 | else false 71 | } 72 | } 73 | 74 | /** 75 | * Implicit class for comparing two double values using absolute tolerance. 76 | */ 77 | implicit class FloatMatrixWithAlmostEquals(val left: Seq[Seq[Float]]) { 78 | 79 | /** 80 | * When the difference of two values are within eps, returns true; otherwise, returns false. 81 | */ 82 | def ~==(right: Seq[Seq[Float]], epsilon : Float = 1E-6F): Boolean = { 83 | if (left.size === right.size) { 84 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) } 85 | } 86 | else false 87 | } 88 | } 89 | 90 | /** 91 | * Implicit class for comparing two double values using absolute tolerance. 92 | */ 93 | implicit class DoubleMatrixWithAlmostEquals(val left: Seq[Seq[Double]]) { 94 | 95 | /** 96 | * When the difference of two values are within eps, returns true; otherwise, returns false. 97 | */ 98 | def ~==(right: Seq[Seq[Double]], epsilon : Double = 1E-6): Boolean = { 99 | if (left.size === right.size) { 100 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) } 101 | } 102 | else false 103 | } 104 | } 105 | 106 | /** 107 | * Implicit class for comparing two decimal values using absolute tolerance. 108 | */ 109 | implicit class DecimalMatrixWithAlmostEquals(val left: Seq[Seq[Decimal]]) { 110 | 111 | /** 112 | * When the difference of two values are within eps, returns true; otherwise, returns false. 113 | */ 114 | def ~==(right: Seq[Seq[Decimal]], epsilon : Double = 1E-6): Boolean = { 115 | if (left.size === right.size) { 116 | (left zip right) forall { case (a, b) => a ~== (b, epsilon) } 117 | } 118 | else false 119 | } 120 | } 121 | 122 | /** 123 | * Implicit class for comparing two rows using absolute tolerance. 124 | */ 125 | implicit class RowWithAlmostEquals(val left: Row) { 126 | 127 | /** 128 | * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. 129 | */ 130 | def ~==(right: Row, schema: StructType): Boolean = { 131 | if (schema != null && schema.fields.size == left.size && schema.fields.size == right.size) { 132 | val leftRowWithSchema = new GenericRowWithSchema(left.toSeq.toArray, schema) 133 | val rightRowWithSchema = new GenericRowWithSchema(right.toSeq.toArray, schema) 134 | leftRowWithSchema ~== rightRowWithSchema 135 | } 136 | else false 137 | } 138 | 139 | /** 140 | * When all fields in row are equal or are within eps, returns true; otherwise, returns false. 141 | */ 142 | def ~==(right: Row, epsilon : Float = 1E-6F): Boolean = { 143 | if (left.size === right.size) { 144 | val leftDataTypes = left.schema.fields.map(_.dataType) 145 | val rightDataTypes = right.schema.fields.map(_.dataType) 146 | 147 | (leftDataTypes zip rightDataTypes).zipWithIndex.forall { 148 | case (x, i) if left.get(i) == null || right.get(i) == null => 149 | left.get(i) == null && right.get(i) == null 150 | 151 | case ((FloatType, FloatType), i) => 152 | left.getFloat(i) === (right.getFloat(i) +- epsilon) 153 | 154 | case ((DoubleType, DoubleType), i) => 155 | left.getDouble(i) === (right.getDouble(i) +- epsilon) 156 | 157 | case ((BinaryType, BinaryType), i) => 158 | left.getAs[Array[Byte]](i).toSeq === right.getAs[Array[Byte]](i).toSeq 159 | 160 | case ((ArrayType(FloatType,_), ArrayType(FloatType,_)), i) => 161 | val leftArray = ArrayData.toArrayData(left.get(i)).toFloatArray().toSeq 162 | val rightArray = ArrayData.toArrayData(right.get(i)).toFloatArray().toSeq 163 | leftArray ~== (rightArray, epsilon) 164 | 165 | case ((ArrayType(DoubleType,_), ArrayType(DoubleType,_)), i) => 166 | val leftArray = ArrayData.toArrayData(left.get(i)).toDoubleArray().toSeq 167 | val rightArray = ArrayData.toArrayData(right.get(i)).toDoubleArray().toSeq 168 | leftArray ~== (rightArray, epsilon) 169 | 170 | case ((ArrayType(BinaryType,_), ArrayType(BinaryType,_)), i) => 171 | val leftArray = ArrayData.toArrayData(left.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq 172 | val rightArray = ArrayData.toArrayData(right.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq 173 | leftArray === rightArray 174 | 175 | case ((ArrayType(ArrayType(FloatType,_),_), ArrayType(ArrayType(FloatType,_),_)), i) => 176 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => 177 | ArrayData.toArrayData(arr).toFloatArray().toSeq 178 | } 179 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => 180 | ArrayData.toArrayData(arr).toFloatArray().toSeq 181 | } 182 | leftArrays ~== (rightArrays, epsilon) 183 | 184 | case ((ArrayType(ArrayType(DoubleType,_),_), ArrayType(ArrayType(DoubleType,_),_)), i) => 185 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => 186 | ArrayData.toArrayData(arr).toDoubleArray().toSeq 187 | } 188 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => 189 | ArrayData.toArrayData(arr).toDoubleArray().toSeq 190 | } 191 | leftArrays ~== (rightArrays, epsilon) 192 | 193 | case ((ArrayType(ArrayType(BinaryType,_),_), ArrayType(ArrayType(BinaryType,_),_)), i) => 194 | val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => 195 | ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq 196 | } 197 | val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => 198 | ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq 199 | } 200 | leftArrays === rightArrays 201 | 202 | case((a,b), i) => left.get(i) === right.get(i) 203 | } 204 | } 205 | else false 206 | } 207 | } 208 | } -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureEncoderTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords.serde 17 | 18 | import org.scalatest.{Matchers, WordSpec} 19 | import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ 20 | import scala.collection.JavaConverters._ 21 | 22 | class FeatureEncoderTest extends WordSpec with Matchers { 23 | 24 | "Int64List feature encoder" should { 25 | "Encode inputs to Int64List" in { 26 | val longFeature = Int64ListFeatureEncoder.encode(Seq(10L)) 27 | val longListFeature = Int64ListFeatureEncoder.encode(Seq(3L,5L,6L)) 28 | 29 | assert(longFeature.getInt64List.getValueList.asScala.toSeq === Seq(10L)) 30 | assert(longListFeature.getInt64List.getValueList.asScala.toSeq === Seq(3L, 5L, 6L)) 31 | } 32 | 33 | "Encode empty list to empty feature" in { 34 | val longListFeature = Int64ListFeatureEncoder.encode(Seq.empty[Long]) 35 | assert(longListFeature.getInt64List.getValueList.size() === 0) 36 | } 37 | } 38 | 39 | "FloatList feature encoder" should { 40 | "Encode inputs to FloatList" in { 41 | val floatFeature = FloatListFeatureEncoder.encode(Seq(2.5F)) 42 | val floatListFeature = FloatListFeatureEncoder.encode(Seq(1.5F,6.8F,-3.2F)) 43 | 44 | assert(floatFeature.getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== Seq(2.5F)) 45 | assert(floatListFeature.getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== Seq(1.5F,6.8F,-3.2F)) 46 | } 47 | 48 | "Encode empty list to empty feature" in { 49 | val floatListFeature = FloatListFeatureEncoder.encode(Seq.empty[Float]) 50 | assert(floatListFeature.getFloatList.getValueList.size() === 0) 51 | } 52 | } 53 | 54 | "ByteList feature encoder" should { 55 | "Encode inputs to ByteList" in { 56 | val binFeature = BytesListFeatureEncoder.encode(Seq(Array(0xff.toByte, 0xd8.toByte))) 57 | val binListFeature = BytesListFeatureEncoder.encode(Seq(Array(0xff.toByte, 0xd8.toByte), Array(0xff.toByte, 0xd9.toByte))) 58 | 59 | assert(binFeature.getBytesList.getValueList.asScala.toSeq.map(_.toByteArray.deep) === Seq(Array(0xff.toByte, 0xd8.toByte).deep)) 60 | assert(binListFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) === Seq(Array(0xff.toByte, 0xd8.toByte).deep, Array(0xff.toByte, 0xd9.toByte).deep)) 61 | } 62 | 63 | "Encode empty list to empty feature" in { 64 | val binListFeature = BytesListFeatureEncoder.encode(Seq.empty[Array[Byte]]) 65 | assert(binListFeature.getBytesList.getValueList.size() === 0) 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/serde/FeatureListEncoderTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | *       http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorflow.spark.datasources.tfrecords.serde 17 | 18 | import org.scalatest.{Matchers, WordSpec} 19 | import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ 20 | 21 | import scala.collection.JavaConverters._ 22 | 23 | class FeatureListEncoderTest extends WordSpec with Matchers { 24 | 25 | "Int64 feature list encoder" should { 26 | 27 | "Encode inputs to feature list of Int64" in { 28 | val longListOfLists = Seq(Seq(3L,5L,Int.MaxValue+6L), Seq(-1L,-6L)) 29 | val longFeatureList = Int64FeatureListEncoder.encode(longListOfLists) 30 | 31 | longFeatureList.getFeatureList.asScala.map(_.getInt64List.getValueList.asScala.toSeq) should equal (longListOfLists) 32 | } 33 | 34 | "Encode empty array to empty feature list" in { 35 | val longFeatureList = Int64FeatureListEncoder.encode(Seq.empty[Seq[Long]]) 36 | assert(longFeatureList.getFeatureList.size() === 0) 37 | } 38 | } 39 | 40 | "Float feature list encoder" should { 41 | 42 | "Encode inputs to feature list of Float" in { 43 | val floatListOfLists = Seq(Seq(-2.67F, 1.5F, 0F), Seq(-1.4F,-6F)) 44 | val floatFeatureList = FloatFeatureListEncoder.encode(floatListOfLists) 45 | 46 | assert(floatFeatureList.getFeatureList.asScala.map(_.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists) 47 | } 48 | 49 | "Encode empty array to empty feature list" in { 50 | val floatFeatureList = FloatFeatureListEncoder.encode(Seq.empty[Seq[Float]]) 51 | assert(floatFeatureList.getFeatureList.size() === 0) 52 | } 53 | } 54 | 55 | "Bytes feature list encoder" should { 56 | 57 | "Encode inputs to feature list of bytes" in { 58 | val bytesListOfLists = Seq(Seq("alice".getBytes, "bob".getBytes), Seq("charles".getBytes)) 59 | val bytesFeatureList = BytesFeatureListEncoder.encode(bytesListOfLists) 60 | 61 | assert(bytesFeatureList.getFeatureList.asScala.map(_.getBytesList.getValueList.asScala.toSeq.map(_.toByteArray.deep)) === bytesListOfLists.map(_.map(_.deep))) 62 | } 63 | 64 | "Encode empty array to empty feature list" in { 65 | val bytesFeatureList = BytesFeatureListEncoder.encode(Seq.empty[Seq[Array[Byte]]]) 66 | assert(bytesFeatureList.getFeatureList.size() === 0) 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-connector/src/test/scala/org/tensorflow/spark/datasources/tfrecords/udf/DataFrameTfrConverterTest.scala: -------------------------------------------------------------------------------- 1 | package org.tensorflow.spark.datasources.tfrecords.udf 2 | 3 | import org.apache.spark.ml.linalg.SQLDataTypes.VectorType 4 | import org.apache.spark.ml.linalg.Vectors 5 | import org.apache.spark.sql.{Row, SparkSession} 6 | import org.apache.spark.sql.functions.{col, struct} 7 | import org.apache.spark.sql.types._ 8 | import org.tensorflow.example.{Example, Feature, SequenceExample} 9 | import org.tensorflow.spark.datasources.tfrecords.SharedSparkSessionSuite 10 | import org.tensorflow.spark.datasources.tfrecords.TestingUtils._ 11 | 12 | import scala.collection.JavaConverters._ 13 | 14 | class DataFrameTfrConverterTest extends SharedSparkSessionSuite { 15 | 16 | "DataFrame to tfr" should { 17 | "Encode given Row as TensorFlow Example" in { 18 | 19 | val schema = StructType(Array( 20 | StructField("IntegerLabel", IntegerType), 21 | StructField("LongLabel", LongType), 22 | StructField("FloatLabel", FloatType), 23 | StructField("DoubleLabel", DoubleType), 24 | StructField("DoubleArrayLabel", ArrayType(DoubleType)), 25 | StructField("StrLabel", StringType), 26 | StructField("StrArrayLabel", ArrayType(StringType)), 27 | StructField("DenseVectorLabel", VectorType), 28 | StructField("SparseVectorLabel", VectorType), 29 | StructField("BinaryLabel", BinaryType), 30 | StructField("BinaryArrayLabel", ArrayType(BinaryType)) 31 | )) 32 | val doubleArray = Array(1.1, 111.1, 11111.1) 33 | val sparseVector = Vectors.sparse(3, Seq((1, 2.0), (2, 1.5))) 34 | val denseVector = Vectors.dense(Array(5.6, 7.0)) 35 | val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte) 36 | val byteArray1 = Array[Byte](-128, 23, 127) 37 | 38 | val data = 39 | Row(1, 23L, 10.0F, 14.0, doubleArray, 40 | "r1", Seq("r2", "r3"), denseVector, sparseVector, 41 | byteArray, Seq(byteArray, byteArray1)) :: 42 | Nil 43 | 44 | val input = spark 45 | .createDataFrame(spark.sparkContext.makeRDD(data), schema) 46 | 47 | val examples = input 48 | .select(DataFrameTfrConverter.getRowToTFRecordExampleUdf(struct(input.columns.map(col): _*)).as("tfr")) 49 | .collect() 50 | .map {case Row(tfr: Array[Byte]) => 51 | Example.parseFrom(tfr) 52 | } 53 | .toList 54 | 55 | val featureMap = examples.head.getFeatures.getFeatureMap.asScala 56 | 57 | assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 1) 58 | assert(featureMap("LongLabel").getInt64List.getValue(0).toInt == 23) 59 | assert(featureMap("FloatLabel").getFloatList.getValue(0) == 10.0F) 60 | assert(featureMap("DoubleLabel").getFloatList.getValue(0) == 14.0F) 61 | assert(featureMap("DoubleArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== doubleArray.map(_.toFloat)) 62 | assert(featureMap("StrLabel").getBytesList.getValue(0).toStringUtf8 == "r1") 63 | assert(featureMap("StrArrayLabel").getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("r2", "r3")) 64 | assert(featureMap("DenseVectorLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== denseVector.toArray.map(_.toFloat)) 65 | assert(featureMap("SparseVectorLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== sparseVector.toDense.toArray.map(_.toFloat)) 66 | assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.deep == byteArray.deep) 67 | val binaryArrayValue = featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala.map(byteArray => byteArray.asScala.toArray.map(_.toByte)) 68 | assert(binaryArrayValue.toArray.deep == Array(byteArray, byteArray1).deep) 69 | } 70 | 71 | "Encode given Row as TensorFlow SequenceExample" in { 72 | 73 | val schemaStructType = StructType(Array( 74 | StructField("IntegerLabel", IntegerType), 75 | StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), 76 | StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), 77 | StructField("DoubleArrayOfArrayLabel", ArrayType(ArrayType(DoubleType))), 78 | StructField("StringArrayOfArrayLabel", ArrayType(ArrayType(StringType))), 79 | StructField("BinaryArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) 80 | )) 81 | 82 | val longListOfLists = Seq(Seq(3L, 5L), Seq(-8L, 0L)) 83 | val floatListOfLists = Seq(Seq(1.5F, -6.5F), Seq(-8.2F, 0F)) 84 | val doubleListOfLists = Seq(Seq(3.0), Seq(6.0, 9.0)) 85 | val stringListOfLists = Seq(Seq("r1"), Seq("r2", "r3"), Seq("r4")) 86 | val binaryListOfLists = stringListOfLists.map(stringList => stringList.map(_.getBytes)) 87 | 88 | val data = 89 | Row(10, longListOfLists, 90 | floatListOfLists, doubleListOfLists, 91 | stringListOfLists, binaryListOfLists) :: 92 | Nil 93 | 94 | val input = spark 95 | .createDataFrame(spark.sparkContext.makeRDD(data), schemaStructType) 96 | 97 | val examples = input 98 | .select(DataFrameTfrConverter.getRowToTFRecordSequenceExampleUdf(struct(input.columns.map(col): _*)).as("tfr")) 99 | .collect() 100 | .map {case Row(tfr: Array[Byte]) => 101 | SequenceExample.parseFrom(tfr) 102 | } 103 | .toList 104 | 105 | val featureMap = examples.head.getContext.getFeatureMap.asScala 106 | val featureListMap = examples.head.getFeatureLists.getFeatureListMap.asScala 107 | 108 | assert(featureMap.size == 1) 109 | assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) 110 | assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 10) 111 | 112 | assert(featureListMap.size == 5) 113 | assert(featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala.map( 114 | _.getInt64List.getValueList.asScala.toSeq) === longListOfLists) 115 | assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.map( 116 | _.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists) 117 | assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.map( 118 | _.getFloatList.getValueList.asScala.map(_.toDouble).toSeq) ~== doubleListOfLists) 119 | assert(featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map( 120 | _.getBytesList.getValueList.asScala.map(_.toStringUtf8).toSeq) === stringListOfLists) 121 | assert(featureListMap("BinaryArrayOfArrayLabel").getFeatureList.asScala.map( 122 | _.getBytesList.getValueList.asScala.map(byteList => byteList.asScala.toSeq)) === binaryListOfLists.map(_.map(_.toSeq))) 123 | } 124 | } 125 | } -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # IDE 109 | .idea/ 110 | 111 | # Generated files 112 | tests/integration/spark_conf/spark-defaults.conf 113 | tests/integration/spark_conf/spark-custom.conf 114 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | # YAPF uses the google style 3 | based_on_style = google 4 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | ARG PYTHON_INSTALL_VERSION 4 | 5 | RUN apt-get update && \ 6 | apt-get install -y --fix-missing unzip bzip2 build-essential openjdk-8-jdk sudo wget 7 | 8 | RUN apt-get update && \ 9 | apt install -y software-properties-common && \ 10 | apt-get install -y python${PYTHON_INSTALL_VERSION} python${PYTHON_INSTALL_VERSION}-dev python${PYTHON_INSTALL_VERSION}-distutils && \ 11 | wget https://bootstrap.pypa.io/get-pip.py && \ 12 | python$PYTHON_INSTALL_VERSION get-pip.py && \ 13 | rm get-pip.py && \ 14 | update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python${PYTHON_INSTALL_VERSION} 1 && \ 15 | update-alternatives --install /usr/local/bin/python python /usr/bin/python${PYTHON_INSTALL_VERSION} 1 16 | 17 | # Install Spark and update env variables 18 | # Modify this when the next Spark 3 release includes the allGather barrier mode API 19 | RUN pip install \ 20 | https://ml-team-public-read.s3-us-west-2.amazonaws.com/pyspark-3.1.0.dev0-60dd1a690fed62b1d6442cdc8cf3f89ef4304d5a.tar.gz \ 21 | --force-reinstall --upgrade 22 | 23 | # Set SPARK_HOME so that tests can easily find the scripts in $SPARK_HOME/sbin 24 | ENV SPARK_HOME /usr/local/lib/python${PYTHON_INSTALL_VERSION}/dist-packages/pyspark 25 | 26 | COPY ./requirements.txt /opt/ 27 | RUN pip install -r /opt/requirements.txt && \ 28 | rm /opt/requirements.txt 29 | 30 | RUN sudo apt-get update && \ 31 | sudo apt-get install -y \ 32 | apt-transport-https \ 33 | ca-certificates \ 34 | curl \ 35 | gnupg-agent \ 36 | software-properties-common && \ 37 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - && \ 38 | sudo add-apt-repository \ 39 | "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ 40 | $(lsb_release -cs) \ 41 | stable" && \ 42 | sudo apt-get update && \ 43 | sudo apt-get install -y docker-ce docker-ce-cli containerd.io 44 | 45 | # Install docker-compose 46 | RUN apt-get install -y curl && \ 47 | curl -L https://github.com/docker/compose/releases/download/1.22.0/docker-compose-`uname -s`-`uname -m` > docker-compose && \ 48 | chmod +x docker-compose && \ 49 | sudo mv docker-compose /usr/local/bin 50 | 51 | # The spark-tensorflow-distributor dir will be mounted here. 52 | VOLUME /mnt/spark-tensorflow-distributor 53 | WORKDIR /mnt/spark-tensorflow-distributor 54 | 55 | ENV PYTHONPATH=/mnt/spark-tensorflow-distributor 56 | 57 | ENTRYPOINT /bin/bash 58 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/README.md: -------------------------------------------------------------------------------- 1 | # Spark TensorFlow Distributor 2 | 3 | This package helps users do distributed training with TensorFlow on their Spark clusters. 4 | 5 | ## Installation 6 | 7 | This package requires Python 3.6+, `tensorflow>=2.1.0` and `pyspark>=3.0.0` to run. 8 | To install `spark-tensorflow-distributor`, run: 9 | 10 | ```bash 11 | pip install spark-tensorflow-distributor 12 | ``` 13 | 14 | The installation does not install PySpark because for most users, PySpark is already installed. 15 | If you do not have PySpark installed, you can install it directly: 16 | 17 | ```bash 18 | pip install pyspark>=3.0.* 19 | ``` 20 | 21 | Note also that in order to use many features of this package, you must set up Spark custom 22 | resource scheduling for GPUs on your cluster. See the Spark docs for this. 23 | 24 | ## Running Tests 25 | 26 | For integration tests, first build the master and worker images and then run the test script. 27 | 28 | ```bash 29 | docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7 30 | ./tests/integration/run.sh 31 | ``` 32 | 33 | For linting, run the following. 34 | 35 | ```bash 36 | ./tests/lint.sh 37 | ``` 38 | 39 | To use the autoformatter, run the following. 40 | 41 | ```bash 42 | yapf --recursive --in-place spark_tensorflow_distributor 43 | ``` 44 | 45 | ## Examples 46 | 47 | Run following example code in `pyspark` shell: 48 | 49 | ```python 50 | from spark_tensorflow_distributor import MirroredStrategyRunner 51 | 52 | # Adapted from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras 53 | def train(): 54 | import tensorflow as tf 55 | import uuid 56 | 57 | BUFFER_SIZE = 10000 58 | BATCH_SIZE = 64 59 | 60 | def make_datasets(): 61 | (mnist_images, mnist_labels), _ = \ 62 | tf.keras.datasets.mnist.load_data(path=str(uuid.uuid4())+'mnist.npz') 63 | 64 | dataset = tf.data.Dataset.from_tensor_slices(( 65 | tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), 66 | tf.cast(mnist_labels, tf.int64)) 67 | ) 68 | dataset = dataset.repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 69 | return dataset 70 | 71 | def build_and_compile_cnn_model(): 72 | model = tf.keras.Sequential([ 73 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 74 | tf.keras.layers.MaxPooling2D(), 75 | tf.keras.layers.Flatten(), 76 | tf.keras.layers.Dense(64, activation='relu'), 77 | tf.keras.layers.Dense(10, activation='softmax'), 78 | ]) 79 | model.compile( 80 | loss=tf.keras.losses.sparse_categorical_crossentropy, 81 | optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 82 | metrics=['accuracy'], 83 | ) 84 | return model 85 | 86 | train_datasets = make_datasets() 87 | options = tf.data.Options() 88 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA 89 | train_datasets = train_datasets.with_options(options) 90 | multi_worker_model = build_and_compile_cnn_model() 91 | multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5) 92 | 93 | MirroredStrategyRunner(num_slots=8).run(train) 94 | ``` 95 | 96 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | master: 4 | image: sarthfrey/spark-tensorflow-distributor:3.7_master 5 | build: . 6 | tty: true 7 | hostname: master 8 | environment: 9 | - MASTER=spark://master:7077 10 | ports: 11 | - "4040:4040" # driver UI 12 | - "8080:8080" # master UI 13 | volumes: 14 | - .:/mnt/spark-tensorflow-distributor 15 | - /var/run/docker.sock:/var/run/docker.sock 16 | worker: 17 | image: sarthfrey/spark-tensorflow-distributor:3.7_worker 18 | build: . 19 | tty: true 20 | ports: 21 | - "8081-8090:8081" # worker UI 22 | links: 23 | - master 24 | volumes: 25 | - .:/mnt/spark-tensorflow-distributor 26 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/examples/simple/example.py: -------------------------------------------------------------------------------- 1 | from spark_tensorflow_distributor import MirroredStrategyRunner 2 | 3 | # Adapted from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras 4 | def train(): 5 | import tensorflow as tf 6 | import uuid 7 | 8 | BUFFER_SIZE = 10000 9 | BATCH_SIZE = 64 10 | 11 | def make_datasets(): 12 | (mnist_images, mnist_labels), _ = \ 13 | tf.keras.datasets.mnist.load_data(path=str(uuid.uuid4())+'mnist.npz') 14 | 15 | dataset = tf.data.Dataset.from_tensor_slices(( 16 | tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), 17 | tf.cast(mnist_labels, tf.int64)) 18 | ) 19 | dataset = dataset.repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 20 | return dataset 21 | 22 | def build_and_compile_cnn_model(): 23 | model = tf.keras.Sequential([ 24 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 25 | tf.keras.layers.MaxPooling2D(), 26 | tf.keras.layers.Flatten(), 27 | tf.keras.layers.Dense(64, activation='relu'), 28 | tf.keras.layers.Dense(10, activation='softmax'), 29 | ]) 30 | model.compile( 31 | loss=tf.keras.losses.sparse_categorical_crossentropy, 32 | optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 33 | metrics=['accuracy'], 34 | ) 35 | return model 36 | 37 | train_datasets = make_datasets() 38 | options = tf.data.Options() 39 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA 40 | train_datasets = train_datasets.with_options(options) 41 | multi_worker_model = build_and_compile_cnn_model() 42 | multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5) 43 | 44 | MirroredStrategyRunner(num_slots=8).run(train) 45 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/requirements.txt: -------------------------------------------------------------------------------- 1 | # Library dependencies for the python code. You need to install these with 2 | # `pip install -r requirements.txt` before you can run this. 3 | 4 | #### ESSENTIAL RUNTIME DEPENDENCIES AND TEST DEPENDENCIES FOR MAIN FUNCTIONALITY #### 5 | 6 | ## Tensorflow Requirements ## 7 | # Tensorflow is required to run this code but depends on specific configurations. Install from: 8 | # https://www.tensorflow.org/get_started/os_setup#overview 9 | 10 | # NOTE: If you want to use the GPU version, you will also need Nvidia's CUDA toolkit and cuDNN: 11 | # https://developer.nvidia.com/cuda-downloads 12 | # https://developer.nvidia.com/cudnn 13 | # Note that if you want to use the GPU version, you have to `pip uninstall tensorflow` 14 | # and `pip install tensorflow-gpu`, since both cannot coexist. 15 | tensorflow>=2.1.0 16 | 17 | # We'll use pytest to run our tests; this isn't really necessary to run the code, but it is to run 18 | # the tests. With this here, you can run the tests with `py.test` from the base directory. 19 | pytest 20 | 21 | # This is used for autoformatting and testing formatting 22 | yapf 23 | 24 | # Linting 25 | pylint 26 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="spark_tensorflow_distributor", 9 | version="0.0.3", 10 | author="sarthfrey", 11 | author_email="sarth.frey@gmail.com", 12 | description="This package helps users do distributed training with TensorFlow on their Spark clusters.", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-distributor", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Development Status :: 1 - Planning", 19 | "Intended Audience :: Developers", 20 | "Operating System :: OS Independent", 21 | "Programming Language :: Python :: 3", 22 | "Natural Language :: English", 23 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 24 | "Topic :: Software Development :: Libraries :: Python Modules", 25 | "Topic :: Software Development :: Version Control :: Git", 26 | ], 27 | install_requires=[ 28 | "tensorflow>=2.1.0", 29 | ], 30 | python_requires='>=3.6', 31 | ) 32 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/spark_tensorflow_distributor/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from .mirrored_strategy_runner import MirroredStrategyRunner 18 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import logging 3 | import subprocess 4 | 5 | 6 | from pyspark.sql import SparkSession 7 | 8 | 9 | def pytest_addoption(parser): 10 | parser.addoption('--max_num_workers', action='store') 11 | 12 | def restart_spark(num_workers, num_gpus_per_worker, max_num_workers): 13 | subprocess.run( 14 | [ 15 | 'cat /dev/null | tests/integration/restart_spark.sh --num_workers {} ' 16 | '--num_gpus_per_worker {} ' 17 | '--max_num_workers {}'.format(num_workers, num_gpus_per_worker, max_num_workers) 18 | ], 19 | shell=True, 20 | check=True, 21 | ) 22 | 23 | @pytest.fixture(scope='session') 24 | def max_num_workers(request): 25 | return int(request.config.getoption("--max_num_workers")) 26 | 27 | @pytest.fixture(scope='session') 28 | def extra_spark_configs(request, autouse=True): 29 | if hasattr(request, 'param'): 30 | conf = request.param 31 | else: 32 | conf = {} 33 | with open('tests/integration/spark_conf/spark-custom.conf', 'w') as f: 34 | f.writelines( 35 | ['{} {}\n'.format(k, v) for k, v in conf.items()] 36 | ) 37 | return conf 38 | 39 | @pytest.fixture(scope='session', autouse=True) 40 | def num_workers(request, max_num_workers): 41 | if not hasattr(request, 'param'): 42 | raise Exception( 43 | 'num_workers is a required fixture for Spark ' 44 | 'TensorFlow Distributor tests, but test `{}` does not ' 45 | 'use it.'.format(request.node.name) 46 | ) 47 | num_workers_value = request.param 48 | if num_workers_value > max_num_workers: 49 | raise Exception( 50 | 'num_workers cannot be greater than {max_num_workers} but ' 51 | 'test `{test_name}` requested num_workers = `{num_workers}`.' 52 | 'use it.'.format( 53 | max_num_workers=max_num_workers, 54 | test_name=request.node.name, 55 | num_workers=num_workers_value, 56 | ) 57 | ) 58 | return num_workers_value 59 | 60 | @pytest.fixture(scope='session', autouse=True) 61 | def num_gpus_per_worker(request): 62 | if not hasattr(request, 'param'): 63 | raise Exception( 64 | 'num_gpus_per_worker is a required fixture for Spark ' 65 | 'TensorFlow Distributor tests, but test `{}` does not ' 66 | 'use it.'.format(request.node.name) 67 | ) 68 | num_gpus_per_worker_value = request.param 69 | return num_gpus_per_worker_value 70 | 71 | @pytest.fixture(scope='session', autouse=True) 72 | def spark_session(num_workers, num_gpus_per_worker, max_num_workers, extra_spark_configs): 73 | restart_spark(num_workers, num_gpus_per_worker, max_num_workers) 74 | builder = SparkSession.builder.appName('Spark TensorFlow Distributor Tests') 75 | logging.getLogger().info('Creating spark session with the following confs.') 76 | with open('tests/integration/spark_conf/spark-defaults.conf') as f: 77 | for line in f: 78 | l = line.strip() 79 | if l: 80 | k, v = l.split(None, 1) 81 | builder.config(k, v) 82 | session = builder.getOrCreate() 83 | yield session 84 | session.stop() 85 | 86 | 87 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/pytest.ini: -------------------------------------------------------------------------------- 1 | # content of pytest.ini 2 | [pytest] 3 | markers = 4 | valid_gpus: mark a test for how many gpus on each machine it requires to run 5 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/restart_spark.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | if [ "$1" != "--num_workers" ]; then 6 | echo 'Usage: restart_spark.sh --num_workers --num_gpus_per_worker --max_num_workers '; 7 | exit 1; 8 | fi 9 | 10 | if [ "$3" != "--num_gpus_per_worker" ]; then 11 | echo 'Usage: restart_spark.sh --num_workers --num_gpus_per_worker --max_num_workers '; 12 | exit 1; 13 | fi 14 | 15 | if [ "$5" != "--max_num_workers" ]; then 16 | echo 'Usage: restart_spark.sh --num_workers --num_gpus_per_worker --max_num_workers '; 17 | exit 1; 18 | fi 19 | 20 | python3 tests/integration/stop_spark.py --num_workers $6 || true && \ 21 | python3 tests/integration/set_spark_conf.py --num_workers $2 --num_gpus_per_worker $4 && \ 22 | python3 tests/integration/start_spark.py --num_workers $2 23 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | chmod +x tests/integration/spark_conf/gpuDiscoveryScriptStub.sh 6 | chmod +x tests/integration/start_spark.py 7 | chmod +x tests/integration/stop_spark.py 8 | chmod +x tests/integration/start_master.sh 9 | chmod +x tests/integration/start_worker.sh 10 | chmod +x tests/integration/stop_worker.sh 11 | chmod +x tests/integration/stop_master.sh 12 | chmod +x tests/integration/restart_spark.sh 13 | python3 tests/integration/run_tests.py 14 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/run_tests.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | 5 | MAX_NUM_WORKERS = 2 6 | 7 | 8 | subprocess.run( 9 | [ 10 | 'docker-compose down && ' 11 | 'docker-compose up -d --scale worker={max_num_workers} && ' 12 | 'docker-compose exec -T master python -m pytest -s tests/integration ' 13 | '--max_num_workers {max_num_workers} && ' 14 | 'docker-compose down'.format(max_num_workers=MAX_NUM_WORKERS) 15 | ], 16 | shell=True, 17 | check=True, 18 | ) 19 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/set_spark_conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script sets the contents of spark-defaults.conf and 3 | gpuDiscoveryScriptStub.sh for the purpose of integration 4 | testing. It does so based on the arg inputs, spark-base.conf 5 | which is static, and spark-custom.conf which is dynamically 6 | changed by tests. 7 | """ 8 | 9 | 10 | import argparse 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | '--num_workers', 16 | help='Number of workers to be set in the spark conf' 17 | ) 18 | parser.add_argument( 19 | '--num_gpus_per_worker', 20 | help='Number of gpus on each worker to be set in the spark conf' 21 | ) 22 | args = parser.parse_args() 23 | num_workers = int(args.num_workers) 24 | num_gpus_per_worker = str(args.num_gpus_per_worker) 25 | 26 | conf = {} 27 | 28 | with open('tests/integration/spark_conf/spark-base.conf', 'r') as f: 29 | lines = f.readlines() 30 | for i in range(len(lines)): 31 | l = lines[i].strip() 32 | if l: 33 | k, v = l.split(None, 1) 34 | conf[k] = v 35 | 36 | with open('tests/integration/spark_conf/spark-custom.conf', 'r') as f: 37 | lines = f.readlines() 38 | for i in range(len(lines)): 39 | l = lines[i].strip() 40 | if l: 41 | k, v = l.split(None, 1) 42 | conf[k] = v 43 | 44 | with open('tests/integration/spark_conf/spark-defaults.conf', 'w') as f: 45 | f.writelines( 46 | ['{} {}\n'.format(k, v) for k, v in conf.items()] 47 | ) 48 | 49 | with open('tests/integration/spark_conf/gpuDiscoveryScriptStub.sh', 'w+') as f: 50 | original_file_content = f.read() 51 | gpus = '","'.join(str(e) for e in range(int(num_gpus_per_worker))) 52 | cmd = "echo '{\"name\": \"gpu\", \"addresses\":[\"" + gpus + "\"]}'" 53 | f.writelines([ 54 | '#!/usr/bin/env bash\n', 55 | '\n', 56 | cmd, 57 | ]) 58 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/spark_conf/gpuDiscoveryScriptStub.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo '{"name": "gpu", "addresses":["0","1","2","3"]}' -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/spark_conf/spark-base.conf: -------------------------------------------------------------------------------- 1 | spark.cores.max 2 2 | spark.executor.cores 1 3 | spark.task.cpus 1 4 | spark.driver.resource.gpu.amount 4 5 | spark.driver.resource.gpu.discoveryScript /mnt/spark-tensorflow-distributor/tests/integration/spark_conf/gpuDiscoveryScriptStub.sh 6 | spark.executor.resource.gpu.amount 4 7 | spark.task.resource.gpu.amount 4 8 | spark.worker.resource.gpu.discoveryScript /mnt/spark-tensorflow-distributor/tests/integration/spark_conf/gpuDiscoveryScriptStub.sh 9 | spark.worker.resource.gpu.amount 4 10 | spark.master spark://master:7077 11 | spark.driver.memory 512m 12 | spark.executor.memory 512m 13 | spark.scheduler.minRegisteredResourcesRatio 1.0 14 | spark.scheduler.maxRegisteredResourcesWaitingTime 180s 15 | spark.python.worker.reuse false 16 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/start_master.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | ${SPARK_HOME}/sbin/spark-daemon.sh \ 6 | start org.apache.spark.deploy.master.Master 1 \ 7 | --properties-file /mnt/spark-tensorflow-distributor/tests/integration/spark_conf/spark-defaults.conf \ 8 | -h master 9 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/start_spark.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument( 7 | '--num_workers', 8 | help='Number of docker workers', 9 | required=True, 10 | ) 11 | args = parser.parse_args() 12 | num_workers = int(args.num_workers) 13 | 14 | # Start spark daemon from master node 15 | subprocess.run( 16 | [ 17 | '/mnt/spark-tensorflow-distributor/tests/integration/start_master.sh' 18 | ], 19 | shell=True, 20 | ) 21 | 22 | # Start spark daemon on worker nodes 23 | for worker_index in range(1, num_workers + 1): 24 | print(f'Starting worker {worker_index}') 25 | subprocess.run( 26 | [ 27 | 'docker-compose exec -T --index={} worker ' 28 | 'tests/integration/start_worker.sh'.format(worker_index) 29 | ], 30 | shell=True, 31 | ) 32 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/start_worker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | ${SPARK_HOME}/sbin/spark-daemon.sh \ 6 | start org.apache.spark.deploy.worker.Worker 1 \ 7 | --properties-file /mnt/spark-tensorflow-distributor/tests/integration/spark_conf/spark-defaults.conf \ 8 | spark://master:7077 9 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/stop_master.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | ${SPARK_HOME}/sbin/spark-daemon.sh \ 6 | stop org.apache.spark.deploy.master.Master 1 7 | 8 | exit 0 9 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/stop_spark.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | import sys 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | '--num_workers', 9 | help='Number of docker workers', 10 | required=True, 11 | ) 12 | args = parser.parse_args() 13 | num_workers = int(args.num_workers) 14 | 15 | # Stop spark daemon on worker nodes 16 | for worker_index in range(1, num_workers + 1): 17 | subprocess.run( 18 | [ 19 | 'docker-compose exec -T --index={} worker ' 20 | '/mnt/spark-tensorflow-distributor/tests/integration/stop_worker.sh'.format(worker_index) 21 | ], 22 | shell=True, 23 | ) 24 | 25 | # Stop spark daemon on master node 26 | subprocess.run( 27 | [ 28 | '/mnt/spark-tensorflow-distributor/tests/integration/stop_master.sh' 29 | ], 30 | shell=True, 31 | ) 32 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/stop_worker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | ${SPARK_HOME}/sbin/spark-daemon.sh \ 6 | stop org.apache.spark.deploy.worker.Worker 1 7 | 8 | exit 0 9 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/integration/test_mirrored_strategy_runner.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pytest 4 | from pyspark.sql import SparkSession 5 | 6 | from spark_tensorflow_distributor import MirroredStrategyRunner 7 | from unittest import mock 8 | 9 | 10 | @pytest.mark.parametrize('num_workers', [2], indirect=True) 11 | @pytest.mark.parametrize('num_gpus_per_worker', [4], indirect=True) 12 | @pytest.mark.parametrize( 13 | 'extra_spark_configs', 14 | [{'spark.task.resource.gpu.amount': '1', 'spark.cores.max': 8, 'spark.executor.cores': 4}, 15 | {'spark.task.resource.gpu.amount': '2', 'spark.cores.max': 4, 'spark.executor.cores': 2}, 16 | {'spark.task.resource.gpu.amount': '4', 'spark.cores.max': 2, 'spark.executor.cores': 1}], 17 | indirect=True, 18 | ) 19 | def test_equal_gpu_allocation(num_workers, num_gpus_per_worker): 20 | def train_fn(): 21 | import os 22 | from pyspark import BarrierTaskContext 23 | context = BarrierTaskContext.get() 24 | cuda_state = os.environ['CUDA_VISIBLE_DEVICES'] 25 | if cuda_state: 26 | num_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) 27 | else: 28 | num_gpus = 0 29 | return [int(e) for e in context.allGather(str(num_gpus))] 30 | 31 | for num_slots in [2, 4, 6, 8]: 32 | runner = MirroredStrategyRunner(num_slots=num_slots) 33 | task_gpu_amount = int(runner.sc.getConf().get('spark.task.resource.gpu.amount')) 34 | expected_num_task = math.ceil(num_slots / task_gpu_amount) 35 | assert runner.get_num_tasks() == expected_num_task 36 | gpus_used_by_each_task = runner.run(train_fn) 37 | assert gpus_used_by_each_task == [ 38 | (num_slots // expected_num_task) + (i < (num_slots % expected_num_task)) 39 | for i in range(expected_num_task) 40 | ] 41 | 42 | @pytest.mark.parametrize('num_workers', [2], indirect=True) 43 | @pytest.mark.parametrize('num_gpus_per_worker', [4], indirect=True) 44 | def test_zero_num_slots(num_workers, num_gpus_per_worker): 45 | with pytest.raises(ValueError): 46 | result = MirroredStrategyRunner(num_slots=0).run(lambda: None) 47 | 48 | @pytest.mark.parametrize('num_workers', [2], indirect=True) 49 | @pytest.mark.parametrize('num_gpus_per_worker', [4], indirect=True) 50 | @pytest.mark.parametrize('num_slots', [1, 2, 3]) 51 | @pytest.mark.parametrize('old_cuda_state', [None, '10,11,12,13']) 52 | def test_local_run(num_workers, num_gpus_per_worker, num_slots, old_cuda_state): 53 | def train_fn(): 54 | import os 55 | return os.environ['CUDA_VISIBLE_DEVICES'] 56 | 57 | if old_cuda_state is not None: 58 | mock_env = {'CUDA_VISIBLE_DEVICES': old_cuda_state} 59 | else: 60 | mock_env = {} 61 | 62 | with mock.patch.dict(os.environ, mock_env, clear=True): 63 | task_cuda_env = MirroredStrategyRunner(num_slots=num_slots, local_mode=True, gpu_resource_name='gpu').run(train_fn) 64 | gpu_set = {int(i) for i in task_cuda_env.split(',')} 65 | assert len(gpu_set) == num_slots 66 | for gpu_id in gpu_set: 67 | if old_cuda_state is not None: 68 | assert gpu_id in [10, 11, 12, 13] 69 | else: 70 | assert gpu_id in [0, 1, 2, 3] 71 | new_cuda_state = os.environ.get('CUDA_VISIBLE_DEVICES') 72 | assert old_cuda_state == new_cuda_state 73 | 74 | @pytest.mark.parametrize('num_workers', [2], indirect=True) 75 | @pytest.mark.parametrize('num_gpus_per_worker', [4], indirect=True) 76 | @pytest.mark.parametrize( 77 | 'extra_spark_configs', 78 | [{'spark.ssl.enabled': 'true'}], 79 | indirect=True, 80 | ) 81 | def test_run_on_ssl_cluster(num_workers, num_gpus_per_worker, extra_spark_configs): 82 | with pytest.raises(Exception): 83 | MirroredStrategyRunner(num_slots=2, gpu_resource_name='gpu').run(lambda: None) 84 | 85 | @pytest.mark.parametrize('num_workers', [2], indirect=True) 86 | @pytest.mark.parametrize('num_gpus_per_worker', [4], indirect=True) 87 | @pytest.mark.parametrize( 88 | 'extra_spark_configs', 89 | [{'spark.ssl.enabled': 'true', 'tensorflow.spark.distributor.ignoreSsl': 'true'}], 90 | indirect=True, 91 | ) 92 | def test_run_on_ssl_cluster_override(num_workers, num_gpus_per_worker, extra_spark_configs): 93 | MirroredStrategyRunner(num_slots=2, gpu_resource_name='gpu').run(lambda: None) 94 | 95 | @pytest.mark.parametrize('num_workers', [2], indirect=True) 96 | @pytest.mark.parametrize('num_gpus_per_worker', [4], indirect=True) 97 | def test_cpu_training_with_gpus(num_workers, num_gpus_per_worker): 98 | def train_fn(): 99 | from pyspark import BarrierTaskContext 100 | context = BarrierTaskContext.get() 101 | cuda_state = os.environ['CUDA_VISIBLE_DEVICES'] 102 | if cuda_state: 103 | num_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) 104 | else: 105 | num_gpus = 0 106 | return [int(e) for e in context.allGather(str(num_gpus))] 107 | 108 | runner = MirroredStrategyRunner(num_slots=2, use_gpu=False) 109 | assert runner.get_num_tasks() == 2 110 | gpus_used_by_each_task = runner.run(train_fn) 111 | assert gpus_used_by_each_task == [0, 0] 112 | 113 | @pytest.mark.parametrize('num_workers', [1], indirect=True) 114 | @pytest.mark.parametrize('num_gpus_per_worker', [4], indirect=True) 115 | @pytest.mark.parametrize( 116 | 'extra_spark_configs', 117 | [{'spark.executorEnv.CUDA_VISIBLE_DEVICES': '10,11,12,13'}], 118 | indirect=True, 119 | ) 120 | def test_spark_task_cuda_devices_env_support(num_workers, num_gpus_per_worker): 121 | def train_fn(): 122 | import os 123 | return os.environ['CUDA_VISIBLE_DEVICES'] 124 | 125 | for num_slots in [2, 3, 4]: 126 | runner = MirroredStrategyRunner(num_slots=num_slots) 127 | task_cuda_env = runner.run(train_fn) 128 | gpu_set = {int(i) for i in task_cuda_env.split(',')} 129 | assert len(gpu_set) == num_slots 130 | for gpu_id in gpu_set: 131 | assert gpu_id in [10, 11, 12, 13] 132 | -------------------------------------------------------------------------------- /spark/spark-tensorflow-distributor/tests/lint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | yapf --recursive --diff spark_tensorflow_distributor 6 | pylint spark_tensorflow_distributor 7 | -------------------------------------------------------------------------------- /swarm/README.md: -------------------------------------------------------------------------------- 1 | # Running Distributed TensorFlow on Docker Compose v2 And Swarm 2 | 3 | ## Prerequisite 4 | 5 | 1. You must be running Docker 1.11 or above. See the 6 | [Docker Documentation](https://docs.docker.com/v1.11/) if you 7 | want to quickly setup a swarm cluster and compose from scratch. 8 | 9 | 2. You'd better set up some shared storage such as HDFS in the cluster. If you'd like to deploy HDFS with docker, see [Run Hadoop Cluster in Docker](http://kiwenlau.blogspot.com/2015/05/quickly-build-arbitrary-size-hadoop.html) 10 | 11 | 3. [Jinja templates](http://jinja.pocoo.org/) must be installed. 12 | 13 | Before you start, you need to set up a Docker Swarm cluster and Compose. It is also preferable to set up some shared storage such as HDFS. You need to know the HDFS namenode which is needed to bring up the TensorFlow cluster. 14 | 15 | ## Steps to Run the job 16 | 17 | 1. Follow the instructions for creating the training program in the parent 18 | [README](../README.md). 19 | 20 | 2. Follow the instructions for building and pushing the Docker image in the 21 | [Docker README](../docker/README.md). 22 | 23 | 3. Copy the template file: 24 | 25 | ```sh 26 | cd ecosystem 27 | cp swarm/template.yaml.jinja docker-compose.template.jinja 28 | ``` 29 | 30 | 4. Edit the `docker-compose.template.jinja` file to edit job parameters. You need to specify the `name`, `image_name`, `train_dir` and optionally change number of worker and ps replicas. The `train_dir` must point to the directory on shared storage if you would like to use TensorBoard or sharded checkpoint. 31 | 32 | 5. Generate the compose file: 33 | 34 | ```sh 35 | mkdir /distribute-tensorflow 36 | python render_template.py docker-compose.template.jinja | tee /distribute-tensorflow/docker-compose.yml 37 | ``` 38 | 39 | 6. Run the TensorFlow Cluster 40 | 41 | 42 | ```sh 43 | cd /distribute-tensorflow 44 | docker-compose up -d 45 | ``` 46 | 47 | -------------------------------------------------------------------------------- /swarm/template.yaml.jinja: -------------------------------------------------------------------------------- 1 | {%- set name = "mnist" -%} 2 | {%- set image = "" -%} 3 | {%- set worker_replicas = 2 -%} 4 | {%- set ps_replicas = 1 -%} 5 | {%- set script = "mnist.py" -%} 6 | {%- set data_dir = "hdfs://namenode/data_dir" -%} 7 | {%- set train_dir = "hdfs://namenode/train_dir" -%} 8 | {%- set tensorboard = true %} 9 | 10 | {%- set port = 5000 -%} 11 | {%- set replicas = {"worker": worker_replicas, "ps": ps_replicas} -%} 12 | 13 | {%- macro worker_hosts() -%} 14 | {%- for i in range(worker_replicas) -%} 15 | {%- if not loop.first -%},{%- endif -%} 16 | {{ name }}-worker-{{ i }}:{{ port }} 17 | {%- endfor -%} 18 | {%- endmacro -%} 19 | 20 | {%- macro ps_hosts() -%} 21 | {%- for i in range(ps_replicas) -%} 22 | {%- if not loop.first -%},{%- endif -%} 23 | {{ name }}-ps-{{ i }}:{{ port }} 24 | {%- endfor -%} 25 | {%- endmacro -%} 26 | 27 | version: '2' 28 | services: 29 | 30 | {%- for job in ["worker", "ps"] -%} 31 | {%- for i in range(replicas[job]) %} 32 | {{ job }}-{{ i }}: 33 | image: {{ image }} 34 | container_name: {{ name }}-{{ job }}-{{ i }} 35 | command: 36 | - python 37 | - {{ script }} 38 | - "--data_dir={{ data_dir }}" 39 | - "--train_dir={{ train_dir }}" 40 | - "--task_index={{ i }}" 41 | - "--job_name={{ job }}" 42 | - "--worker_hosts={{ worker_hosts() }}" 43 | - "--ps_hosts={{ ps_hosts() }}" 44 | {% endfor %} 45 | {%- endfor -%} 46 | {%- if tensorboard %} 47 | tensorboard: 48 | image: {{ image }} 49 | command: 50 | - tensorboard 51 | - --logdir={{ train_dir }} 52 | {% endif %} 53 | --------------------------------------------------------------------------------