├── examples ├── resnet │ ├── __init__.py │ ├── resnet_cifar_spark.py │ ├── README_orig.md │ └── README.md ├── utils │ ├── mnist_reshape.py │ └── stop_streaming.py ├── mnist │ ├── mnist_data_setup.py │ ├── keras │ │ ├── mnist_inference.py │ │ ├── mnist_tf.py │ │ ├── mnist_spark.py │ │ ├── mnist_tf_ds.py │ │ └── mnist_pipeline.py │ └── estimator │ │ ├── mnist_inference.py │ │ ├── mnist_tf.py │ │ ├── mnist_spark_streaming.py │ │ ├── mnist_spark.py │ │ └── mnist_pipeline.py └── segmentation │ ├── segmentation.py │ ├── README.md │ ├── segmentation_dist.py │ └── segmentation_spark.py ├── sd.allow ├── .tidelift.yml ├── requirements.txt ├── scripts ├── stop_spark.sh ├── ec2-cloud-config.txt ├── install_spark.sh ├── start_spark.sh ├── spark-ec2 └── deploy.generic │ └── root │ └── spark-ec2 │ └── ec2-variables.sh ├── lib └── tensorflow-hadoop-1.0-SNAPSHOT.jar ├── tensorflowonspark ├── __init__.py ├── marker.py ├── compat.py ├── TFManager.py ├── TFParallel.py ├── util.py ├── gpu_info.py └── dfutil.py ├── doc └── source │ ├── tensorflowonspark.util.rst │ ├── tensorflowonspark.TFNode.rst │ ├── tensorflowonspark.dfutil.rst │ ├── tensorflowonspark.marker.rst │ ├── tensorflowonspark.gpu_info.rst │ ├── tensorflowonspark.pipeline.rst │ ├── tensorflowonspark.TFCluster.rst │ ├── tensorflowonspark.TFManager.rst │ ├── tensorflowonspark.TFParallel.rst │ ├── tensorflowonspark.TFSparkNode.rst │ ├── tensorflowonspark.reservation.rst │ ├── tensorflowonspark.reservation_client.rst │ ├── index.rst │ ├── tensorflowonspark.rst │ └── conf.py ├── .gitignore ├── src ├── test │ └── scala │ │ └── com │ │ └── yahoo │ │ └── tensorflowonspark │ │ ├── SimpleTypeParserTest.scala │ │ ├── TestData.scala │ │ ├── DFUtilTest.scala │ │ └── TFModelTest.scala └── main │ └── scala │ └── com │ └── yahoo │ └── tensorflowonspark │ ├── TFParams.scala │ ├── SimpleTypeParser.scala │ └── Inference.scala ├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── setup.py ├── tests ├── run_tests.sh ├── test.py ├── README.md ├── test_TFParallel.py ├── test_TFNode.py ├── test_dfutil.py ├── test_reservation.py ├── test_TFCluster.py └── test_pipeline.py ├── Contributing.md ├── screwdriver.yaml ├── setup.cfg ├── tox.ini ├── README.md ├── Code-of-Conduct.md └── pom.xml /examples/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sd.allow: -------------------------------------------------------------------------------- 1 | version: 1 2 | push: 3 | - screwdriver:6384 4 | -------------------------------------------------------------------------------- /.tidelift.yml: -------------------------------------------------------------------------------- 1 | tests: 2 | removed: warn 3 | unlicensed: warn 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | numpy 3 | packaging 4 | py4j 5 | pyspark 6 | scipy 7 | setuptools 8 | tensorflow 9 | -------------------------------------------------------------------------------- /scripts/stop_spark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | 3 | ${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh 4 | -------------------------------------------------------------------------------- /lib/tensorflow-hadoop-1.0-SNAPSHOT.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yahoo/TensorFlowOnSpark/HEAD/lib/tensorflow-hadoop-1.0-SNAPSHOT.jar -------------------------------------------------------------------------------- /scripts/ec2-cloud-config.txt: -------------------------------------------------------------------------------- 1 | disable_root: false 2 | runcmd: 3 | - source /root/.bash_profile 4 | - sudo ln /dev/null /dev/raw1394 5 | 6 | -------------------------------------------------------------------------------- /tensorflowonspark/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s (%(threadName)s-%(process)d) %(message)s") 4 | 5 | __version__ = "2.2.5" 6 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.util.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.util module 2 | ============================== 3 | 4 | .. automodule:: tensorflowonspark.util 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.TFNode.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.TFNode module 2 | ================================ 3 | 4 | .. automodule:: tensorflowonspark.TFNode 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.dfutil.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.dfutil module 2 | ================================ 3 | 4 | .. automodule:: tensorflowonspark.dfutil 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.marker.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.marker module 2 | ================================ 3 | 4 | .. automodule:: tensorflowonspark.marker 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.gpu_info.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.gpu\_info module 2 | =================================== 3 | 4 | .. automodule:: tensorflowonspark.gpu_info 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.pipeline.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.pipeline module 2 | ================================== 3 | 4 | .. automodule:: tensorflowonspark.pipeline 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.TFCluster.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.TFCluster module 2 | =================================== 3 | 4 | .. automodule:: tensorflowonspark.TFCluster 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.TFManager.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.TFManager module 2 | =================================== 3 | 4 | .. automodule:: tensorflowonspark.TFManager 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.TFParallel.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.TFParallel module 2 | =================================== 3 | 4 | .. automodule:: tensorflowonspark.TFParallel 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | docs/.doctrees 4 | *.py[cod] 5 | *$py.class 6 | *.doctree 7 | *.log 8 | *.jar 9 | .DS_Store 10 | target 11 | test-data 12 | dependency-reduced-pom.xml 13 | venv 14 | .idea -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.TFSparkNode.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.TFSparkNode module 2 | ===================================== 3 | 4 | .. automodule:: tensorflowonspark.TFSparkNode 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.reservation.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.reservation module 2 | ===================================== 3 | 4 | .. automodule:: tensorflowonspark.reservation 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.reservation_client.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark\.reservation\_client module 2 | ============================================= 3 | 4 | .. automodule:: tensorflowonspark.reservation_client 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /examples/utils/mnist_reshape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | import sys 6 | import numpy as np 7 | vec = [int(x) for x in next(sys.stdin).split(',')] 8 | img = np.reshape(vec[1:], (28, 28, 1)) 9 | print(np.array2string(img).replace('\n ', ',')) 10 | -------------------------------------------------------------------------------- /scripts/install_spark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | 3 | # Install JDK8 4 | yum install -y java-1.8.0-openjdk 5 | export JAVA_HOME=/usr/lib/jvm/jre-1.8.0 6 | 7 | # Install Spark 8 | export SPARK_VERSION=3.1.2 9 | export HADOOP_VERSION=2.7 10 | curl -LO https://downloads.apache.org/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz 11 | 12 | export SPARK_HOME=/opt/spark 13 | mkdir $SPARK_HOME 14 | tar -xf spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz -C $SPARK_HOME --strip-components=1 15 | -------------------------------------------------------------------------------- /scripts/start_spark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | #export SPARK_HOME=/opt/spark 3 | #export SPARK_LOCAL_IP=127.0.0.1 4 | #export PATH=$SPARK_HOME/bin:$PATH 5 | # 6 | ## Start Spark Standalone Cluster 7 | #export SPARK_CLASSPATH=./lib/tensorflow-hadoop-1.0-SNAPSHOT.jar 8 | #export MASTER=spark://$(hostname):7077 9 | #export SPARK_WORKER_INSTANCES=2; export CORES_PER_WORKER=1 10 | #export TOTAL_CORES=$((${CORES_PER_WORKER}*${SPARK_WORKER_INSTANCES})) 11 | 12 | ${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 1G ${MASTER} 13 | -------------------------------------------------------------------------------- /src/test/scala/com/yahoo/tensorflowonspark/SimpleTypeParserTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2018 Yahoo Inc. 3 | * Licensed under the terms of the Apache 2.0 license. 4 | * Please see LICENSE file in the project root for terms. 5 | */ 6 | package com.yahoo.tensorflowonspark 7 | 8 | import org.scalatest.FunSuite 9 | 10 | class SimpleTypeParserTest extends FunSuite with TestData { 11 | test("parse simple type string as schema") { 12 | val s = schema.simpleString 13 | val parsed = SimpleTypeParser.parse(s) 14 | assert(parsed == schema) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /tensorflowonspark/marker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import nested_scopes 8 | from __future__ import print_function 9 | 10 | 11 | class Marker(object): 12 | """Base class for special marker objects in the data queue""" 13 | pass 14 | 15 | 16 | class EndPartition(Marker): 17 | """Marks the end of an RDD Partition during data feeding""" 18 | pass 19 | -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | .. TensorFlowOnSpark documentation master file, created by 2 | sphinx-quickstart on Fri Sep 7 15:30:10 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to TensorFlowOnSpark's documentation! 7 | ============================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 3 11 | :caption: Contents: 12 | 13 | tensorflowonspark 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | -------------------------------------------------------------------------------- /examples/utils/stop_streaming.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | """ 5 | Simple utility to shutdown a Spark StreamingContext by signaling the reservation Server. 6 | Note: use the reservation server address (host, port) reported in the driver logs. 7 | """ 8 | 9 | from tensorflowonspark import reservation 10 | import sys 11 | 12 | if __name__ == "__main__": 13 | host = sys.argv[1] 14 | port = int(sys.argv[2]) 15 | addr = (host, port) 16 | client = reservation.Client(addr) 17 | client.request_stop() 18 | client.close() 19 | -------------------------------------------------------------------------------- /doc/source/tensorflowonspark.rst: -------------------------------------------------------------------------------- 1 | tensorflowonspark package 2 | ========================= 3 | 4 | .. automodule:: tensorflowonspark 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | 14 | tensorflowonspark.TFCluster 15 | tensorflowonspark.TFManager 16 | tensorflowonspark.TFNode 17 | tensorflowonspark.TFParallel 18 | tensorflowonspark.TFSparkNode 19 | tensorflowonspark.dfutil 20 | tensorflowonspark.gpu_info 21 | tensorflowonspark.marker 22 | tensorflowonspark.pipeline 23 | tensorflowonspark.reservation 24 | tensorflowonspark.reservation_client 25 | tensorflowonspark.util 26 | 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | 5 | --- 6 | 7 | **Environment:** 8 | - Python version [e.g. 2.7, 3.6] 9 | - Spark version [e.g. 2.1, 2.3.1] 10 | - TensorFlow version [e.g. 1.5, 1.9.0] 11 | - TensorFlowOnSpark version [e.g. 1.1, 1.3.2] 12 | - Cluster version [e.g. Standalone, Hadoop 2.8, CDH5] 13 | 14 | **Describe the bug:** 15 | A clear and concise description of what the bug is. 16 | 17 | **Logs:** 18 | If applicable, add logs to help explain your problem. Note: errors may not be fully described in the driver/console logs. Make sure to check the executor logs for possible root causes. 19 | 20 | **Spark Submit Command Line:** 21 | If applicable, add your spark-submit command line. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2017, Yahoo Inc. 3 | # Licensed under the terms of the apache license. See the LICENSE file in the project root for terms 4 | """ 5 | Package setup file for python module 'tensorflowonspark' 6 | """ 7 | import setuptools 8 | import sys 9 | 10 | 11 | def setuptools_version_supported(): 12 | major, minor, patch = setuptools.__version__.split('.') 13 | if int(major) > 38: 14 | return True 15 | return False 16 | 17 | 18 | if __name__ == '__main__': 19 | # Check for a working version of setuptools here because earlier versions did not 20 | # support python_requires. 21 | if not setuptools_version_supported(): 22 | print('Setuptools version 38.0.0 or higher is needed to install this package') 23 | sys.exit(1) 24 | 25 | # We're being run from the command line so call setup with our arguments 26 | setuptools.setup() 27 | -------------------------------------------------------------------------------- /tests/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) 4 | 5 | if [ -z "$SPARK_HOME" ]; then 6 | echo "Please set SPARK_HOME environment variable" 7 | exit 1 8 | fi 9 | 10 | if [ -z "$SPARK_CLASSPATH" ]; then 11 | echo "Please add the path to tensorflow-hadoop-*.jar to the SPARK_CLASSPATH environment variable" 12 | exit 1 13 | fi 14 | 15 | # Start Spark Standalone Cluster 16 | export MASTER=spark://$(hostname):7077 17 | export SPARK_WORKER_INSTANCES=2; export CORES_PER_WORKER=1 18 | export TOTAL_CORES=$((${CORES_PER_WORKER}*${SPARK_WORKER_INSTANCES})) 19 | ${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 1G ${MASTER} 20 | 21 | # Run tests 22 | python -m unittest discover -s $DIR 23 | EXIT_CODE=$? 24 | 25 | # Stop Spark Standalone Cluster 26 | ${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh 27 | 28 | exit $EXIT_CODE 29 | -------------------------------------------------------------------------------- /examples/resnet/resnet_cifar_spark.py: -------------------------------------------------------------------------------- 1 | import resnet_cifar_dist 2 | 3 | if __name__ == '__main__': 4 | # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 5 | # absl_app.run(main) 6 | from pyspark.context import SparkContext 7 | from pyspark.conf import SparkConf 8 | from tensorflowonspark import TFCluster 9 | import argparse 10 | 11 | sc = SparkContext(conf=SparkConf().setAppName("resnet_cifar")) 12 | executors = sc._conf.get("spark.executor.instances") 13 | num_executors = int(executors) if executors is not None else 1 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--cluster_size", help="number of nodes in the cluster (for Spark Standalone)", type=int, default=num_executors) 17 | parser.add_argument("--num_ps", help="number of parameter servers", type=int, default=0) 18 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 19 | args, rem = parser.parse_known_args() 20 | 21 | cluster = TFCluster.run(sc, resnet_cifar_dist.main_fun, rem, args.cluster_size, args.num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, master_node='chief') 22 | cluster.shutdown() 23 | -------------------------------------------------------------------------------- /scripts/spark-ec2: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one 5 | # or more contributor license agreements. See the NOTICE file 6 | # distributed with this work for additional information 7 | # regarding copyright ownership. The ASF licenses this file 8 | # to you under the Apache License, Version 2.0 (the 9 | # "License"); you may not use this file except in compliance 10 | # with the License. You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | # 21 | # This script is cloned from https://github.com/amplab/spark-ec2/blob/branch-1.6/spark-ec2. 22 | # 23 | 24 | # Preserve the user's CWD so that relative paths are passed correctly to 25 | #+ the underlying Python script. 26 | SPARK_EC2_DIR="$(dirname "$0")" 27 | 28 | python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" 29 | -------------------------------------------------------------------------------- /tensorflowonspark/compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Yahoo Inc 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | """Helper functions to abstract API changes between TensorFlow versions, intended for end-user TF code.""" 5 | 6 | import tensorflow as tf 7 | from packaging import version 8 | 9 | 10 | def export_saved_model(model, export_dir, is_chief=False): 11 | if version.parse(tf.__version__) < version.parse('2.1.0'): 12 | if is_chief: 13 | tf.keras.experimental.export_saved_model(model, export_dir) 14 | else: 15 | # non-chief nodes save to dummy location on local disk 16 | export_dir = export_dir if is_chief else 'worker_model' 17 | model.save(export_dir, save_format='tf') 18 | 19 | 20 | def disable_auto_shard(options): 21 | if version.parse(tf.__version__) < version.parse('2.1.0'): 22 | options.experimental_distribute.auto_shard = False 23 | else: 24 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF 25 | 26 | 27 | def is_gpu_available(): 28 | if version.parse(tf.__version__) < version.parse('2.1.0'): 29 | return tf.test.is_built_with_cuda() 30 | else: 31 | return len(tf.config.list_physical_devices('GPU')) > 0 32 | -------------------------------------------------------------------------------- /scripts/deploy.generic/root/spark-ec2/ec2-variables.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | # These variables are automatically filled in by the spark-ec2 script. 21 | export MASTERS="{{master_list}}" 22 | export SLAVES="{{slave_list}}" 23 | export HDFS_DATA_DIRS="{{hdfs_data_dirs}}" 24 | export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" 25 | export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" 26 | export MODULES="{{modules}}" 27 | export SPARK_VERSION="{{spark_version}}" 28 | export TACHYON_VERSION="{{tachyon_version}}" 29 | export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" 30 | export SWAP_MB="{{swap}}" 31 | export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" 32 | export SPARK_MASTER_OPTS="{{spark_master_opts}}" 33 | export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" 34 | export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" 35 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | from pyspark import SparkConf, SparkContext 5 | from pyspark.sql import SparkSession 6 | 7 | 8 | class SparkTest(unittest.TestCase): 9 | """Base class for unittests using Spark. Sets up and tears down a cluster per test class""" 10 | 11 | @classmethod 12 | def setUpClass(cls): 13 | master = os.getenv('MASTER') 14 | assert master is not None, "Please start a Spark standalone cluster and export MASTER to your env." 15 | 16 | num_workers = os.getenv('SPARK_WORKER_INSTANCES') 17 | assert num_workers is not None, "Please export SPARK_WORKER_INSTANCES to your env." 18 | cls.num_workers = int(num_workers) 19 | 20 | spark_jars = os.getenv('SPARK_CLASSPATH') 21 | assert spark_jars, "Please add path to tensorflow/ecosystem/hadoop jar to SPARK_CLASSPATH." 22 | 23 | cls.conf = SparkConf().set('spark.jars', spark_jars).set('spark.scheduler.barrier.maxConcurrentTasksCheck.maxFailures', 3) 24 | 25 | cls.sc = SparkContext(master, cls.__name__, conf=cls.conf) 26 | cls.spark = SparkSession.builder.getOrCreate() 27 | 28 | @classmethod 29 | def tearDownClass(cls): 30 | cls.spark.stop() 31 | cls.sc.stop() 32 | 33 | def setUp(self): 34 | print("\n===========================================================") 35 | print(self.id()) 36 | print("===========================================================\n") 37 | 38 | 39 | class SimpleTest(SparkTest): 40 | """Check that basic Spark is working""" 41 | def test_spark(self): 42 | sum = self.sc.parallelize(range(1000)).sum() 43 | self.assertEqual(sum, 499500) 44 | 45 | 46 | if __name__ == '__main__': 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Unit/Integration Tests 2 | 3 | ## Requirements 4 | 5 | Since TensorFlowOnSpark (TFoS) is literally an integration of TensorFlow and Spark, these tests assume your environment has: 6 | - Spark installed at `${SPARK_HOME}` 7 | - Python installed with tensorflow 8 | - TFoS installed via `pip install -e .` (for easier coding/iteration) 9 | 10 | Note: the tests that use Spark will require a local Spark Standalone cluster (vs. Spark Local mode), since TFoS assumes that the executors run in separate processes. This is true for distributed clusters (Standalone and YARN), but not true for the non-distributed Spark Local mode, since the executors are just launched as threads in a single process. 11 | 12 | ## Instructions 13 | 14 | 1. Setup ENV variables to point to your Spark and TensorFlowOnSpark. 15 | ```bash 16 | export SPARK_HOME= 17 | export TFoS_HOME= 18 | export PYTHONPATH=${SPARK_HOME}/python 19 | export SPARK_CLASSPATH=${TFoS_HOME}/lib/tensorflow-hadoop-1.0-SNAPSHOT.jar 20 | ``` 21 | 2. Run script to automatically start Spark Standalone cluster, run all tests, and shutdown the cluster, OR 22 | ```bash 23 | cd ${TFoS_HOME}/tests 24 | ./run_tests.sh 25 | ``` 26 | 3. OPTIONAL: manually start/stop the Spark Standalone cluster (when iterating on code). 27 | ``` 28 | # Start Spark Standalone cluster 29 | export MASTER=spark://$(hostname):7077 30 | export SPARK_WORKER_INSTANCES=2; export CORES_PER_WORKER=1 31 | export TOTAL_CORES=$((${CORES_PER_WORKER}*${SPARK_WORKER_INSTANCES})) 32 | ${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 3G ${MASTER} 33 | 34 | # Develop code, run tests, repeat... 35 | cd ${TFoS_HOME}/tests 36 | python -m unittest discover 37 | 38 | # Stop Spark Standalone cluster when done 39 | ${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh 40 | ``` 41 | -------------------------------------------------------------------------------- /src/test/scala/com/yahoo/tensorflowonspark/TestData.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2018 Yahoo Inc. 3 | * Licensed under the terms of the Apache 2.0 license. 4 | * Please see LICENSE file in the project root for terms. 5 | */ 6 | package com.yahoo.tensorflowonspark 7 | 8 | import org.apache.spark.sql.Row 9 | import org.apache.spark.sql.types._ 10 | 11 | trait TestData { 12 | val row1 = Row("one".getBytes, true, 1, 1L, 1.0f, 1.0, "one", 13 | Seq[Array[Byte]]("one".getBytes, "two".getBytes, "three".getBytes), 14 | Seq[Boolean](true, true, true), 15 | Seq[Int](1, 2, 3), 16 | Seq[Long](1L, 2L, 3L), 17 | Seq[Float](1.0f, 1.1f, 1.2f), 18 | Seq[Double](1.0, 1.1, 1.2), 19 | Seq[String]("one", "two", "three")) 20 | val row2 = Row("foo".getBytes, false, 2, 2L, 2.0f, 2.0, "foo", 21 | Seq[Array[Byte]]("foo".getBytes, "bar".getBytes, "baz".getBytes), 22 | Seq[Boolean](false, false, false), 23 | Seq[Int](4, 5, 6), 24 | Seq[Long](4L, 5L, 6L), 25 | Seq[Float](2.0f, 2.1f, 2.2f), 26 | Seq[Double](2.0, 2.1, 2.2), 27 | Seq[String]("foo", "bar", "baz")) 28 | 29 | val listRows = List(row1, row2) 30 | 31 | val schema = StructType(Array( 32 | StructField("binary", BinaryType), 33 | StructField("bool", BooleanType), 34 | StructField("int", IntegerType), 35 | StructField("long", LongType), 36 | StructField("float", FloatType), 37 | StructField("double", DoubleType), 38 | StructField("string", StringType), 39 | StructField("arrayBinary", ArrayType(BinaryType)), 40 | StructField("arrayBool", ArrayType(BooleanType)), 41 | StructField("arrayInt", ArrayType(IntegerType)), 42 | StructField("arrayLong", ArrayType(LongType)), 43 | StructField("arrayFloat", ArrayType(FloatType)), 44 | StructField("arrayDouble", ArrayType(DoubleType)), 45 | StructField("arrayString", ArrayType(StringType)) 46 | )) 47 | } 48 | -------------------------------------------------------------------------------- /tests/test_TFParallel.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import test 3 | from tensorflowonspark import TFParallel 4 | 5 | 6 | class TFParallelTest(test.SparkTest): 7 | 8 | @classmethod 9 | def setUpClass(cls): 10 | super(TFParallelTest, cls).setUpClass() 11 | 12 | @classmethod 13 | def tearDownClass(cls): 14 | super(TFParallelTest, cls).tearDownClass() 15 | 16 | def test_basic_tf(self): 17 | """Single-node TF graph (w/ args) running independently on multiple executors.""" 18 | def _map_fun(args, ctx): 19 | import tensorflow as tf 20 | x = tf.constant(args['x']) 21 | y = tf.constant(args['y']) 22 | sum = tf.math.add(x, y) 23 | assert sum.numpy() == 3 24 | 25 | args = {'x': 1, 'y': 2} 26 | TFParallel.run(self.sc, _map_fun, tf_args=args, num_executors=self.num_workers, use_barrier=False) 27 | 28 | def test_basic_tf_barrier(self): 29 | """Single-node TF graph (w/ args) running independently on multiple executors using Spark barrier.""" 30 | def _map_fun(args, ctx): 31 | import tensorflow as tf 32 | x = tf.constant(args['x']) 33 | y = tf.constant(args['y']) 34 | sum = tf.math.add(x, y) 35 | assert sum.numpy() == 3 36 | 37 | args = {'x': 1, 'y': 2} 38 | TFParallel.run(self.sc, _map_fun, tf_args=args, num_executors=self.num_workers) 39 | 40 | def test_basic_tf_barrier_insufficient_resources(self): 41 | """Single-node TF graph (w/ args) running independently on multiple executors using Spark barrier with insufficient resource.""" 42 | def _map_fun(args, ctx): 43 | import tensorflow as tf 44 | x = tf.constant(args['x']) 45 | y = tf.constant(args['y']) 46 | sum = tf.math.add(x, y) 47 | assert sum.numpy() == 3 48 | 49 | args = {'x': 1, 'y': 2} 50 | with self.assertRaises(Exception): 51 | TFParallel.run(self.sc, _map_fun, tf_args=args, num_executors=self.num_workers + 1) 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /Contributing.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | First, thanks for taking the time to contribute to our project! There are many ways you can help out. 3 | 4 | ### Questions 5 | 6 | If you have a question that needs an answer, [create an issue](https://help.github.com/articles/creating-an-issue/), and label it as a question. 7 | 8 | ### Issues for bugs or feature requests 9 | 10 | If you encounter any bugs in the code, or want to request a new feature or enhancement, please [create an issue](https://help.github.com/articles/creating-an-issue/) to report it. Kindly add a label to indicate what type of issue it is. 11 | 12 | ### Contribute Code 13 | We welcome your pull requests for bug fixes. To implement something new, please create an issue first so we can discuss it together. 14 | 15 | ***Creating a Pull Request*** 16 | Please follow [best practices](https://github.com/trein/dev-best-practices/wiki/Git-Commit-Best-Practices) for creating git commits. 17 | 18 | When your code is ready to be submitted, [submit a pull request](https://help.github.com/articles/creating-a-pull-request/) to begin the code review process. 19 | 20 | We only seek to accept code that you are authorized to contribute to the project. We have added a pull request template on our projects so that your contributions are made with the following confirmation: 21 | 22 | > I confirm that this contribution is made under the terms of the license found in the root directory of this repository's source tree and that I have the authority necessary to make this contribution on behalf of its copyright owner. 23 | 24 | ## Code of Conduct 25 | 26 | We encourage inclusive and professional interactions on our project. We welcome everyone to open an issue, improve the documentation, report bug or ssubmit a pull request. By participating in this project, you agree to abide by the [Yahoo Code of Conduct](Code-of-Conduct.md). If you feel there is a conduct issue related to this project, please raise it per the Code of Conduct process and we will address it. 27 | -------------------------------------------------------------------------------- /src/main/scala/com/yahoo/tensorflowonspark/TFParams.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2018 Yahoo Inc. 3 | * Licensed under the terms of the Apache 2.0 license. 4 | * Please see LICENSE file in the project root for terms. 5 | */ 6 | package com.yahoo.tensorflowonspark 7 | 8 | import org.apache.spark.ml.param.{Param, Params} 9 | 10 | trait TFParams extends Params { 11 | 12 | final val batchSize: Param[Int] = new Param[Int](this, "batchSize", 13 | "Batch size for consuming input data. Default: 128") 14 | final def getBatchSize: Int = $(batchSize) 15 | final def setBatchSize(i: Int): this.type = set(batchSize, i) 16 | setDefault(batchSize, 128) 17 | 18 | final val model: Param[String] = new Param[String](this, "model", 19 | "Path to TensorFlow saved_model file") 20 | final def getModel: String = $(model) 21 | final def setModel(s: String): this.type = set(model, s) 22 | setDefault(model, "") 23 | 24 | final val tag: Param[String] = new Param[String](this, "tag", 25 | "String tag for graph within model. Default: \"serve\"") 26 | final def getTag: String = $(tag) 27 | final def setTag(s: String): this.type = set(tag, s) 28 | setDefault(tag, "serve") 29 | 30 | final val inputMapping: Param[Map[String, String]] = new Param[Map[String, String]](this, "inputMapping", 31 | "mapping of input DataFrame column name to TensorFlow input tensor name") 32 | final def getInputMapping: Map[String, String] = $(inputMapping) 33 | final def setInputMapping(m: Map[String, String]): this.type = set(inputMapping, m) 34 | setDefault(inputMapping, Map.empty[String, String]) 35 | 36 | final val outputMapping: Param[Map[String, String]] = new Param[Map[String, String]](this, "outputMapping", 37 | "mapping of TensorFlow output tensor name to output DataFrame column name") 38 | final def getOutputMapping: Map[String, String] = $(outputMapping) 39 | final def setOutputMapping(m: Map[String, String]): this.type = set(outputMapping, m) 40 | setDefault(outputMapping, Map.empty[String, String]) 41 | 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/main/scala/com/yahoo/tensorflowonspark/SimpleTypeParser.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2018 Yahoo Inc. 3 | * Licensed under the terms of the Apache 2.0 license. 4 | * Please see LICENSE file in the project root for terms. 5 | */ 6 | package com.yahoo.tensorflowonspark 7 | 8 | import org.apache.spark.sql.types._ 9 | 10 | import scala.util.parsing.combinator.RegexParsers 11 | 12 | /** 13 | * Parser which generates a StructType from a string of StructType.simpleString format. 14 | * 15 | * Currently, this supports the following base types: 16 | * - binary 17 | * - boolean 18 | * - int 19 | * - long (not a simpleString keyword, but provided here for ease of use) 20 | * - bigint 21 | * - float 22 | * - double 23 | * - string 24 | * 25 | * Additionally, this supports single-dimensional arrays of the base types. 26 | */ 27 | object SimpleTypeParser { 28 | def parse(simpleString: String): StructType = { 29 | val parser = new SimpleTypeParser 30 | parser.parseAll(parser.struct, simpleString).get 31 | } 32 | } 33 | 34 | class SimpleTypeParser extends RegexParsers { 35 | val name = "[a-zA-Z][/a-zA-Z_-]*".r 36 | 37 | def baseType: Parser[DataType] = ("binary" | "boolean" | "int" | "long" | "bigint" | "float" | "double" | "string") ^^ { 38 | case "binary" => BinaryType 39 | case "boolean" => BooleanType 40 | case "int" => IntegerType 41 | case "long" => LongType 42 | case "bigint" => LongType 43 | case "float" => FloatType 44 | case "double" => DoubleType 45 | case "string" => StringType 46 | } 47 | 48 | def arrayType: Parser[DataType] = ("array<" ~ baseType ~ ">") ^^ { 49 | case "array<" ~ bt ~ ">" => ArrayType(bt) 50 | } 51 | 52 | def dataType: Parser[DataType] = baseType | arrayType 53 | 54 | def field: Parser[StructField] = (name ~ ":" ~ dataType) ^^ { 55 | case n ~ ":" ~ t => StructField(n, t) 56 | } 57 | def fieldList: Parser[Seq[StructField]] = (field ~ opt("," ~ fieldList)) ^^ { 58 | case f ~ None => Seq(f) 59 | case f ~ Some("," ~ fl) => f +: fl 60 | } 61 | def struct: Parser[StructType] = ("struct<" ~ fieldList ~ ">") ^^ { 62 | case "struct<" ~ fl ~ ">" => StructType(fl) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /screwdriver.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Yahoo Inc. 2 | # Licensed under the terms of the apache license. See the LICENSE file in the project root for terms 3 | 4 | version: 4 5 | shared: 6 | environment: 7 | PACKAGE_DIRECTORY: tensorflowonspark 8 | SPARK_HOME: ${SD_ROOT_DIR}/spark 9 | TOX_ARGS: '--verbose' 10 | TOX_ENVLIST: py37 11 | annotations: 12 | screwdriver.cd/cpu: HIGH 13 | screwdriver.cd/ram: HIGH 14 | 15 | jobs: 16 | validate_test: 17 | template: python/validate_unittest 18 | requires: [~commit] 19 | steps: 20 | - prevalidate_code: | 21 | source scripts/install_spark.sh 22 | 23 | validate_lint: 24 | template: python/validate_lint 25 | requires: [~commit] 26 | 27 | validate_codestyle: 28 | template: python/validate_codestyle 29 | requires: [~commit] 30 | 31 | validate_dependencies: 32 | template: python/validate_dependencies 33 | requires: [~commit] 34 | 35 | # validate_security: 36 | # template: python/validate_security 37 | # requires: [~commit] 38 | 39 | validate_documentation: 40 | template: python/documentation 41 | environment: 42 | DOCUMENTATION_PUBLISH: False 43 | requires: [~commit] 44 | steps: 45 | - update_version: | 46 | echo 'using version from setup.cfg' 47 | - publish_documentation: | 48 | $BASE_PYTHON -m pip install sphinx_rtd_theme tensorflow 49 | $BASE_PYTHON -m screwdrivercd.documentation 50 | 51 | publish_test_pypi: 52 | template: python/package_python 53 | environment: 54 | PUBLISH: True 55 | TWINE_REPOSITORY_URL: https://test.pypi.org/legacy/ 56 | requires: [~tag:/^v\.*/] 57 | steps: 58 | - update_version: | 59 | echo 'using version from setup.cfg' 60 | 61 | publish_pypi: 62 | template: python/package_python 63 | environment: 64 | PUBLISH: True 65 | requires: [publish_test_pypi] 66 | steps: 67 | - update_version: | 68 | echo 'using version from setup.cfg' 69 | 70 | publish_documentation: 71 | template: python/documentation 72 | requires: [publish_pypi] 73 | steps: 74 | - update_version: | 75 | echo 'using version from setup.cfg' 76 | - publish_documentation: | 77 | $BASE_PYTHON -m pip install sphinx_rtd_theme tensorflow 78 | $BASE_PYTHON -m screwdrivercd.documentation 79 | -------------------------------------------------------------------------------- /tests/test_TFNode.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import os 3 | import unittest 4 | from tensorflowonspark import TFManager, TFNode 5 | 6 | 7 | class TFNodeTest(unittest.TestCase): 8 | def test_hdfs_path(self): 9 | """Normalization of absolution & relative string paths depending on filesystem""" 10 | cwd = os.getcwd() 11 | user = getpass.getuser() 12 | fs = ["file://", "hdfs://", "viewfs://"] 13 | paths = { 14 | "hdfs://foo/bar": ["hdfs://foo/bar", "hdfs://foo/bar", "hdfs://foo/bar"], 15 | "viewfs://foo/bar": ["viewfs://foo/bar", "viewfs://foo/bar", "viewfs://foo/bar"], 16 | "file://foo/bar": ["file://foo/bar", "file://foo/bar", "file://foo/bar"], 17 | "/foo/bar": ["file:///foo/bar", "hdfs:///foo/bar", "viewfs:///foo/bar"], 18 | "foo/bar": ["file://{}/foo/bar".format(cwd), "hdfs:///user/{}/foo/bar".format(user), "viewfs:///user/{}/foo/bar".format(user)], 19 | } 20 | 21 | for i in range(len(fs)): 22 | ctx = type('MockContext', (), {'defaultFS': fs[i], 'working_dir': cwd}) 23 | for path, expected in paths.items(): 24 | final_path = TFNode.hdfs_path(ctx, path) 25 | self.assertEqual(final_path, expected[i], "fs({}) + path({}) => {}, expected {}".format(fs[i], path, final_path, expected[i])) 26 | 27 | def test_datafeed(self): 28 | """TFNode.DataFeed basic operations""" 29 | mgr = TFManager.start('abc'.encode('utf-8'), ['input', 'output'], 'local') 30 | 31 | # insert 10 numbers followed by an end-of-feed marker 32 | q = mgr.get_queue('input') 33 | for i in range(10): 34 | q.put(i) 35 | q.put(None) # end-of-feed marker 36 | 37 | feed = TFNode.DataFeed(mgr) 38 | 39 | # [0,1] 40 | self.assertFalse(feed.done_feeding) 41 | batch = feed.next_batch(2) 42 | self.assertEqual(len(batch), 2) 43 | self.assertEqual(sum(batch), 1) 44 | 45 | # [2,3,4,5] 46 | self.assertFalse(feed.done_feeding) 47 | batch = feed.next_batch(4) 48 | self.assertEqual(len(batch), 4) 49 | self.assertEqual(sum(batch), 14) 50 | 51 | # [6,7,8,9] 52 | self.assertFalse(feed.done_feeding) 53 | batch = feed.next_batch(10) # ask for more than available 54 | self.assertEqual(len(batch), 4) 55 | self.assertEqual(sum(batch), 30) 56 | 57 | # should be done 58 | self.assertTrue(feed.should_stop()) 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tensorflowonspark/TFManager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import nested_scopes 8 | from __future__ import print_function 9 | 10 | from multiprocessing.managers import BaseManager 11 | from multiprocessing import JoinableQueue 12 | 13 | 14 | class TFManager(BaseManager): 15 | """Python multiprocessing.Manager for distributed, multi-process communication.""" 16 | pass 17 | 18 | 19 | # global to each Spark executor's python worker 20 | mgr = None # TFManager 21 | qdict = {} # dictionary of queues 22 | kdict = {} # dictionary of key-values 23 | 24 | 25 | def _get(key): 26 | return kdict[key] 27 | 28 | 29 | def _set(key, value): 30 | kdict[key] = value 31 | 32 | 33 | def _get_queue(qname): 34 | try: 35 | return qdict[qname] 36 | except KeyError: 37 | return None 38 | 39 | 40 | def start(authkey, queues, mode='local'): 41 | """Create a new multiprocess.Manager (or return existing one). 42 | 43 | Args: 44 | :authkey: string authorization key 45 | :queues: *INTERNAL_USE* 46 | :mode: 'local' indicates that the manager will only be accessible from the same host, otherwise remotely accessible. 47 | 48 | Returns: 49 | A TFManager instance, which is also cached in local memory of the Python worker process. 50 | """ 51 | global mgr, qdict, kdict 52 | qdict.clear() 53 | kdict.clear() 54 | for q in queues: 55 | qdict[q] = JoinableQueue() 56 | 57 | TFManager.register('get_queue', callable=lambda qname: _get_queue(qname)) 58 | TFManager.register('get', callable=lambda key: _get(key)) 59 | TFManager.register('set', callable=lambda key, value: _set(key, value)) 60 | if mode == 'remote': 61 | mgr = TFManager(address=('', 0), authkey=authkey) 62 | else: 63 | mgr = TFManager(authkey=authkey) 64 | mgr.start() 65 | return mgr 66 | 67 | 68 | def connect(address, authkey): 69 | """Connect to a multiprocess.Manager. 70 | 71 | Args: 72 | :address: unique address to the TFManager, either a unique connection string for 'local', or a (host, port) tuple for remote. 73 | :authkey: string authorization key 74 | 75 | Returns: 76 | A TFManager instance referencing the remote TFManager at the supplied address. 77 | """ 78 | TFManager.register('get_queue') 79 | TFManager.register('get') 80 | TFManager.register('set') 81 | m = TFManager(address, authkey=authkey) 82 | m.connect() 83 | return m 84 | -------------------------------------------------------------------------------- /tensorflowonspark/TFParallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Yahoo Inc 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import nested_scopes 8 | from __future__ import print_function 9 | 10 | import logging 11 | from . import TFSparkNode 12 | from . import util 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def run(sc, map_fn, tf_args, num_executors, use_barrier=True): 18 | """Runs the user map_fn as parallel, independent instances of TF on the Spark executors. 19 | 20 | Args: 21 | :sc: SparkContext 22 | :map_fun: user-supplied TensorFlow "main" function 23 | :tf_args: ``argparse`` args, or command-line ``ARGV``. These will be passed to the ``map_fun``. 24 | :num_executors: number of Spark executors. This should match your Spark job's ``--num_executors``. 25 | :use_barrier: Boolean indicating if TFParallel should use Spark barrier execution mode to wait for all executors. 26 | 27 | Returns: 28 | None 29 | """ 30 | 31 | # get default filesystem from spark 32 | defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS") 33 | # strip trailing "root" slash from "file:///" to be consistent w/ "hdfs://..." 34 | if defaultFS.startswith("file://") and len(defaultFS) > 7 and defaultFS.endswith("/"): 35 | defaultFS = defaultFS[:-1] 36 | 37 | def _run(it): 38 | from pyspark import BarrierTaskContext 39 | 40 | for i in it: 41 | worker_num = i 42 | 43 | if use_barrier: 44 | # use BarrierTaskContext to get placement of all nodes 45 | barrier_ctx = BarrierTaskContext.get() 46 | tasks = barrier_ctx.getTaskInfos() 47 | nodes = [t.address for t in tasks] 48 | num_workers = len(nodes) 49 | else: 50 | nodes = [] 51 | num_workers = num_executors 52 | 53 | # use the placement info to help allocate GPUs 54 | # note: defaults to CPU if no GPUs present 55 | num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1 56 | util.single_node_env(num_gpus=num_gpus, worker_index=worker_num, nodes=nodes) 57 | 58 | # run the user map_fn 59 | ctx = TFSparkNode.TFNodeContext() 60 | ctx.defaultFS = defaultFS 61 | ctx.worker_num = worker_num 62 | ctx.executor_id = worker_num 63 | ctx.num_workers = num_workers 64 | 65 | map_fn(tf_args, ctx) 66 | 67 | # return a dummy iterator (since we have to use mapPartitions) 68 | return [0] 69 | 70 | nodeRDD = sc.parallelize(list(range(num_executors)), num_executors) 71 | if use_barrier: 72 | nodeRDD.barrier().mapPartitions(_run).collect() 73 | else: 74 | nodeRDD.mapPartitions(_run).collect() 75 | -------------------------------------------------------------------------------- /tests/test_dfutil.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import test 4 | import unittest 5 | 6 | from tensorflowonspark import dfutil 7 | 8 | 9 | class DFUtilTest(test.SparkTest): 10 | @classmethod 11 | def setUpClass(cls): 12 | super(DFUtilTest, cls).setUpClass() 13 | 14 | # define model_dir and export_dir for tests 15 | cls.tfrecord_dir = os.getcwd() + os.sep + "test_tfr" 16 | 17 | @classmethod 18 | def tearDownClass(cls): 19 | super(DFUtilTest, cls).tearDownClass() 20 | 21 | def setUp(self): 22 | super(DFUtilTest, self).setUp() 23 | # remove any prior test artifacts 24 | shutil.rmtree(self.tfrecord_dir, ignore_errors=True) 25 | 26 | def tearDown(self): 27 | # Note: don't clean up artifacts after test (in case we need to view/debug) 28 | pass 29 | 30 | def test_dfutils(self): 31 | # create a DataFrame of a single row consisting of standard types (str, int, int_array, float, float_array, binary) 32 | row1 = ('text string', 1, [2, 3, 4, 5], -1.1, [-2.2, -3.3, -4.4, -5.5], bytearray(b'\xff\xfe\xfd\xfc')) 33 | rdd = self.sc.parallelize([row1]) 34 | df1 = self.spark.createDataFrame(rdd, ['a', 'b', 'c', 'd', 'e', 'f']) 35 | print("schema: {}".format(df1.schema)) 36 | 37 | # save the DataFrame as TFRecords 38 | dfutil.saveAsTFRecords(df1, self.tfrecord_dir) 39 | self.assertTrue(os.path.isdir(self.tfrecord_dir)) 40 | 41 | # reload the DataFrame from exported TFRecords 42 | df2 = dfutil.loadTFRecords(self.sc, self.tfrecord_dir, binary_features=['f']) 43 | row2 = df2.take(1)[0] 44 | 45 | print("row_saved: {}".format(row1)) 46 | print("row_loaded: {}".format(row2)) 47 | 48 | # confirm loaded values match original/saved values 49 | self.assertEqual(row1[0], row2['a']) 50 | self.assertEqual(row1[1], row2['b']) 51 | self.assertEqual(row1[2], row2['c']) 52 | self.assertAlmostEqual(row1[3], row2['d'], 6) 53 | for i in range(len(row1[4])): 54 | self.assertAlmostEqual(row1[4][i], row2['e'][i], 6) 55 | print("type(f): {}".format(type(row2['f']))) 56 | for i in range(len(row1[5])): 57 | self.assertEqual(row1[5][i], row2['f'][i]) 58 | 59 | # check origin of each DataFrame 60 | self.assertFalse(dfutil.isLoadedDF(df1)) 61 | self.assertTrue(dfutil.isLoadedDF(df2)) 62 | 63 | # references are equivalent 64 | df_ref = df2 65 | self.assertTrue(dfutil.isLoadedDF(df_ref)) 66 | 67 | # mutated DFs are not equal, even if contents are identical 68 | df3 = df2.filter(df2.a == 'string_label') 69 | self.assertFalse(dfutil.isLoadedDF(df3)) 70 | 71 | # re-used/re-assigned variables are not equal 72 | df2 = df3 73 | self.assertFalse(dfutil.isLoadedDF(df2)) 74 | 75 | 76 | if __name__ == '__main__': 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Yahoo Inc. 2 | # Licensed under the terms of the apache license. See the LICENSE file in the project root for terms 3 | [metadata] 4 | author = Lee Yang 5 | author_email = leewyang@gmail.com 6 | classifiers = 7 | Intended Audience :: Developers 8 | Intended Audience :: Science/Research 9 | License :: OSI Approved :: Apache Software License 10 | Topic :: Software Development :: Libraries 11 | Programming Language :: Python :: 3 :: Only 12 | Programming Language :: Python :: 3.7 13 | Programming Language :: Python :: 3.8 14 | Programming Language :: Python :: 3.9 15 | description = Deep learning with TensorFlow on Apache Spark clusters 16 | license = Apache 2.0 17 | long_description = file:README.md 18 | long_description_content_type = text/markdown 19 | name = tensorflowonspark 20 | url = https://github.com/yahoo/TensorFlowOnSpark 21 | version = 2.2.5 22 | 23 | [options] 24 | packages = 25 | tensorflowonspark 26 | 27 | # The install_requires should include abstract package dependencies 28 | # here (do not specify specific versions) 29 | 30 | install_requires = 31 | packaging 32 | setuptools>38.0 33 | 34 | # By default new packages require at minimum the current supported Python release. 35 | python_requires = >="3.6" 36 | zip_safe = True 37 | 38 | [options.extras_require] 39 | # This config section allows you to define optional dependencies. For the general case, the defaults will 40 | # work fine. So these settings aren't required. However, many of the screwdriver CI Pipeline steps 41 | # will install the appropriate extras for that step. This makes it possible to install packages that install 42 | # or enhance the functionality of the CI Pipeline step. 43 | # Such as packages that implement plugins or themes for the step in question. 44 | 45 | # Additional packages for testing (test step) 46 | # test = 47 | 48 | # Additonal packages needed for documentation generation (doc_build/doc_publish steps) 49 | # If you want to use a sphinx theme from a package, list it here. 50 | # doc_build = 51 | 52 | # Additional packages needed for mypy type checking 53 | # mypy = 54 | 55 | # Additional packages needed for pep8/pycodestyle style checking 56 | # pycodestyle = 57 | 58 | # Additional packages needed for pylint code analysis 59 | # pylint = 60 | 61 | [options.entry_points] 62 | # Console script entry points are used to create wrapper scripts that run a specific function, the resulting wrapper 63 | # is installed in the bin directory. 64 | 65 | # They are defined using the following format: 66 | # scriptname = modulename:function 67 | # console_scripts = 68 | # TFoS=ouroath.TFoS.cli:main 69 | 70 | [screwdrivercd.version] 71 | # Base the autoversion build number on the screwdriver build number 72 | # This requires the CI Pipeline to have a build step that runs before 73 | # any packaging steps. 74 | version_type = sdv4_SD_BUILD 75 | 76 | [bdist_wheel] 77 | universal = 1 78 | -------------------------------------------------------------------------------- /examples/mnist/mnist_data_setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | 10 | if __name__ == "__main__": 11 | import argparse 12 | 13 | from pyspark.context import SparkContext 14 | from pyspark.conf import SparkConf 15 | import tensorflow as tf 16 | import tensorflow_datasets as tfds 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--num_partitions", help="Number of output partitions", type=int, default=10) 20 | parser.add_argument("--output", help="HDFS directory to save examples in parallelized format", default="data/mnist") 21 | 22 | args = parser.parse_args() 23 | print("args:", args) 24 | 25 | sc = SparkContext(conf=SparkConf().setAppName("mnist_data_setup")) 26 | 27 | mnist, info = tfds.load('mnist', with_info=True) 28 | print(info.as_json) 29 | 30 | # convert to numpy, then RDDs 31 | mnist_train = tfds.as_numpy(mnist['train']) 32 | mnist_test = tfds.as_numpy(mnist['test']) 33 | 34 | train_rdd = sc.parallelize(mnist_train, args.num_partitions).cache() 35 | test_rdd = sc.parallelize(mnist_test, args.num_partitions).cache() 36 | 37 | # save as CSV (label,comma-separated-features) 38 | def to_csv(example): 39 | return str(example['label']) + ',' + ','.join([str(i) for i in example['image'].reshape(784)]) 40 | 41 | train_rdd.map(to_csv).saveAsTextFile(args.output + "/csv/train") 42 | test_rdd.map(to_csv).saveAsTextFile(args.output + "/csv/test") 43 | 44 | # save as TFRecords (numpy vs. PNG) 45 | # note: the MNIST tensorflow_dataset is already provided as TFRecords but with a PNG bytes_list 46 | # this export format is less-efficient, but easier to work with later 47 | def to_tfr(example): 48 | ex = tf.train.Example( 49 | features=tf.train.Features( 50 | feature={ 51 | 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['label'].astype("int64")])), 52 | 'image': tf.train.Feature(int64_list=tf.train.Int64List(value=example['image'].reshape(784).astype("int64"))) 53 | } 54 | ) 55 | ) 56 | return (bytearray(ex.SerializeToString()), None) 57 | 58 | train_rdd.map(to_tfr).saveAsNewAPIHadoopFile(args.output + "/tfr/train", 59 | "org.tensorflow.hadoop.io.TFRecordFileOutputFormat", 60 | keyClass="org.apache.hadoop.io.BytesWritable", 61 | valueClass="org.apache.hadoop.io.NullWritable") 62 | test_rdd.map(to_tfr).saveAsNewAPIHadoopFile(args.output + "/tfr/test", 63 | "org.tensorflow.hadoop.io.TFRecordFileOutputFormat", 64 | keyClass="org.apache.hadoop.io.BytesWritable", 65 | valueClass="org.apache.hadoop.io.NullWritable") 66 | -------------------------------------------------------------------------------- /tensorflowonspark/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import nested_scopes 8 | from __future__ import print_function 9 | 10 | import logging 11 | import os 12 | import socket 13 | import subprocess 14 | import errno 15 | from socket import error as socket_error 16 | from . import gpu_info 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def single_node_env(num_gpus=1, worker_index=-1, nodes=[]): 22 | """Setup environment variables for Hadoop compatibility and GPU allocation""" 23 | # ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI) 24 | if 'HADOOP_PREFIX' in os.environ and 'TFOS_CLASSPATH_UPDATED' not in os.environ: 25 | classpath = os.environ['CLASSPATH'] 26 | hadoop_path = os.path.join(os.environ['HADOOP_PREFIX'], 'bin', 'hadoop') 27 | hadoop_classpath = subprocess.check_output([hadoop_path, 'classpath', '--glob']).decode() 28 | os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath 29 | os.environ['TFOS_CLASSPATH_UPDATED'] = '1' 30 | 31 | if gpu_info.is_gpu_available() and num_gpus > 0: 32 | # reserve GPU(s), if requested 33 | if worker_index >= 0 and nodes and len(nodes) > 0: 34 | # compute my index relative to other nodes on the same host, if known 35 | my_addr = nodes[worker_index] 36 | my_host = my_addr.split(':')[0] 37 | local_peers = [n for n in nodes if n.startswith(my_host)] 38 | my_index = local_peers.index(my_addr) 39 | else: 40 | # otherwise, just use global worker index 41 | my_index = worker_index 42 | 43 | gpus_to_use = gpu_info.get_gpus(num_gpus, my_index) 44 | logger.info("Using gpu(s): {0}".format(gpus_to_use)) 45 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use 46 | else: 47 | # CPU 48 | logger.info("Using CPU") 49 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 50 | 51 | 52 | def get_ip_address(): 53 | """Simple utility to get host IP address.""" 54 | try: 55 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 56 | s.connect(("8.8.8.8", 80)) 57 | ip_address = s.getsockname()[0] 58 | except socket_error as sockerr: 59 | if sockerr.errno != errno.ENETUNREACH: 60 | raise sockerr 61 | ip_address = socket.gethostbyname(socket.getfqdn()) 62 | finally: 63 | s.close() 64 | 65 | return ip_address 66 | 67 | 68 | def find_in_path(path, file): 69 | """Find a file in a given path string.""" 70 | for p in path.split(os.pathsep): 71 | candidate = os.path.join(p, file) 72 | if os.path.exists(candidate) and os.path.isfile(candidate): 73 | return candidate 74 | return False 75 | 76 | 77 | def write_executor_id(num): 78 | """Write executor_id into a local file in the executor's current working directory""" 79 | with open("executor_id", "w") as f: 80 | f.write(str(num)) 81 | 82 | 83 | def read_executor_id(): 84 | """Read worker id from a local file in the executor's current working directory""" 85 | if os.path.isfile("executor_id"): 86 | with open("executor_id", "r") as f: 87 | return int(f.read()) 88 | else: 89 | msg = "No executor_id file found on this node, please ensure that:\n" + \ 90 | "1. Spark num_executors matches TensorFlow cluster_size\n" + \ 91 | "2. Spark tasks per executor is 1\n" + \ 92 | "3. Spark dynamic allocation is disabled\n" + \ 93 | "4. There are no other root-cause exceptions on other nodes\n" 94 | raise Exception(msg) 95 | -------------------------------------------------------------------------------- /tensorflowonspark/gpu_info.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import nested_scopes 8 | from __future__ import print_function 9 | 10 | import logging 11 | import random 12 | import subprocess 13 | import time 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | MAX_RETRIES = 3 #: Maximum retries to allocate GPUs 18 | AS_STRING = 'string' 19 | AS_LIST = 'list' 20 | 21 | 22 | def is_gpu_available(): 23 | """Determine if GPUs are available on the host""" 24 | try: 25 | subprocess.check_output(["nvidia-smi", "--list-gpus"]) 26 | return True 27 | except Exception: 28 | return False 29 | 30 | 31 | def get_gpus(num_gpu=1, worker_index=-1, format=AS_STRING): 32 | """Get list of free GPUs according to nvidia-smi. 33 | 34 | This will retry for ``MAX_RETRIES`` times until the requested number of GPUs are available. 35 | 36 | Args: 37 | :num_gpu: number of GPUs desired. 38 | :worker_index: index "hint" for allocation of available GPUs. 39 | 40 | Returns: 41 | Comma-delimited string of GPU ids, or raises an Exception if the requested number of GPUs could not be found. 42 | """ 43 | # get list of gpus (index, uuid) 44 | list_gpus = subprocess.check_output(["nvidia-smi", "--list-gpus"]).decode() 45 | logger.debug("all GPUs:\n{0}".format(list_gpus)) 46 | 47 | # parse index and guid 48 | gpus = [x for x in list_gpus.split('\n') if len(x) > 0] 49 | 50 | def parse_gpu(gpu_str): 51 | cols = gpu_str.split(' ') 52 | return cols[5].split(')')[0], cols[1].split(':')[0] 53 | 54 | gpu_list = [parse_gpu(gpu) for gpu in gpus] 55 | 56 | free_gpus = [] 57 | retries = 0 58 | while len(free_gpus) < num_gpu and retries < MAX_RETRIES: 59 | smi_output = subprocess.check_output(["nvidia-smi", "--format=csv,noheader,nounits", "--query-compute-apps=gpu_uuid"]).decode() 60 | logger.debug("busy GPUs:\n{0}".format(smi_output)) 61 | busy_uuids = [x for x in smi_output.split('\n') if len(x) > 0] 62 | for uuid, index in gpu_list: 63 | if uuid not in busy_uuids: 64 | free_gpus.append(index) 65 | 66 | if len(free_gpus) < num_gpu: 67 | logger.warn("Unable to find available GPUs: requested={0}, available={1}".format(num_gpu, len(free_gpus))) 68 | retries += 1 69 | time.sleep(30 * retries) 70 | free_gpus = [] 71 | 72 | logger.info("Available GPUs: {}".format(free_gpus)) 73 | 74 | # if still can't find available GPUs, raise exception 75 | if len(free_gpus) < num_gpu: 76 | smi_output = subprocess.check_output(["nvidia-smi", "--format=csv", "--query-compute-apps=gpu_uuid,pid,process_name,used_gpu_memory"]).decode() 77 | logger.info(": {0}".format(smi_output)) 78 | raise Exception("Unable to find {} free GPU(s)\n{}".format(num_gpu, smi_output)) 79 | 80 | # Get logical placement 81 | num_available = len(free_gpus) 82 | if worker_index == -1: 83 | # use original random placement 84 | random.shuffle(free_gpus) 85 | proposed_gpus = free_gpus[:num_gpu] 86 | else: 87 | # ordered by worker index 88 | if worker_index * num_gpu + num_gpu > num_available: 89 | worker_index = worker_index * num_gpu % num_available 90 | proposed_gpus = free_gpus[worker_index * num_gpu:(worker_index * num_gpu + num_gpu)] 91 | logger.info("Proposed GPUs: {}".format(proposed_gpus)) 92 | 93 | if format == AS_STRING: 94 | return ','.join(str(x) for x in proposed_gpus) 95 | elif format == AS_LIST: 96 | return proposed_gpus 97 | else: 98 | raise Exception("Unknown GPU format") 99 | -------------------------------------------------------------------------------- /examples/mnist/keras/mnist_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | # This example demonstrates how to leverage Spark for parallel inferencing from a SavedModel. 6 | # 7 | # Normally, you can use TensorFlowOnSpark to just form a TensorFlow cluster for training and inferencing. 8 | # However, in some situations, you may have a SavedModel without the original code for defining the inferencing 9 | # graph. In these situations, we can use Spark to instantiate a single-node TensorFlow instance on each executor, 10 | # where each executor can independently load the model and inference on input data. 11 | # 12 | # Note: this particular example demonstrates use of `tf.data.Dataset` to read the input data for inferencing, 13 | # but it could also be adapted to just use an RDD of TFRecords from Spark. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | 24 | def inference(args, ctx): 25 | 26 | # load saved_model 27 | saved_model = tf.saved_model.load(args.export_dir, tags='serve') 28 | predict = saved_model.signatures['serving_default'] 29 | 30 | # parse function for TFRecords 31 | def parse_tfr(example_proto): 32 | feature_def = {"label": tf.io.FixedLenFeature(1, tf.int64), 33 | "image": tf.io.FixedLenFeature(784, tf.int64)} 34 | features = tf.io.parse_single_example(serialized=example_proto, features=feature_def) 35 | image = tf.cast(features['image'], dtype=tf.float32) / 255.0 36 | image = tf.reshape(image, [28, 28, 1]) 37 | label = tf.cast(features['label'], dtype=tf.float32) 38 | return (image, label) 39 | 40 | # define a new tf.data.Dataset (for inferencing) 41 | ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False) 42 | ds = ds.shard(ctx.num_workers, ctx.worker_num) 43 | ds = ds.interleave(tf.data.TFRecordDataset) 44 | ds = ds.map(parse_tfr) 45 | ds = ds.batch(10) 46 | 47 | # create an output file per spark worker for the predictions 48 | tf.io.gfile.makedirs(args.output) 49 | output_file = tf.io.gfile.GFile("{}/part-{:05d}".format(args.output, ctx.worker_num), mode='w') 50 | 51 | for batch in ds: 52 | predictions = predict(conv2d_input=batch[0]) 53 | labels = np.reshape(batch[1], -1).astype(np.int) 54 | preds = np.argmax(predictions['dense_1'], axis=1) 55 | for x in zip(labels, preds): 56 | output_file.write("{} {}\n".format(x[0], x[1])) 57 | 58 | output_file.close() 59 | 60 | 61 | if __name__ == '__main__': 62 | from pyspark.context import SparkContext 63 | from pyspark.conf import SparkConf 64 | from tensorflowonspark import TFParallel 65 | 66 | sc = SparkContext(conf=SparkConf().setAppName("mnist_inference")) 67 | executors = sc._conf.get("spark.executor.instances") 68 | num_executors = int(executors) if executors is not None else 1 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--cluster_size", help="number of nodes in the cluster (for S with labelspark Standalone)", type=int, default=num_executors) 72 | parser.add_argument('--images_labels', type=str, help='Directory for input images with labels') 73 | parser.add_argument("--export_dir", help="HDFS path to export model", type=str, default="mnist_export") 74 | parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions") 75 | args, _ = parser.parse_known_args() 76 | print("args: {}".format(args)) 77 | 78 | # Running single-node TF instances on each executor 79 | TFParallel.run(sc, inference, args, args.cluster_size) 80 | -------------------------------------------------------------------------------- /src/main/scala/com/yahoo/tensorflowonspark/Inference.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2018 Yahoo Inc. 3 | * Licensed under the terms of the Apache 2.0 license. 4 | * Please see LICENSE file in the project root for terms. 5 | */ 6 | package com.yahoo.tensorflowonspark 7 | 8 | import org.apache.spark.sql.SparkSession 9 | import org.apache.spark.sql.types._ 10 | import org.apache.spark.{SparkConf, SparkContext} 11 | import org.json4s._ 12 | import org.json4s.native.JsonMethods 13 | 14 | /** 15 | * Spark application that infers from a TensorFlow SavedModel. 16 | */ 17 | object Inference { 18 | 19 | case class Config(export_dir: String = "", 20 | input: String = "", 21 | schema_hint: StructType = new StructType(), 22 | input_mapping: Map[String, String] = Map.empty, 23 | output_mapping: Map[String, String] = Map.empty, 24 | output: String = "", 25 | verbose: Boolean = false) 26 | 27 | def main(args: Array[String]) { 28 | val conf = new SparkConf().setAppName("Inference") 29 | implicit val sc: SparkContext = new SparkContext(conf) 30 | val parser = new scopt.OptionParser[Config]("Inference") { 31 | opt[String]("export_dir").text("Path to exported saved_model") 32 | .action((x, conf) => conf.copy(export_dir = x)) 33 | opt[String]("input").text("Path to input TFRecords") 34 | .action((x, conf) => conf.copy(input = x)) 35 | opt[String]("schema_hint").text("schema hint (in StructType.simpleString format) for converting TFRecord features to Spark DataFrame types") 36 | .action{ case (schema, conf) => conf.copy(schema_hint = SimpleTypeParser.parse(schema)) } 37 | opt[String]("input_mapping").text("JSON mapping of input columns to input tensors") 38 | .action((x, conf) => conf.copy(input_mapping = JsonMethods.parse(x).values.asInstanceOf[Map[String, String]])) 39 | opt[String]("output_mapping").text("JSON mapping of output tensors to output columns") 40 | .action((x, conf) => conf.copy(output_mapping = JsonMethods.parse(x).values.asInstanceOf[Map[String, String]])) 41 | opt[String]("output").text("Path to write predictions").action((x, conf) => conf.copy(output = x)) 42 | opt[Unit]("verbose").text("Print input dataframe sample with schema").action((_, conf) => conf.copy(verbose = true)) 43 | } 44 | 45 | parser.parse(args, Config()) match { 46 | case Some(config) => run(sc, config) 47 | case None => System.exit(1) 48 | } 49 | sc.stop() 50 | } 51 | 52 | def run(implicit sc: SparkContext, config: Config) { 53 | 54 | implicit val spark: SparkSession = SparkSession.builder().getOrCreate() 55 | 56 | // load TFRecords as a Spark DataFrame (using a user-provided schema hint) 57 | val df = DFUtil.loadTFRecords(config.input, config.schema_hint) 58 | if (config.verbose) { 59 | df.show() 60 | df.printSchema() 61 | } 62 | 63 | // instantiate a TFModel pointing to an existing TensorFlow saved_model export 64 | // set up mappings between input DataFrame columns to input Tensors 65 | // and output Tensors to output DataFrame columns 66 | // Note: the output DataFrame column types will be inferred from the output Tensor dtypes 67 | val model = new TFModel().setModel(config.export_dir) 68 | .setInputMapping(config.input_mapping) 69 | .setOutputMapping(config.output_mapping) 70 | 71 | // transform the input DataFrame 72 | // Note: we're currently dropping input columns for simplicity, you can retrieve them as Tensors if needed. 73 | val predDF = model.transform(df) 74 | 75 | // write the predictions 76 | predDF.write.json(config.output) 77 | 78 | spark.stop() 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Copyright 2017, Yahoo Inc. 2 | # Licensed under the terms of the apache license. See the LICENSE file in the project root for terms 3 | [config] 4 | package_dir = tensorflowonspark 5 | package_name = tensorflowonspark 6 | 7 | [tox] 8 | envlist = py37 9 | skip_missing_interpreters = true 10 | 11 | [testenv] 12 | allowlist_externals = 13 | bash 14 | changedir = {toxinidir} 15 | commands = 16 | /bin/bash scripts/start_spark.sh 17 | python -m unittest discover -s tests 18 | # pytest --junitxml=pytest_{envname}.xml -o junit_suite_name={envname} --cov={[config]package_name} --cov-report=xml:coverage.xml --cov-report term-missing tests/ 19 | /bin/bash scripts/stop_spark.sh 20 | deps = 21 | -rrequirements.txt 22 | coverage 23 | six 24 | pytest 25 | pytest-cov 26 | install_command = {envpython} {envbindir}/pip install {opts} {packages} 27 | list_dependencies_command = {envpython} {envbindir}/pip freeze 28 | passenv = SSH_AUTH_SOCK BUILD_NUMBER HOSTNAME SPARK_HOME SPARK_LOCAL_IP MASTER SPARK_WORKER_INSTANCES SPARK_CLASSPATH CORES_PER_WORKER 29 | setenv = 30 | SPARK_LOCAL_IP = 127.0.0.1 31 | MASTER = spark://{env:HOSTNAME}:7077 32 | PYTHONPATH = /opt/spark/python 33 | SPARK_CLASSPATH = ./lib/tensorflow-hadoop-1.0-SNAPSHOT.jar 34 | SPARK_WORKER_INSTANCES = 2 35 | CORES_PER_WORKER = 1 36 | extras = 37 | test 38 | 39 | [testenv:coverage] 40 | commands = 41 | coverage combine -a 42 | coverage report -m --skip-covered 43 | deps = 44 | coverage 45 | six 46 | pytest 47 | pytest-cov 48 | skip_install = true 49 | 50 | [testenv:lint_codestyle] 51 | deps = 52 | six 53 | pycodestyle 54 | commands = {envpython} {envbindir}/pycodestyle {[config]package_dir} 55 | changedir = {toxinidir} 56 | install_command = {envpython} {envbindir}/pip install {opts} {packages} 57 | list_dependencies_command = {envpython} {envbindir}/pip freeze 58 | passenv = SSH_AUTH_SOCK BUILD_NUMBER 59 | extras = 60 | pep8 61 | 62 | [testenv:lint_pylint] 63 | deps = 64 | isort<=4.2.15 65 | six 66 | pylint 67 | commands = {envpython} {envbindir}/pylint --output-format=parseable {[config]package_dir} 68 | changedir = {toxinidir} 69 | install_command = {envpython} {envbindir}/pip install {opts} {packages} 70 | list_dependencies_command = {envpython} {envbindir}/pip freeze 71 | passenv = SSH_AUTH_SOCK BUILD_NUMBER 72 | extras = 73 | pylint 74 | 75 | [testenv:lint_mypy] 76 | deps = 77 | mypy 78 | lxml 79 | commands = 80 | {envpython} {envbindir}/mypy -p {[config]package_name} --ignore-missing-imports --txt-report artifacts/mypy 81 | changedir = {toxinidir} 82 | install_command = {envpython} {envbindir}/pip install {opts} {packages} 83 | list_dependencies_command = {envpython} {envbindir}/pip freeze 84 | passenv = SSH_AUTH_SOCK BUILD_NUMBER 85 | extras = 86 | mypy 87 | 88 | [testenv:doc_build] 89 | deps = 90 | sphinx!=1.8.0 91 | sphinx_rtd_theme 92 | guzzle_sphinx_theme 93 | recommonmark 94 | sphinx_markdown_tables 95 | commands = {envpython} {envbindir}/sphinx-build -b html doc/source build/sphinx/html 96 | changedir = {toxinidir} 97 | install_command = {envpython} {envbindir}/pip install {opts} {packages} 98 | list_dependencies_command = {envpython} {envbindir}/pip freeze 99 | passenv = SSH_AUTH_SOCK BUILD_NUMBER 100 | extras = 101 | doc_build 102 | basepython = python3.6 103 | 104 | [testenv:add_api_docs] 105 | deps = 106 | sphinx 107 | commands = 108 | {envpython} {envbindir}/sphinx-apidoc -T -e -M -o doc/source/ src "artifacts/*" "dist/*" "screwdriver/*" "scripts/*" setup.py "tests/*" 109 | changedir = {toxinidir} 110 | extras = 111 | doc_build 112 | passenv = SSH_AUTH_SOCK BUILD_NUMBER 113 | install_command = {envpython} {envbindir}/pip install {opts} {packages} 114 | list_dependencies_command = {envpython} {envbindir}/pip freeze 115 | basepython = python3.6 116 | 117 | [pycodestyle] 118 | ignore = E1,E2,E3,E4,E5,W293 119 | max_line_length = 160 120 | -------------------------------------------------------------------------------- /examples/resnet/README_orig.md: -------------------------------------------------------------------------------- 1 | This folder contains the Keras implementation of the ResNet models. For more 2 | information about the models, please refer to this [README file](../../README.md). 3 | 4 | Similar to the [estimator implementation](../../r1/resnet), the Keras 5 | implementation has code for both CIFAR-10 data and ImageNet data. The CIFAR-10 6 | version uses a ResNet56 model implemented in 7 | [`resnet_cifar_model.py`](./resnet_cifar_model.py), and the ImageNet version 8 | uses a ResNet50 model implemented in [`resnet_model.py`](./resnet_model.py). 9 | 10 | To use 11 | either dataset, make sure that you have the latest version of TensorFlow 12 | installed and 13 | [add the models folder to your Python path](/official/#running-the-models), 14 | otherwise you may encounter an error like `ImportError: No module named 15 | official.resnet`. 16 | 17 | ## CIFAR-10 18 | 19 | Download and extract the CIFAR-10 data. You can use the following script: 20 | ```bash 21 | python ../../r1/resnet/cifar10_download_and_extract.py 22 | ``` 23 | 24 | After you download the data, you can run the program by: 25 | 26 | ```bash 27 | python resnet_cifar_main.py 28 | ``` 29 | 30 | If you did not use the default directory to download the data, specify the 31 | location with the `--data_dir` flag, like: 32 | 33 | ```bash 34 | python resnet_cifar_main.py --data_dir=/path/to/cifar 35 | ``` 36 | 37 | ## ImageNet 38 | 39 | Download the ImageNet dataset and convert it to TFRecord format. 40 | The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py) 41 | and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy) 42 | provide a few options. 43 | 44 | Once your dataset is ready, you can begin training the model as follows: 45 | 46 | ```bash 47 | python resnet_imagenet_main.py 48 | ``` 49 | 50 | Again, if you did not download the data to the default directory, specify the 51 | location with the `--data_dir` flag: 52 | 53 | ```bash 54 | python resnet_imagenet_main.py --data_dir=/path/to/imagenet 55 | ``` 56 | 57 | There are more flag options you can specify. Here are some examples: 58 | 59 | - `--use_synthetic_data`: when set to true, synthetic data, rather than real 60 | data, are used; 61 | - `--batch_size`: the batch size used for the model; 62 | - `--model_dir`: the directory to save the model checkpoint; 63 | - `--train_epochs`: number of epoches to run for training the model; 64 | - `--train_steps`: number of steps to run for training the model. We now only 65 | support a number that is smaller than the number of batches in an epoch. 66 | - `--skip_eval`: when set to true, evaluation as well as validation during 67 | training is skipped 68 | 69 | For example, this is a typical command line to run with ImageNet data with 70 | batch size 128 per GPU: 71 | 72 | ```bash 73 | python -m resnet_imagenet_main \ 74 | --model_dir=/tmp/model_dir/something \ 75 | --num_gpus=2 \ 76 | --batch_size=128 \ 77 | --train_epochs=90 \ 78 | --train_steps=10 \ 79 | --use_synthetic_data=false 80 | ``` 81 | 82 | See [`common.py`](common.py) for full list of options. 83 | 84 | ## Using multiple GPUs 85 | You can train these models on multiple GPUs using `tf.distribute.Strategy` API. 86 | You can read more about them in this 87 | [guide](https://www.tensorflow.org/guide/distribute_strategy). 88 | 89 | In this example, we have made it easier to use is with just a command line flag 90 | `--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA, 91 | and 0 otherwise. 92 | 93 | - --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device. 94 | - --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device. 95 | - --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous 96 | distributed training across the GPUs. 97 | 98 | If you wish to run without `tf.distribute.Strategy`, you can do so by setting 99 | `--distribution_strategy=off`. 100 | 101 | -------------------------------------------------------------------------------- /examples/mnist/estimator/mnist_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | 5 | # This example demonstrates how to leverage Spark for parallel inferencing from a SavedModel. 6 | # 7 | # Normally, you can use TensorFlowOnSpark to just form a TensorFlow cluster for training and inferencing. 8 | # However, in some situations, you may have a SavedModel without the original code for defining the inferencing 9 | # graph. In these situations, we can use Spark to instantiate a single-node TensorFlow instance on each executor, 10 | # where each executor can independently load the model and inference on input data. 11 | # 12 | # Note: this particular example demonstrates use of `tf.data.Dataset` to read the input data for inferencing, 13 | # but it could also be adapted to just use an RDD of TFRecords from Spark. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | 24 | def inference(it, num_workers, args): 25 | from tensorflowonspark import util 26 | 27 | # consume worker number from RDD partition iterator 28 | for i in it: 29 | worker_num = i 30 | print("worker_num: {}".format(i)) 31 | 32 | # setup env for single-node TF 33 | util.single_node_env() 34 | 35 | # load saved_model 36 | saved_model = tf.saved_model.load(args.export_dir, tags='serve') 37 | predict = saved_model.signatures['serving_default'] 38 | 39 | # parse function for TFRecords 40 | def parse_tfr(example_proto): 41 | feature_def = {"label": tf.io.FixedLenFeature(1, tf.int64), 42 | "image": tf.io.FixedLenFeature(784, tf.int64)} 43 | features = tf.io.parse_single_example(serialized=example_proto, features=feature_def) 44 | image = tf.cast(features['image'], dtype=tf.float32) / 255.0 45 | image = tf.reshape(image, [28, 28, 1]) 46 | label = tf.cast(features['label'], dtype=tf.float32) 47 | return (image, label) 48 | 49 | # define a new tf.data.Dataset (for inferencing) 50 | ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels), shuffle=False) 51 | ds = ds.shard(num_workers, worker_num) 52 | ds = ds.interleave(tf.data.TFRecordDataset) 53 | ds = ds.map(parse_tfr) 54 | ds = ds.batch(10) 55 | 56 | # create an output file per spark worker for the predictions 57 | tf.io.gfile.makedirs(args.output) 58 | output_file = tf.io.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w') 59 | 60 | for batch in ds: 61 | predictions = predict(conv2d_input=batch[0]) 62 | labels = np.reshape(batch[1], -1).astype(np.int) 63 | preds = np.argmax(predictions['logits'], axis=1) 64 | for x in zip(labels, preds): 65 | output_file.write("{} {}\n".format(x[0], x[1])) 66 | 67 | output_file.close() 68 | 69 | 70 | if __name__ == '__main__': 71 | from pyspark.context import SparkContext 72 | from pyspark.conf import SparkConf 73 | 74 | sc = SparkContext(conf=SparkConf().setAppName("mnist_inference")) 75 | executors = sc._conf.get("spark.executor.instances") 76 | num_executors = int(executors) if executors is not None else 1 77 | 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("--cluster_size", help="number of nodes in the cluster (for S with labelspark Standalone)", type=int, default=num_executors) 80 | parser.add_argument('--images_labels', type=str, help='Directory for input images with labels') 81 | parser.add_argument("--export_dir", help="HDFS path to export model", type=str, default="mnist_export") 82 | parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions") 83 | args, _ = parser.parse_known_args() 84 | print("args: {}".format(args)) 85 | 86 | # Not using TFCluster... just running single-node TF instances on each executor 87 | nodes = list(range(args.cluster_size)) 88 | nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size) 89 | nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args)) 90 | -------------------------------------------------------------------------------- /examples/mnist/keras/mnist_tf.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras 2 | 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | 6 | def main_fun(args, ctx): 7 | import tensorflow_datasets as tfds 8 | import tensorflow as tf 9 | from tensorflowonspark import compat 10 | 11 | tfds.disable_progress_bar() 12 | 13 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 14 | 15 | BUFFER_SIZE = args.buffer_size 16 | BATCH_SIZE = args.batch_size 17 | NUM_WORKERS = args.cluster_size 18 | 19 | # Scaling MNIST data from (0, 255] to (0., 1.] 20 | def scale(image, label): 21 | return tf.cast(image, tf.float32) / 255, label 22 | 23 | # workaround for https://github.com/tensorflow/datasets/issues/1405 24 | datasets = tfds.load(name='mnist', split='train', as_supervised=True) 25 | options = tf.data.Options() 26 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA 27 | train_datasets_unbatched = datasets.with_options(options).repeat().map(scale).shuffle(BUFFER_SIZE) 28 | 29 | def build_and_compile_cnn_model(): 30 | model = tf.keras.Sequential([ 31 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 32 | tf.keras.layers.MaxPooling2D(), 33 | tf.keras.layers.Flatten(), 34 | tf.keras.layers.Dense(64, activation='relu'), 35 | tf.keras.layers.Dense(10, activation='softmax') 36 | ]) 37 | model.compile( 38 | loss=tf.keras.losses.sparse_categorical_crossentropy, 39 | optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 40 | metrics=['accuracy']) 41 | return model 42 | 43 | # single node 44 | # single_worker_model = build_and_compile_cnn_model() 45 | # single_worker_model.fit(x=train_datasets, epochs=3) 46 | 47 | # Here the batch size scales up by number of workers since 48 | # `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 49 | # and now this becomes 128. 50 | GLOBAL_BATCH_SIZE = BATCH_SIZE * NUM_WORKERS 51 | train_datasets = train_datasets_unbatched.batch(GLOBAL_BATCH_SIZE) 52 | 53 | # this fails 54 | # callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=args.model_dir)] 55 | tf.io.gfile.makedirs(args.model_dir) 56 | filepath = args.model_dir + "/weights-{epoch:04d}" 57 | callbacks = [ 58 | tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True), 59 | tf.keras.callbacks.TensorBoard(log_dir=args.model_dir) 60 | ] 61 | 62 | with strategy.scope(): 63 | multi_worker_model = build_and_compile_cnn_model() 64 | multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=args.steps_per_epoch, callbacks=callbacks) 65 | 66 | compat.export_saved_model(multi_worker_model, args.export_dir, ctx.job_name == 'chief') 67 | 68 | 69 | if __name__ == '__main__': 70 | import argparse 71 | from pyspark.context import SparkContext 72 | from pyspark.conf import SparkConf 73 | from tensorflowonspark import TFCluster 74 | 75 | sc = SparkContext(conf=SparkConf().setAppName("mnist_keras")) 76 | executors = sc._conf.get("spark.executor.instances") 77 | num_executors = int(executors) if executors is not None else 1 78 | 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 81 | parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000) 82 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 83 | parser.add_argument("--epochs", help="number of epochs", type=int, default=3) 84 | parser.add_argument("--model_dir", help="path to save model/checkpoint", default="mnist_model") 85 | parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export") 86 | parser.add_argument("--steps_per_epoch", help="number of steps per epoch", type=int, default=469) 87 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 88 | 89 | args = parser.parse_args() 90 | print("args:", args) 91 | 92 | cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief', log_dir=args.model_dir) 93 | cluster.shutdown() 94 | -------------------------------------------------------------------------------- /tests/test_reservation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import time 4 | import unittest 5 | 6 | from tensorflowonspark import util 7 | from tensorflowonspark.reservation import Reservations, Server, Client 8 | from unittest import mock 9 | 10 | 11 | class ReservationTest(unittest.TestCase): 12 | def test_reservation_class(self): 13 | """Test core reservation class, expecting 2 reservations""" 14 | r = Reservations(2) 15 | self.assertFalse(r.done()) 16 | 17 | # add first reservation 18 | r.add({'node': 1}) 19 | self.assertFalse(r.done()) 20 | self.assertEqual(r.remaining(), 1) 21 | 22 | # add second reservation 23 | r.add({'node': 2}) 24 | self.assertTrue(r.done()) 25 | self.assertEqual(r.remaining(), 0) 26 | 27 | # get final list 28 | reservations = r.get() 29 | self.assertEqual(len(reservations), 2) 30 | 31 | def test_reservation_server(self): 32 | """Test reservation server, expecting 1 reservation""" 33 | s = Server(1) 34 | addr = s.start() 35 | 36 | # add first reservation 37 | c = Client(addr) 38 | resp = c.register({'node': 1}) 39 | self.assertEqual(resp, 'OK') 40 | 41 | # get list of reservations 42 | reservations = c.get_reservations() 43 | self.assertEqual(len(reservations), 1) 44 | 45 | # should return immediately with list of reservations 46 | reservations = c.await_reservations() 47 | self.assertEqual(len(reservations), 1) 48 | 49 | # request server stop 50 | c.request_stop() 51 | time.sleep(1) 52 | self.assertEqual(s.done, True) 53 | 54 | def test_reservation_environment_exists_get_server_ip_return_environment_value(self): 55 | tfos_server = Server(5) 56 | with mock.patch.dict(os.environ, {'TFOS_SERVER_HOST': 'my_host_ip'}): 57 | assert tfos_server.get_server_ip() == "my_host_ip" 58 | 59 | def test_reservation_environment_not_exists_get_server_ip_return_actual_host_ip(self): 60 | tfos_server = Server(5) 61 | assert tfos_server.get_server_ip() == util.get_ip_address() 62 | 63 | def test_reservation_environment_exists_start_listening_socket_return_socket_listening_to_environment_port_value(self): 64 | tfos_server = Server(1) 65 | with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}): 66 | assert tfos_server.start_listening_socket().getsockname()[1] == 9999 67 | 68 | def test_reservation_environment_not_exists_start_listening_socket_return_socket(self): 69 | tfos_server = Server(1) 70 | print(tfos_server.start_listening_socket().getsockname()[1]) 71 | assert type(tfos_server.start_listening_socket().getsockname()[1]) == int 72 | 73 | def test_reservation_environment_exists_port_spec(self): 74 | tfos_server = Server(1) 75 | with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9999'}): 76 | self.assertEqual(tfos_server.get_server_ports(), [9999]) 77 | 78 | with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9997-9999'}): 79 | self.assertEqual(tfos_server.get_server_ports(), [9997, 9998, 9999]) 80 | 81 | def test_reservation_environment_exists_start_listening_socket_return_socket_listening_to_environment_port_range(self): 82 | tfos_server1 = Server(1) 83 | tfos_server2 = Server(1) 84 | tfos_server3 = Server(1) 85 | with mock.patch.dict(os.environ, {'TFOS_SERVER_PORT': '9998-9999'}): 86 | s1 = tfos_server1.start_listening_socket() 87 | self.assertEqual(s1.getsockname()[1], 9998) 88 | s2 = tfos_server2.start_listening_socket() 89 | self.assertEqual(s2.getsockname()[1], 9999) 90 | with self.assertRaises(Exception): 91 | tfos_server3.start_listening_socket() 92 | tfos_server1.stop() 93 | tfos_server2.stop() 94 | 95 | def test_reservation_server_multi(self): 96 | """Test reservation server, expecting multiple reservations""" 97 | num_clients = 4 98 | s = Server(num_clients) 99 | addr = s.start() 100 | 101 | def reserve(num): 102 | c = Client(addr) 103 | # time.sleep(random.randint(0,5)) # simulate varying start times 104 | resp = c.register({'node': num}) 105 | self.assertEqual(resp, 'OK') 106 | c.await_reservations() 107 | c.close() 108 | 109 | # start/register clients 110 | threads = [None] * num_clients 111 | for i in range(num_clients): 112 | threads[i] = threading.Thread(target=reserve, args=(i,)) 113 | threads[i].start() 114 | 115 | # wait for clients to complete 116 | for i in range(num_clients): 117 | threads[i].join() 118 | print("all done") 119 | 120 | # get list of reservations 121 | c = Client(addr) 122 | reservations = c.get_reservations() 123 | self.assertEqual(len(reservations), num_clients) 124 | 125 | # request server stop 126 | c.request_stop() 127 | time.sleep(1) 128 | self.assertEqual(s.done, True) 129 | 130 | 131 | if __name__ == '__main__': 132 | unittest.main() 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 6 | # TensorFlowOnSpark 7 | > _TensorFlowOnSpark brings scalable deep learning to Apache Hadoop and Apache Spark 8 | clusters._ 9 | 10 | [![Build Status](https://cd.screwdriver.cd/pipelines/6384/badge)](https://cd.screwdriver.cd/pipelines/6384) 11 | [![Package](https://img.shields.io/badge/package-pypi-blue.svg)](https://pypi.org/project/tensorflowonspark/) 12 | [![Downloads](https://img.shields.io/pypi/dm/tensorflowonspark.svg)](https://img.shields.io/pypi/dm/tensorflowonspark.svg) 13 | [![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://yahoo.github.io/TensorFlowOnSpark/) 14 | 15 | By combining salient features from the [TensorFlow](https://www.tensorflow.org) deep learning framework with [Apache Spark](http://spark.apache.org) and [Apache Hadoop](http://hadoop.apache.org), TensorFlowOnSpark enables distributed 16 | deep learning on a cluster of GPU and CPU servers. 17 | 18 | It enables both distributed TensorFlow training and 19 | inferencing on Spark clusters, with a goal to minimize the amount 20 | of code changes required to run existing TensorFlow programs on a 21 | shared grid. Its Spark-compatible API helps manage the TensorFlow 22 | cluster with the following steps: 23 | 24 | 1. **Startup** - launches the Tensorflow main function on the executors, along with listeners for data/control messages. 25 | 1. **Data ingestion** 26 | - **InputMode.TENSORFLOW** - leverages TensorFlow's built-in APIs to read data files directly from HDFS. 27 | - **InputMode.SPARK** - sends Spark RDD data to the TensorFlow nodes via a `TFNode.DataFeed` class. Note that we leverage the [Hadoop Input/Output Format](https://github.com/tensorflow/ecosystem/tree/master/hadoop) to access TFRecords on HDFS. 28 | 1. **Shutdown** - shuts down the Tensorflow workers and PS nodes on the executors. 29 | 30 | ## Table of Contents 31 | 32 | - [Background](#background) 33 | - [Install](#install) 34 | - [Usage](#usage) 35 | - [API](#api) 36 | - [Contribute](#contribute) 37 | - [License](#license) 38 | 39 | ## Background 40 | 41 | TensorFlowOnSpark was developed by Yahoo for large-scale distributed 42 | deep learning on our Hadoop clusters in Yahoo's private cloud. 43 | 44 | TensorFlowOnSpark provides some important benefits (see [our 45 | blog](https://developer.yahoo.com/blogs/157196317141/)) 46 | over alternative deep learning solutions. 47 | * Easily migrate existing TensorFlow programs with <10 lines of code change. 48 | * Support all TensorFlow functionalities: synchronous/asynchronous training, model/data parallelism, inferencing and TensorBoard. 49 | * Server-to-server direct communication achieves faster learning when available. 50 | * Allow datasets on HDFS and other sources pushed by Spark or pulled by TensorFlow. 51 | * Easily integrate with your existing Spark data processing pipelines. 52 | * Easily deployed on cloud or on-premise and on CPUs or GPUs. 53 | 54 | ## Install 55 | 56 | TensorFlowOnSpark is provided as a pip package, which can be installed on single machines via: 57 | ``` 58 | # for tensorflow>=2.0.0 59 | pip install tensorflowonspark 60 | 61 | # for tensorflow<2.0.0 62 | pip install tensorflowonspark==1.4.4 63 | ``` 64 | 65 | For distributed clusters, please see our [wiki site](../../wiki) for detailed documentation for specific environments, such as our getting started guides for [single-node Spark Standalone](https://github.com/yahoo/TensorFlowOnSpark/wiki/GetStarted_Standalone), [YARN clusters](../../wiki/GetStarted_YARN) and [AWS EC2](../../wiki/GetStarted_EC2). Note: the Windows operating system is not currently supported due to [this issue](https://github.com/yahoo/TensorFlowOnSpark/issues/36). 66 | 67 | ## Usage 68 | 69 | To use TensorFlowOnSpark with an existing TensorFlow application, you can follow our [Conversion Guide](../../wiki/Conversion-Guide) to describe the required changes. Additionally, our [wiki site](../../wiki) has pointers to some presentations which provide an overview of the platform. 70 | 71 | **Note: since TensorFlow 2.x breaks API compatibility with TensorFlow 1.x, the examples have been updated accordingly. If you are using TensorFlow 1.x, you will need to checkout the `v1.4.4` tag for compatible examples and instructions.** 72 | 73 | ## API 74 | 75 | [API Documentation](https://yahoo.github.io/TensorFlowOnSpark/) is automatically generated from the code. 76 | 77 | ## Contribute 78 | 79 | Please join the [TensorFlowOnSpark user group](https://groups.google.com/forum/#!forum/TensorFlowOnSpark-users) for discussions and questions. If you have a question, please review our [FAQ](../../wiki/Frequently-Asked-Questions) before posting. 80 | 81 | Contributions are always welcome. For more information, please see our [guide for getting involved](Contributing.md). 82 | 83 | ## License 84 | 85 | The use and distribution terms for this software are covered by the Apache 2.0 license. 86 | See [LICENSE](LICENSE) file for terms. 87 | -------------------------------------------------------------------------------- /examples/mnist/keras/mnist_spark.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras 2 | 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | 6 | def main_fun(args, ctx): 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflowonspark import compat, TFNode 10 | 11 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 12 | 13 | def build_and_compile_cnn_model(): 14 | model = tf.keras.Sequential([ 15 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 16 | tf.keras.layers.MaxPooling2D(), 17 | tf.keras.layers.Flatten(), 18 | tf.keras.layers.Dense(64, activation='relu'), 19 | tf.keras.layers.Dense(10, activation='softmax') 20 | ]) 21 | model.compile( 22 | loss=tf.keras.losses.sparse_categorical_crossentropy, 23 | optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 24 | metrics=['accuracy']) 25 | return model 26 | 27 | # single node 28 | # single_worker_model = build_and_compile_cnn_model() 29 | # single_worker_model.fit(x=train_datasets, epochs=3) 30 | 31 | tf_feed = TFNode.DataFeed(ctx.mgr, False) 32 | 33 | def rdd_generator(): 34 | while not tf_feed.should_stop(): 35 | batch = tf_feed.next_batch(1) 36 | if len(batch) > 0: 37 | example = batch[0] 38 | image = np.array(example[0]).astype(np.float32) / 255.0 39 | image = np.reshape(image, (28, 28, 1)) 40 | label = np.array(example[1]).astype(np.float32) 41 | label = np.reshape(label, (1,)) 42 | yield (image, label) 43 | else: 44 | return 45 | 46 | ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1]))) 47 | ds = ds.batch(args.batch_size) 48 | 49 | # this fails 50 | # callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=args.model_dir)] 51 | tf.io.gfile.makedirs(args.model_dir) 52 | filepath = args.model_dir + "/weights-{epoch:04d}" 53 | callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True)] 54 | 55 | with strategy.scope(): 56 | multi_worker_model = build_and_compile_cnn_model() 57 | 58 | # Note: MultiWorkerMirroredStrategy (CollectiveAllReduceStrategy) is synchronous, 59 | # so we need to ensure that all workers complete training before any of them run out of data from the RDD. 60 | # And given that Spark RDD partitions (and partition sizes) can be non-evenly divisible by num_workers, 61 | # we'll just stop training at 90% of the total expected number of steps. 62 | steps_per_epoch = 60000 / args.batch_size 63 | steps_per_epoch_per_worker = steps_per_epoch / ctx.num_workers 64 | max_steps_per_worker = steps_per_epoch_per_worker * 0.9 65 | 66 | multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks) 67 | 68 | compat.export_saved_model(multi_worker_model, args.export_dir, ctx.job_name == 'chief') 69 | 70 | # terminating feed tells spark to skip processing further partitions 71 | tf_feed.terminate() 72 | 73 | 74 | if __name__ == '__main__': 75 | import argparse 76 | from pyspark.context import SparkContext 77 | from pyspark.conf import SparkConf 78 | from tensorflowonspark import TFCluster 79 | 80 | sc = SparkContext(conf=SparkConf().setAppName("mnist_keras")) 81 | executors = sc._conf.get("spark.executor.instances") 82 | num_executors = int(executors) if executors is not None else 1 83 | 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 86 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 87 | parser.add_argument("--epochs", help="number of epochs", type=int, default=3) 88 | parser.add_argument("--images_labels", help="path to MNIST images and labels in parallelized format") 89 | parser.add_argument("--model_dir", help="path to save checkpoint", default="mnist_model") 90 | parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export") 91 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 92 | 93 | args = parser.parse_args() 94 | print("args:", args) 95 | 96 | # create RDD of input data 97 | def parse(ln): 98 | vec = [int(x) for x in ln.split(',')] 99 | return (vec[1:], vec[0]) 100 | 101 | images_labels = sc.textFile(args.images_labels).map(parse) 102 | 103 | cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.SPARK, master_node='chief') 104 | # Note: need to feed extra data to ensure that each worker receives sufficient data to complete epochs 105 | # to compensate for variability in partition sizes and spark scheduling 106 | cluster.train(images_labels, args.epochs) 107 | cluster.shutdown() 108 | -------------------------------------------------------------------------------- /examples/mnist/estimator/mnist_tf.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator 2 | 3 | 4 | def main_fun(args, ctx): 5 | import tensorflow_datasets as tfds 6 | import tensorflow as tf 7 | 8 | BUFFER_SIZE = args.buffer_size 9 | BATCH_SIZE = args.batch_size 10 | LEARNING_RATE = args.learning_rate 11 | 12 | def input_fn(mode, input_context=None): 13 | datasets, info = tfds.load(name='mnist', 14 | with_info=True, 15 | as_supervised=True) 16 | mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else 17 | datasets['test']) 18 | 19 | def scale(image, label): 20 | image = tf.cast(image, tf.float32) 21 | image /= 255 22 | return image, label 23 | 24 | if input_context: 25 | mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines, 26 | input_context.input_pipeline_id) 27 | return mnist_dataset.repeat(args.epochs).map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) 28 | 29 | def serving_input_receiver_fn(): 30 | features = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='features') 31 | receiver_tensors = {'conv2d_input': features} 32 | return tf.estimator.export.ServingInputReceiver(receiver_tensors, receiver_tensors) 33 | 34 | def model_fn(features, labels, mode): 35 | model = tf.keras.Sequential([ 36 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 37 | tf.keras.layers.MaxPooling2D(), 38 | tf.keras.layers.Flatten(), 39 | tf.keras.layers.Dense(64, activation='relu'), 40 | tf.keras.layers.Dense(10) 41 | ]) 42 | logits = model(features, training=False) 43 | 44 | if mode == tf.estimator.ModeKeys.PREDICT: 45 | predictions = {'logits': logits} 46 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 47 | 48 | optimizer = tf.compat.v1.train.GradientDescentOptimizer( 49 | learning_rate=LEARNING_RATE) 50 | loss = tf.keras.losses.SparseCategoricalCrossentropy( 51 | from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits) 52 | loss = tf.reduce_sum(input_tensor=loss) * (1. / BATCH_SIZE) 53 | if mode == tf.estimator.ModeKeys.EVAL: 54 | return tf.estimator.EstimatorSpec(mode, loss=loss) 55 | 56 | return tf.estimator.EstimatorSpec( 57 | mode=mode, 58 | loss=loss, 59 | train_op=optimizer.minimize( 60 | loss, tf.compat.v1.train.get_or_create_global_step())) 61 | 62 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 63 | config = tf.estimator.RunConfig(train_distribute=strategy, save_checkpoints_steps=100) 64 | 65 | classifier = tf.estimator.Estimator( 66 | model_fn=model_fn, model_dir=args.model_dir, config=config) 67 | 68 | # exporter = tf.estimator.FinalExporter("serving", serving_input_receiver_fn=serving_input_receiver_fn) 69 | 70 | tf.estimator.train_and_evaluate( 71 | classifier, 72 | train_spec=tf.estimator.TrainSpec(input_fn=input_fn), 73 | eval_spec=tf.estimator.EvalSpec(input_fn=input_fn) 74 | # eval_spec=tf.estimator.EvalSpec(input_fn=input_fn, exporters=exporter) 75 | ) 76 | 77 | if ctx.job_name == 'chief': 78 | print("========== exporting saved_model to {}".format(args.export_dir)) 79 | classifier.export_saved_model(args.export_dir, serving_input_receiver_fn) 80 | 81 | 82 | if __name__ == "__main__": 83 | # tf.app.run() 84 | 85 | from pyspark.context import SparkContext 86 | from pyspark.conf import SparkConf 87 | from tensorflowonspark import TFCluster 88 | import argparse 89 | 90 | sc = SparkContext(conf=SparkConf().setAppName("mnist_estimator")) 91 | executors = sc._conf.get("spark.executor.instances") 92 | num_executors = int(executors) if executors is not None else 1 93 | 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 96 | parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000) 97 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 98 | parser.add_argument("--epochs", help="number of epochs", type=int, default=3) 99 | parser.add_argument("--learning_rate", help="learning rate", type=float, default=1e-4) 100 | parser.add_argument("--model_dir", help="path to save checkpoint", default="mnist_model") 101 | parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export") 102 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 103 | 104 | args = parser.parse_args() 105 | print("args:", args) 106 | 107 | cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, log_dir=args.model_dir, master_node='chief', eval_node=True) 108 | cluster.shutdown(grace_secs=60) 109 | -------------------------------------------------------------------------------- /tests/test_TFCluster.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import test 3 | import time 4 | from tensorflowonspark import TFCluster, TFNode 5 | 6 | 7 | class TFClusterTest(test.SparkTest): 8 | @classmethod 9 | def setUpClass(cls): 10 | super(TFClusterTest, cls).setUpClass() 11 | 12 | @classmethod 13 | def tearDownClass(cls): 14 | super(TFClusterTest, cls).tearDownClass() 15 | 16 | def test_basic_tf(self): 17 | """Single-node TF graph (w/ args) running independently on multiple executors.""" 18 | def _map_fun(args, ctx): 19 | import tensorflow as tf 20 | x = tf.constant(args['x']) 21 | y = tf.constant(args['y']) 22 | sum = tf.math.add(x, y) 23 | assert sum.numpy() == 3 24 | 25 | args = {'x': 1, 'y': 2} 26 | cluster = TFCluster.run(self.sc, _map_fun, tf_args=args, num_executors=self.num_workers, num_ps=0) 27 | cluster.shutdown() 28 | 29 | def test_inputmode_spark(self): 30 | """Distributed TF cluster w/ InputMode.SPARK""" 31 | def _map_fun(args, ctx): 32 | import tensorflow as tf 33 | 34 | tf_feed = TFNode.DataFeed(ctx.mgr, False) 35 | while not tf_feed.should_stop(): 36 | batch = tf_feed.next_batch(batch_size=10) 37 | print("batch: {}".format(batch)) 38 | squares = tf.math.square(batch) 39 | print("squares: {}".format(squares)) 40 | tf_feed.batch_results(squares.numpy()) 41 | 42 | input = [[x] for x in range(1000)] # set up input as tensors of shape [1] to match placeholder 43 | rdd = self.sc.parallelize(input, 10) 44 | cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.SPARK) 45 | rdd_out = cluster.inference(rdd) 46 | rdd_sum = rdd_out.sum() 47 | self.assertEqual(rdd_sum, sum([x * x for x in range(1000)])) 48 | cluster.shutdown() 49 | 50 | def test_inputmode_spark_exception(self): 51 | """Distributed TF cluster w/ InputMode.SPARK and exception during feeding""" 52 | def _map_fun(args, ctx): 53 | import tensorflow as tf 54 | 55 | tf_feed = TFNode.DataFeed(ctx.mgr, False) 56 | while not tf_feed.should_stop(): 57 | batch = tf_feed.next_batch(10) 58 | if len(batch) > 0: 59 | squares = tf.math.square(batch) 60 | tf_feed.batch_results(squares.numpy()) 61 | raise Exception("FAKE exception during feeding") 62 | 63 | input = [[x] for x in range(1000)] # set up input as tensors of shape [1] to match placeholder 64 | rdd = self.sc.parallelize(input, 10) 65 | with self.assertRaises(Exception): 66 | cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.SPARK) 67 | cluster.inference(rdd, feed_timeout=1).count() 68 | cluster.shutdown() 69 | 70 | def test_inputmode_spark_late_exception(self): 71 | """Distributed TF cluster w/ InputMode.SPARK and exception after feeding""" 72 | def _map_fun(args, ctx): 73 | import tensorflow as tf 74 | 75 | tf_feed = TFNode.DataFeed(ctx.mgr, False) 76 | while not tf_feed.should_stop(): 77 | batch = tf_feed.next_batch(10) 78 | if len(batch) > 0: 79 | squares = tf.math.square(batch) 80 | tf_feed.batch_results(squares.numpy()) 81 | 82 | # simulate post-feed actions that raise an exception 83 | time.sleep(2) 84 | raise Exception("FAKE exception after feeding") 85 | 86 | input = [[x] for x in range(1000)] # set up input as tensors of shape [1] to match placeholder 87 | rdd = self.sc.parallelize(input, 10) 88 | with self.assertRaises(Exception): 89 | cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.SPARK) 90 | cluster.inference(rdd).count() 91 | cluster.shutdown(grace_secs=5) # note: grace_secs must be larger than the time needed for post-feed actions 92 | 93 | def test_port_released(self): 94 | """Test that temporary socket/port is released prior to invoking user map_fun.""" 95 | def _map_fun(args, ctx): 96 | assert ctx.tmp_socket is None 97 | 98 | cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief') 99 | cluster.shutdown() 100 | 101 | def test_port_unreleased(self): 102 | """Test that temporary socket/port is unreleased prior to invoking user map_fun.""" 103 | def _map_fun(args, ctx): 104 | import socket 105 | assert ctx.tmp_socket is not None 106 | reserved_port = ctx.tmp_socket.getsockname()[1] 107 | 108 | # socket bind to tmp port should fail 109 | try: 110 | my_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 111 | my_sock.bind(('0.0.0.0', reserved_port)) 112 | assert False, "should never hit this assert statement" 113 | except socket.error as e: 114 | print(e) 115 | assert True, "should raise an exception" 116 | 117 | ctx.release_port() 118 | assert ctx.tmp_socket is None 119 | 120 | cluster = TFCluster.run(self.sc, _map_fun, tf_args={}, num_executors=self.num_workers, num_ps=0, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief', release_port=False) 121 | cluster.shutdown() 122 | 123 | 124 | if __name__ == '__main__': 125 | unittest.main() 126 | -------------------------------------------------------------------------------- /src/test/scala/com/yahoo/tensorflowonspark/DFUtilTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2018 Yahoo Inc. 3 | * Licensed under the terms of the Apache 2.0 license. 4 | * Please see LICENSE file in the project root for terms. 5 | */ 6 | package com.yahoo.tensorflowonspark 7 | 8 | import java.io.File 9 | 10 | import org.apache.commons.io.FileUtils 11 | import org.apache.spark.{SparkConf, SparkContext} 12 | import org.apache.spark.sql.SparkSession 13 | import org.apache.spark.sql.types._ 14 | import org.scalatest.FunSuite 15 | import org.scalatest.BeforeAndAfter 16 | import org.scalatest.Matchers._ 17 | 18 | import scala.collection.JavaConversions._ 19 | 20 | class DFUtilTest extends FunSuite with BeforeAndAfter with TestData { 21 | val conf: SparkConf = new SparkConf().setAppName("SparkInferTF").setMaster("local") 22 | implicit val sc: SparkContext = new SparkContext(conf) 23 | implicit val spark: SparkSession = SparkSession.builder.getOrCreate() 24 | 25 | before { 26 | FileUtils.deleteDirectory(new File("test-data")) 27 | } 28 | 29 | test("Save DataFrame as TFRecords and reload with same schema") { 30 | val df1 = spark.createDataFrame(List(row1, row2), schema) 31 | df1.show() 32 | df1.printSchema() 33 | assert(schema == df1.schema) 34 | 35 | // save to disk 36 | DFUtil.saveAsTFRecords(df1, "test-data") 37 | assert(new File("test-data").exists()) 38 | 39 | // reload from disk 40 | val df2 = DFUtil.loadTFRecords("test-data", schema) 41 | df2.show() 42 | df2.printSchema() 43 | assert(df1.schema == df2.schema) 44 | 45 | // compare binary column 46 | val binaryIn = df1.select("binary").collect 47 | val binaryOut = df2.select("binary").collect 48 | assert(binaryOut(0).getAs[Array[Byte]](0) === binaryIn(0).getAs[Array[Byte]](0)) 49 | assert(binaryOut(1).getAs[Array[Byte]](0) === binaryIn(1).getAs[Array[Byte]](0)) 50 | 51 | // compare scalar columns 52 | val scalarsIn = df1.select("bool", "int", "long", "float", "double", "string").collect 53 | val scalarsOut = df2.select("bool", "int", "long", "float", "double", "string").collect 54 | 55 | assert(scalarsOut(0).toSeq == scalarsIn(0).toSeq) 56 | assert(scalarsOut(1).toSeq == scalarsIn(1).toSeq) 57 | 58 | // compare binary array column 59 | val binArraysIn = df1.select("arrayBinary").collect 60 | val binArraysOut = df2.select("arrayBinary").collect 61 | for (row <- 0 to 1) { 62 | val out = binArraysOut(row).getList[Array[Byte]](0) 63 | val in = binArraysIn(row).getList[Array[Byte]](0) 64 | for (i <- 0 to 2) { 65 | assert(out(i) === in(i)) 66 | } 67 | } 68 | 69 | // compare array columms 70 | val arraysIn = df1.select("arrayBool", "arrayInt", "arrayLong", "arrayFloat", "arrayString").collect 71 | val arraysOut = df2.select("arrayBool", "arrayInt", "arrayLong", "arrayFloat", "arrayString").collect 72 | 73 | assert(arraysOut(0).toSeq == arraysIn(0).toSeq) 74 | assert(arraysOut(1).toSeq == arraysIn(1).toSeq) 75 | 76 | assert(arraysOut(0).getList[Boolean](0) === arraysIn(0).getList[Boolean](0)) 77 | assert(arraysOut(0).getList[Int](1) === arraysIn(0).getList[Int](1)) 78 | assert(arraysOut(0).getList[Long](2) === arraysIn(0).getList[Long](2)) 79 | assert(arraysOut(0).getList[Float](3) === arraysIn(0).getList[Float](3)) 80 | 81 | // compare arrayDouble columns 82 | // Note: there is loss of precision since we convert double => float when saving, 83 | // and then convert float => double when loading. So need to use epsilon comparison. 84 | val arrayDoubleIn = df1.select("arrayDouble").collect 85 | val arrayDoubleOut = df2.select("arrayDouble").collect 86 | for (row <- 0 to 1) { 87 | val out = arrayDoubleOut(row).getList[Double](0) 88 | val in = arrayDoubleIn(row).getList[Double](0) 89 | for (i <- 0 to 2) { 90 | assert(out(i) === in(i) +- 1e-6) 91 | } 92 | } 93 | } 94 | 95 | test("Save DataFrame as TFRecords and reload without schema") { 96 | val df1 = spark.createDataFrame(List(row1, row2), schema) 97 | df1.show() 98 | df1.printSchema() 99 | assert(schema == df1.schema) 100 | 101 | // save to disk 102 | DFUtil.saveAsTFRecords(df1, "test-data") 103 | assert(new File("test-data").exists()) 104 | 105 | // reload from disk w/o schema hint 106 | val df2 = DFUtil.loadTFRecords("test-data") 107 | df2.show() 108 | df2.printSchema() 109 | 110 | // convert schema to list of StructFields, sorted by name 111 | val actual = df2.schema.fields.sortBy(_.name) 112 | 113 | // this is the expected inferred StructFields, sorted by name 114 | val expected = Array( 115 | StructField("binary", StringType), 116 | StructField("bool", LongType), 117 | StructField("int", LongType), 118 | StructField("long", LongType), 119 | StructField("float", FloatType), 120 | StructField("double", FloatType), 121 | StructField("string", StringType), 122 | StructField("arrayBinary", ArrayType(StringType)), 123 | StructField("arrayBool", ArrayType(LongType)), 124 | StructField("arrayInt", ArrayType(LongType)), 125 | StructField("arrayLong", ArrayType(LongType)), 126 | StructField("arrayFloat", ArrayType(FloatType)), 127 | StructField("arrayDouble", ArrayType(FloatType)), 128 | StructField("arrayString", ArrayType(StringType)) 129 | ).sortBy(_.name) 130 | 131 | assert(actual === expected) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /examples/resnet/README.md: -------------------------------------------------------------------------------- 1 | # ResNet Image Classification 2 | 3 | Original Source: https://github.com/tensorflow/models/tree/master/official/benchmark/models 4 | 5 | This code is based on the Image Classification model from the official [TensorFlow Models](https://github.com/tensorflow/models) repository. This example already supports different forms of distribution via the `DistributionStrategy` API, so there isn't much additional work to convert it to TensorFlowOnSpark. 6 | 7 | Notes: 8 | - This example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed. 9 | - For simplicity, this just uses a single-node Spark Standalone installation. 10 | 11 | #### Run the Single-Node Application 12 | 13 | First, make sure that you can run the original example, as follows: 14 | ``` 15 | # clone the TensorFlow models repository 16 | git clone https://github.com/tensorflow/models 17 | cd models 18 | 19 | # checkout the specific revision that this example was based upon 20 | git checkout c25c3e882e398d287240f619d7f56ac5b2973b6e 21 | 22 | # download the CIFAR10 dataset to /tmp/cifar10_data 23 | python official/r1/resnet/cifar10_download_and_extract.py 24 | 25 | # run the example 26 | export TENSORFLOW_MODELS=$(pwd) 27 | export CIFAR_DATA=/tmp/cifar10_data/cifar-10-batches-bin 28 | export PYTHONPATH=${TENSORFLOW_MODELS}:$PYTHONPATH 29 | 30 | # pip install tensorflow==2.1.1 tensorflow_model_optimization==0.3.0 31 | python ${TENSORFLOW_MODELS}/official/benchmark/models/resnet_cifar_main.py --data_dir=${CIFAR_DATA} --num_gpus=0 --train_epochs=1 32 | ``` 33 | 34 | If you have GPUs available, just set `--num_gpus` to the number of GPUs on your machine. 35 | 36 | #### Run as a Distributed TensorFlow Application 37 | 38 | Next, confirm that this application is capable of being distributed. We can test this on a single CPU machine by using two different terminal/shell sessions, as follows: 39 | ``` 40 | # in one shell/window 41 | export TFoS_HOME=/path/to/TensorFlowOnSpark 42 | export CIFAR_DATA=/tmp/cifar10_data/cifar-10-batches-bin 43 | export PYTHONPATH=${PYTHONPATH}:${TENSORFLOW_MODELS} 44 | export TF_CONFIG='{"cluster": { "chief": ["localhost:2222"], "worker": ["localhost:2223"]}, "task": {"type": "chief", "index": 0}}' 45 | python ${TFoS_HOME}/examples/resnet/resnet_cifar_main.py --data_dir=${CIFAR_DATA} --num_gpus=0 --ds=multi_worker_mirrored --train_epochs=1 46 | 47 | # in another shell/window 48 | # cd /path/to/tensorflow/models 49 | export TFoS_HOME=/path/to/TensorFlowOnSpark 50 | export CIFAR_DATA=/tmp/cifar10_data/cifar-10-batches-bin 51 | export PYTHONPATH=${PYTHONPATH}:${TENSORFLOW_MODELS} 52 | export TF_CONFIG='{"cluster": { "chief": ["localhost:2222"], "worker": ["localhost:2223"]}, "task": {"type": "worker", "index": 0}}' 53 | python ${TFoS_HOME}/examples/resnet/resnet_cifar_main.py --data_dir=${CIFAR_DATA} --num_gpus=0 --ds=multi_worker_mirrored --train_epochs=1 54 | ``` 55 | 56 | Note that we now configure the code to use the `MultiWorkerMirroredStrategy`. Also note that training will not begin until both nodes have started. 57 | 58 | ### Run as a TensorFlowOnSpark Application 59 | 60 | Finally, we can run the converted application as follows: 61 | ```bash 62 | export TFoS_HOME=/path/to/TensorFlowOnSpark 63 | export TENSORFLOW_MODELS=/path/to/tensorflow/models 64 | export CIFAR_DATA=/tmp/cifar10_data/cifar-10-batches-bin 65 | export PYTHONPATH=${PYTHONPATH}:${TENSORFLOW_MODELS} 66 | export MASTER=spark://$(hostname):7077 67 | export SPARK_WORKER_INSTANCES=2 68 | export CORES_PER_WORKER=1 69 | export TOTAL_CORES=$((${CORES_PER_WORKER}*${SPARK_WORKER_INSTANCES})) 70 | 71 | # start spark cluster 72 | ${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c $CORES_PER_WORKER -m 3G ${MASTER} 73 | 74 | # train and evaluate 75 | ${SPARK_HOME}/bin/spark-submit \ 76 | --master ${MASTER} \ 77 | --conf spark.cores.max=${TOTAL_CORES} \ 78 | --conf spark.task.cpus=${CORES_PER_WORKER} \ 79 | --py-files ${TFoS_HOME}/examples/resnet/resnet_cifar_dist.py \ 80 | ${TFoS_HOME}/examples/resnet/resnet_cifar_spark.py \ 81 | --cluster_size ${SPARK_WORKER_INSTANCES} \ 82 | --epochs 1 \ 83 | --data_dir ${CIFAR_DATA} \ 84 | --num_gpus=0 \ 85 | --ds=multi_worker_mirrored \ 86 | --train_epochs 1 87 | 88 | # shutdown spark 89 | ${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh 90 | ``` 91 | 92 | Notes: 93 | - Most of the original TensorFlow code from `resnet_cifar_main.py` has been copied into `resnet_cifar_dist.py`, so you can diff the changes required for TensorFlowOnSpark. 94 | - The `def main(_)` function was changed to `def main_fun(argv, ctx)`. 95 | - The `absl_app.run(main)` invocation was replaced by the Spark "main" function in `resnet_cifar_spark.py`. This file mostly contains the Spark application boilerplate along with the TensorFlowOnSpark calls to setup the TensorFlow cluster. Note that having the separate Spark and TensorFlow files can help isolate code and avoid Spark serialization issues. 96 | - The Spark "main" function uses `argparse` to parse TensorFlowOnSpark-specific command line arguments, but it passes the remaining argments (in the `rem` variable) to the TensorFlow `main_fun`, which then parses those arguments via `define_cifar_flags()` and `flags.FLAGS(argv)`. 97 | - In a truly distributed environment, you would need: 98 | - A distributed file system to store the dataset, so that each executor/node is able to read the data. 99 | - The dependencies from the `tensorflow/models` to be available on the executors, either installed locally or bundled with the Spark application. 100 | -------------------------------------------------------------------------------- /examples/segmentation/segmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # 4 | #@title Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import, division, print_function, unicode_literals 17 | 18 | from tensorflow_examples.models.pix2pix import pix2pix 19 | import tensorflow_datasets as tfds 20 | import tensorflow as tf 21 | 22 | dataset, info = tfds.load('oxford_iiit_pet:3.2.0', with_info=True) 23 | 24 | 25 | def normalize(input_image, input_mask): 26 | input_image = tf.cast(input_image, tf.float32)/128.0 - 1 27 | input_mask -= 1 28 | return input_image, input_mask 29 | 30 | 31 | @tf.function 32 | def load_image_train(datapoint): 33 | input_image = tf.image.resize(datapoint['image'], (128, 128)) 34 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) 35 | 36 | if tf.random.uniform(()) > 0.5: 37 | input_image = tf.image.flip_left_right(input_image) 38 | input_mask = tf.image.flip_left_right(input_mask) 39 | 40 | input_image, input_mask = normalize(input_image, input_mask) 41 | 42 | return input_image, input_mask 43 | 44 | 45 | def load_image_test(datapoint): 46 | input_image = tf.image.resize(datapoint['image'], (128, 128)) 47 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) 48 | input_image, input_mask = normalize(input_image, input_mask) 49 | return input_image, input_mask 50 | 51 | 52 | TRAIN_LENGTH = info.splits['train'].num_examples 53 | BATCH_SIZE = 64 54 | BUFFER_SIZE = 1000 55 | STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE 56 | 57 | train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) 58 | test = dataset['test'].map(load_image_test) 59 | 60 | train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() 61 | train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 62 | test_dataset = test.batch(BATCH_SIZE) 63 | 64 | OUTPUT_CHANNELS = 3 65 | 66 | base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False) 67 | 68 | # Use the activations of these layers 69 | layer_names = [ 70 | 'block_1_expand_relu', # 64x64 71 | 'block_3_expand_relu', # 32x32 72 | 'block_6_expand_relu', # 16x16 73 | 'block_13_expand_relu', # 8x8 74 | 'block_16_project', # 4x4 75 | ] 76 | layers = [base_model.get_layer(name).output for name in layer_names] 77 | 78 | # Create the feature extraction model 79 | down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers) 80 | 81 | down_stack.trainable = False 82 | 83 | up_stack = [ 84 | pix2pix.upsample(512, 3), # 4x4 -> 8x8 85 | pix2pix.upsample(256, 3), # 8x8 -> 16x16 86 | pix2pix.upsample(128, 3), # 16x16 -> 32x32 87 | pix2pix.upsample(64, 3), # 32x32 -> 64x64 88 | ] 89 | 90 | 91 | def unet_model(output_channels): 92 | 93 | # This is the last layer of the model 94 | last = tf.keras.layers.Conv2DTranspose( 95 | output_channels, 3, strides=2, 96 | padding='same', activation='softmax') # 64x64 -> 128x128 97 | 98 | inputs = tf.keras.layers.Input(shape=[128, 128, 3]) 99 | x = inputs 100 | 101 | # Downsampling through the model 102 | skips = down_stack(x) 103 | x = skips[-1] 104 | skips = reversed(skips[:-1]) 105 | 106 | # Upsampling and establishing the skip connections 107 | for up, skip in zip(up_stack, skips): 108 | x = up(x) 109 | concat = tf.keras.layers.Concatenate() 110 | x = concat([x, skip]) 111 | 112 | x = last(x) 113 | 114 | return tf.keras.Model(inputs=inputs, outputs=x) 115 | 116 | 117 | model = unet_model(OUTPUT_CHANNELS) 118 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', 119 | metrics=['accuracy']) 120 | 121 | # Training only (since we're using command-line) 122 | # def create_mask(pred_mask): 123 | # pred_mask = tf.argmax(pred_mask, axis=-1) 124 | # pred_mask = pred_mask[..., tf.newaxis] 125 | # return pred_mask[0] 126 | # 127 | # 128 | # def show_predictions(dataset=None, num=1): 129 | # if dataset: 130 | # for image, mask in dataset.take(num): 131 | # pred_mask = model.predict(image) 132 | # display([image[0], mask[0], create_mask(pred_mask)]) 133 | # else: 134 | # display([sample_image, sample_mask, 135 | # create_mask(model.predict(sample_image[tf.newaxis, ...]))]) 136 | # 137 | # 138 | # class DisplayCallback(tf.keras.callbacks.Callback): 139 | # def on_epoch_end(self, epoch, logs=None): 140 | # clear_output(wait=True) 141 | # show_predictions() 142 | # print ('\nSample Prediction after epoch {}\n'.format(epoch+1)) 143 | # 144 | 145 | # EPOCHS = 20 146 | EPOCHS = 1 147 | VAL_SUBSPLITS = 5 148 | VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS 149 | 150 | model_history = model.fit(train_dataset, epochs=EPOCHS, 151 | steps_per_epoch=STEPS_PER_EPOCH, 152 | validation_steps=VALIDATION_STEPS, 153 | validation_data=test_dataset) 154 | 155 | model.save_weights("keras_weights") 156 | -------------------------------------------------------------------------------- /src/test/scala/com/yahoo/tensorflowonspark/TFModelTest.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2018 Yahoo Inc. 3 | * Licensed under the terms of the Apache 2.0 license. 4 | * Please see LICENSE file in the project root for terms. 5 | */ 6 | package com.yahoo.tensorflowonspark 7 | 8 | import java.nio._ 9 | 10 | import org.apache.spark.sql.Row 11 | import org.scalatest.FunSuite 12 | import org.tensorflow.{DataType, Tensor} 13 | 14 | 15 | class TFModelTest extends FunSuite with TestData { 16 | val model = new TFModel("test") 17 | 18 | test("Convert Rows to Tensors") { 19 | val tensors = model.batch2tensors(listRows, schema) 20 | 21 | // given 2 rows of M columns in listRows 22 | // expect M tensors with 2 rows each, with ArrayType tensors having 3 cols each 23 | assert(tensors.size == listRows.head.size) 24 | assert(tensors.forall { case (name, tensor) => 25 | val expectedShape = if (name.startsWith("array")) Array(2L, 3L) else Array(2L) 26 | tensor.shape() sameElements expectedShape }) 27 | 28 | // check "sum" of columns for numeric scalar types 29 | assert(tensors("bool").copyTo(Array.ofDim[Boolean](2)) === Array(true, false)) 30 | assert(tensors("int").copyTo(Array.ofDim[Int](2)).sum === 3) 31 | assert(tensors("long").copyTo(Array.ofDim[Long](2)).sum === 3L) 32 | assert(tensors("float").copyTo(Array.ofDim[Float](2)).sum === 3.0f) 33 | assert(tensors("double").copyTo(Array.ofDim[Double](2)).sum === 3.0) 34 | 35 | // check binary/string types 36 | assert(tensors("binary").copyTo(Array.ofDim[Array[Byte]](2)) === Array("one".getBytes, "foo".getBytes)) 37 | assert(tensors("string").copyTo(Array.ofDim[Array[Byte]](2)).map(new String(_)) === Array("one", "foo")) 38 | 39 | // check sum of rows for numeric array types 40 | assert(tensors("arrayBool").copyTo(Array.ofDim[Boolean](2,3)).map(row => row.reduce((x,y) => x && y)) === Array(true, false)) 41 | assert(tensors("arrayInt").copyTo(Array.ofDim[Int](2,3)).map(_.sum) === Array(6, 15)) 42 | assert(tensors("arrayLong").copyTo(Array.ofDim[Long](2,3)).map(_.sum) === Array(6L, 15L)) 43 | assert(tensors("arrayFloat").copyTo(Array.ofDim[Float](2,3)).map(_.sum) === Array(3.3f, 6.3f)) 44 | assert(tensors("arrayDouble").copyTo(Array.ofDim[Double](2,3)).map(_.sum) === Array(3.3, 6.3)) 45 | 46 | // check binary/string array types 47 | assert(tensors("arrayBinary").copyTo(Array.ofDim[Array[Byte]](2, 3)) === 48 | Array( 49 | Array("one".getBytes, "two".getBytes, "three".getBytes), 50 | Array("foo".getBytes, "bar".getBytes, "baz".getBytes) 51 | )) 52 | assert(tensors("arrayString").copyTo(Array.ofDim[Array[Byte]](2, 3)).map(row => 53 | row.map(new String(_))) === 54 | Array( 55 | Array("one", "two", "three"), 56 | Array("foo", "bar", "baz") 57 | ) 58 | ) 59 | } 60 | 61 | test("Convert 1D Tensor to Rows") { 62 | val floatBuf = FloatBuffer.wrap(Array(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f)) 63 | val t8 = Tensor.create(Array(8L), floatBuf) 64 | val rows: List[Row] = model.tensors2batch(Seq(t8)) 65 | assert(rows(0).getAs[Float](0) == 0.0f) 66 | assert(rows(1).getAs[Float](0) == 1.0f) 67 | assert(rows(2).getAs[Float](0) == 2.0f) 68 | assert(rows(3).getAs[Float](0) == 3.0f) 69 | } 70 | 71 | test("Convert 2D Tensor to Rows") { 72 | // 8 x 1 tensor 73 | val floatBuf = FloatBuffer.wrap(Array(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f)) 74 | val t8_1 = Tensor.create(Array(8L, 1L), floatBuf) 75 | var rows: List[Row] = model.tensors2batch(Seq(t8_1)) 76 | assert(rows(0).getAs[Array[Float]](0).sum == 0.0f) 77 | assert(rows(1).getAs[Array[Float]](0).sum == 1.0f) 78 | assert(rows(2).getAs[Array[Float]](0).sum == 2.0f) 79 | assert(rows(3).getAs[Array[Float]](0).sum == 3.0f) 80 | 81 | // 4 x 2 tensor 82 | floatBuf.rewind() 83 | val t4_2 = Tensor.create(Array(4L, 2L), floatBuf) 84 | rows = model.tensors2batch(Seq(t4_2)) 85 | assert(rows(0).getAs[Array[Float]](0).sum == 1.0f) 86 | assert(rows(1).getAs[Array[Float]](0).sum == 5.0f) 87 | assert(rows(2).getAs[Array[Float]](0).sum == 9.0f) 88 | assert(rows(3).getAs[Array[Float]](0).sum == 13.0f) 89 | 90 | // 2 x 4 tensor 91 | floatBuf.rewind() 92 | val t2_4 = Tensor.create(Array(2L, 4L), floatBuf) 93 | rows = model.tensors2batch(Seq(t2_4)) 94 | assert(rows(0).getAs[Array[Float]](0).sum == 6.0f) 95 | assert(rows(1).getAs[Array[Float]](0).sum == 22.0f) 96 | 97 | // 1 x 8 tensor 98 | floatBuf.rewind() 99 | val t1_8 = Tensor.create(Array(1L, 8L), floatBuf) 100 | rows = model.tensors2batch(Seq(t1_8)) 101 | assert(rows(0).getAs[Array[Float]](0).sum == 28.0f) 102 | } 103 | 104 | test("Convert multiple Tensors to Rows") { 105 | // 1D tensor 106 | val longBuf = LongBuffer.wrap(Array(0L, 1L, 2L, 3L)) 107 | val t4 = Tensor.create(Array(4L), longBuf) 108 | 109 | // 4 x 2 tensor 110 | val floatBuf = FloatBuffer.wrap(Array(0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f)) 111 | val t4_2 = Tensor.create(Array(4L, 2L), floatBuf) 112 | 113 | val rows = model.tensors2batch(Seq(t4, t4_2)) 114 | 115 | // expect Rows: 116 | // 0L, (0.0f, 1.0f) 117 | // 1L, (2.0f, 3.0f) 118 | // 2L, (4.0f, 5.0f) 119 | // 3L, (6.0f, 7.0f) 120 | assert(rows(0).getLong(0) == 0L) 121 | assert(rows(0).getAs[Array[Float]](1).sum == 1.0f) 122 | assert(rows(1).getLong(0) == 1L) 123 | assert(rows(1).getAs[Array[Float]](1).sum == 5.0f) 124 | assert(rows(2).getLong(0) == 2L) 125 | assert(rows(2).getAs[Array[Float]](1).sum == 9.0f) 126 | assert(rows(3).getLong(0) == 3L) 127 | assert(rows(3).getAs[Array[Float]](1).sum == 13.0f) 128 | } 129 | } -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sphinx_rtd_theme 17 | import sys 18 | 19 | _pysrc = os.path.abspath(os.path.join(os.path.abspath(__file__), '..', '..', '..')) 20 | sys.path.insert(0, _pysrc) 21 | 22 | autodoc_mock_imports = ["pyspark", "tensorflow"] 23 | 24 | # -- Project information ----------------------------------------------------- 25 | 26 | project = 'TensorFlowOnSpark' 27 | copyright = '2020, Yahoo Inc' 28 | author = 'Yahoo Inc' 29 | 30 | # The short X.Y version 31 | version = '2.2.5' 32 | # The full version, including alpha/beta/rc tags 33 | release = '2.2.5' 34 | 35 | 36 | # -- General configuration --------------------------------------------------- 37 | 38 | # If your documentation needs a minimal Sphinx version, state it here. 39 | # 40 | # needs_sphinx = '1.0' 41 | 42 | # Add any Sphinx extension module names here, as strings. They can be 43 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 44 | # ones. 45 | extensions = [ 46 | 'sphinx.ext.autodoc', 47 | 'sphinx.ext.viewcode', 48 | 'sphinx.ext.githubpages', 49 | 'sphinx_rtd_theme' 50 | ] 51 | 52 | # Add any paths that contain templates here, relative to this directory. 53 | templates_path = ['_templates'] 54 | 55 | # The suffix(es) of source filenames. 56 | # You can specify multiple suffix as a list of string: 57 | # 58 | # source_suffix = ['.rst', '.md'] 59 | source_suffix = '.rst' 60 | 61 | # The master toctree document. 62 | master_doc = 'index' 63 | 64 | # The language for content autogenerated by Sphinx. Refer to documentation 65 | # for a list of supported languages. 66 | # 67 | # This is also used if you do content translation via gettext catalogs. 68 | # Usually you set "language" from the command line for these cases. 69 | language = 'en' 70 | 71 | # List of patterns, relative to source directory, that match files and 72 | # directories to ignore when looking for source files. 73 | # This pattern also affects html_static_path and html_extra_path . 74 | exclude_patterns = [] 75 | 76 | # The name of the Pygments (syntax highlighting) style to use. 77 | pygments_style = 'sphinx' 78 | 79 | # If true, the current module name will be prepended to all description 80 | # unit titles (such as .. function::). 81 | # 82 | add_module_names = False 83 | 84 | # -- Options for HTML output ------------------------------------------------- 85 | 86 | # The theme to use for HTML and HTML Help pages. See the documentation for 87 | # a list of builtin themes. 88 | # 89 | html_theme = 'sphinx_rtd_theme' 90 | 91 | # Theme options are theme-specific and customize the look and feel of a theme 92 | # further. For a list of options available for each theme, see the 93 | # documentation. 94 | # 95 | # html_theme_options = {} 96 | 97 | # Add any paths that contain custom static files (such as style sheets) here, 98 | # relative to this directory. They are copied after the builtin static files, 99 | # so a file named "default.css" will overwrite the builtin "default.css". 100 | html_static_path = ['_static'] 101 | 102 | # Custom sidebar templates, must be a dictionary that maps document names 103 | # to template names. 104 | # 105 | # The default sidebars (for documents that don't match any pattern) are 106 | # defined by theme itself. Builtin themes are using these templates by 107 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 108 | # 'searchbox.html']``. 109 | # 110 | # html_sidebars = {} 111 | 112 | 113 | # -- Options for HTMLHelp output --------------------------------------------- 114 | 115 | # Output file base name for HTML help builder. 116 | htmlhelp_basename = 'TensorFlowOnSparkdoc' 117 | 118 | 119 | # -- Options for LaTeX output ------------------------------------------------ 120 | 121 | latex_elements = { 122 | # The paper size ('letterpaper' or 'a4paper'). 123 | # 124 | # 'papersize': 'letterpaper', 125 | 126 | # The font size ('10pt', '11pt' or '12pt'). 127 | # 128 | # 'pointsize': '10pt', 129 | 130 | # Additional stuff for the LaTeX preamble. 131 | # 132 | # 'preamble': '', 133 | 134 | # Latex figure (float) alignment 135 | # 136 | # 'figure_align': 'htbp', 137 | } 138 | 139 | # Grouping the document tree into LaTeX files. List of tuples 140 | # (source start file, target name, title, 141 | # author, documentclass [howto, manual, or own class]). 142 | latex_documents = [ 143 | (master_doc, 'TensorFlowOnSpark.tex', 'TensorFlowOnSpark Documentation', 144 | 'Lee Yang', 'manual'), 145 | ] 146 | 147 | 148 | # -- Options for manual page output ------------------------------------------ 149 | 150 | # One entry per manual page. List of tuples 151 | # (source start file, name, description, authors, manual section). 152 | man_pages = [ 153 | (master_doc, 'tensorflowonspark', 'TensorFlowOnSpark Documentation', 154 | [author], 1) 155 | ] 156 | 157 | 158 | # -- Options for Texinfo output ---------------------------------------------- 159 | 160 | # Grouping the document tree into Texinfo files. List of tuples 161 | # (source start file, target name, title, author, 162 | # dir menu entry, description, category) 163 | texinfo_documents = [ 164 | (master_doc, 'TensorFlowOnSpark', 'TensorFlowOnSpark Documentation', 165 | author, 'TensorFlowOnSpark', 'One line description of project.', 166 | 'Miscellaneous'), 167 | ] 168 | 169 | 170 | # -- Extension configuration ------------------------------------------------- 171 | -------------------------------------------------------------------------------- /examples/mnist/keras/mnist_tf_ds.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras 2 | 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | 6 | def main_fun(args, ctx): 7 | """Example demonstrating loading TFRecords directly from disk (e.g. HDFS) without tensorflow_datasets.""" 8 | import tensorflow as tf 9 | from tensorflowonspark import compat 10 | 11 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 12 | 13 | BUFFER_SIZE = args.buffer_size 14 | BATCH_SIZE = args.batch_size 15 | NUM_WORKERS = args.cluster_size 16 | 17 | # parser for TFRecords downloaded by tensorflow_datasets 18 | # these are images + labels, where the images are just serialized PNGs 19 | def parse_tfds(x): 20 | feature_def = {"label": tf.io.FixedLenFeature(1, tf.int64), "image": tf.io.VarLenFeature(tf.string)} 21 | example = tf.io.parse_single_example(x, feature_def) 22 | image = tf.io.decode_image(example['image'].values[0]) / 255 23 | image.set_shape([28, 28, 1]) # fix for https://github.com/tensorflow/tensorflow/issues/24520 24 | label = example['label'] 25 | return (image, label) 26 | 27 | # parser for TFRecords generated by ${TFoS_HOME}/examples/mnist/mnist_data_setup.py 28 | # these are images + labels, where the images are a flattened arrays of ints 29 | def parse_tfos(example_proto): 30 | feature_def = {"label": tf.io.FixedLenFeature(10, tf.int64), 31 | "image": tf.io.FixedLenFeature(28 * 28 * 1, tf.int64)} 32 | features = tf.io.parse_single_example(example_proto, feature_def) 33 | image = tf.cast(features['image'], tf.float32) / 255 34 | image = tf.reshape(image, (28, 28, 1)) 35 | label = tf.math.argmax(features['label'], output_type=tf.int32) 36 | return (image, label) 37 | 38 | # Dataset for input data 39 | # tfds: /path/to/tensorflow_datasets/mnist/1.0.0/mnist-train.tfrecord* 40 | # tfos: /path/to/mnist/tfr/train/part-r-* 41 | image_pattern = ctx.absolute_path(args.images_labels) 42 | 43 | ds = tf.data.Dataset.list_files(image_pattern) 44 | ds = ds.repeat(args.epochs).shuffle(BUFFER_SIZE) 45 | ds = ds.interleave(tf.data.TFRecordDataset) 46 | 47 | if args.data_format == 'tfds': 48 | train_datasets_unbatched = ds.map(parse_tfds) 49 | else: # 'tfos' 50 | train_datasets_unbatched = ds.map(parse_tfos) 51 | 52 | def build_and_compile_cnn_model(): 53 | model = tf.keras.Sequential([ 54 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 55 | tf.keras.layers.MaxPooling2D(), 56 | tf.keras.layers.Flatten(), 57 | tf.keras.layers.Dense(64, activation='relu'), 58 | tf.keras.layers.Dense(10, activation='softmax') 59 | ]) 60 | model.compile( 61 | loss=tf.keras.losses.sparse_categorical_crossentropy, 62 | optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 63 | metrics=['accuracy']) 64 | return model 65 | 66 | # single node 67 | # single_worker_model = build_and_compile_cnn_model() 68 | # single_worker_model.fit(x=train_datasets, epochs=3) 69 | 70 | # Here the batch size scales up by number of workers since 71 | # `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 72 | # and now this becomes 128. 73 | GLOBAL_BATCH_SIZE = BATCH_SIZE * NUM_WORKERS 74 | train_datasets = train_datasets_unbatched.batch(GLOBAL_BATCH_SIZE) 75 | 76 | # this fails 77 | # callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=args.model_dir)] 78 | tf.io.gfile.makedirs(args.model_dir) 79 | filepath = args.model_dir + "/weights-{epoch:04d}" 80 | callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True)] 81 | 82 | # Note: if you part files have an uneven number of records, you may see an "Out of Range" exception 83 | # at less than the expected number of steps_per_epoch, because the executor with least amount of records will finish first. 84 | steps_per_epoch = 60000 / GLOBAL_BATCH_SIZE 85 | 86 | with strategy.scope(): 87 | multi_worker_model = build_and_compile_cnn_model() 88 | multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks) 89 | 90 | compat.export_saved_model(multi_worker_model, args.export_dir, ctx.job_name == 'chief') 91 | 92 | 93 | if __name__ == '__main__': 94 | import argparse 95 | from pyspark.context import SparkContext 96 | from pyspark.conf import SparkConf 97 | from tensorflowonspark import TFCluster 98 | 99 | sc = SparkContext(conf=SparkConf().setAppName("mnist_keras")) 100 | executors = sc._conf.get("spark.executor.instances") 101 | num_executors = int(executors) if executors is not None else 1 102 | 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 105 | parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000) 106 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 107 | parser.add_argument("--data_format", help="data format (tfos|tfds)", type=str, choices=["tfos", "tfds"], default="tfos") 108 | parser.add_argument("--epochs", help="number of epochs", type=int, default=3) 109 | parser.add_argument("--images_labels", help="HDFS path to MNIST image_label files in parallelized format") 110 | parser.add_argument("--model_dir", help="path to save model/checkpoint", default="mnist_model") 111 | parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export") 112 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 113 | 114 | args = parser.parse_args() 115 | print("args:", args) 116 | 117 | cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief') 118 | cluster.shutdown() 119 | -------------------------------------------------------------------------------- /examples/segmentation/README.md: -------------------------------------------------------------------------------- 1 | # Image Segmentation 2 | 3 | Original Source: https://www.tensorflow.org/tutorials/images/segmentation 4 | 5 | This code is based on the [Image Segmentation](https://www.tensorflow.org/tutorials/images/segmentation) notebook example, converted to a single-node TensorFlow python app, then converted into a distributed TensorFlow app using the `MultiWorkerMirroredStrategy`, and then finally adapted for TensorFlowOnSpark. Compare the different versions to see the conversion steps involved at each stage. 6 | 7 | Notes: 8 | - this example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed. 9 | 10 | #### Train via Single-Node 11 | 12 | The [segmentation.py](segmentation.py) file contains the bulk of the code from the example notebook, minus any code for interactively visualizing the images and masks, since the end goal will be a non-interactive job in Spark. 13 | 14 | Run the single-node example to ensure that your environment is set up correctly. For brevity, this example only trains a single epoch (vs. the original 20 epochs), but you can modify the source to run more epochs, if desired. 15 | ```bash 16 | # Run the following, if you see: "Failed to construct dataset oxford_iiit_petDataset oxford_iiit_pet cannot be loaded at version 3.2.0, only: 3.1.0." 17 | # pip uninstall tensorflow-datasets 18 | # pip install tfds-nightly 19 | 20 | # train 21 | python ${TFoS_HOME}/examples/segmentation/segmentation.py 22 | ``` 23 | 24 | This will save the model weights as `keras_weights.*` files, which you can re-use in the original notebook as follows: 25 | ``` 26 | # create a new empty model 27 | model = unet_model(OUTPUT_CHANNELS) 28 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', 29 | metrics=['accuracy']) 30 | show_predictions() 31 | 32 | # load the weights 33 | model.load_weights("/path/to/keras_weights") 34 | show_predictions() 35 | ``` 36 | 37 | #### Train via Distributed TensorFlow 38 | 39 | Next, the [segmentation_dist.py](segmentation_dist.py) file adds a `MultiWorkerMirroredStrategy` to enable distributed training. For simplicity, we can simulate two different machines by using separate shell windows. If you have multiple nodes available, you can run these commands on the separate machines (using the cluster host names instead of `localhost`). 40 | ``` 41 | # on one node/shell 42 | export TF_CONFIG='{"cluster": { "worker": ["localhost:2222", "localhost:2223"]}, "task": {"type": "worker", "index": 0}}' 43 | python ${TFoS_HOME}/examples/segmentation/segmentation_dist.py 44 | 45 | # on another node/shell 46 | export TF_CONFIG='{"cluster": { "worker": ["localhost:2222", "localhost:2223"]}, "task": {"type": "worker", "index": 1}}' 47 | python ${TFoS_HOME}/examples/segmentation/segmentation_dist.py 48 | ``` 49 | 50 | Note that training will not start until all nodes are running and connected to the cluster. Also note that the `MultiWorkerMirroredStrategy` is a synchronous training strategy, so each node will train a batch of data and update the model weights in lock-step with each of the other nodes. This has implications that are beyond the scope of this tutorial. For more information, you can read the [TensorFlow distributed training documentation](https://www.tensorflow.org/beta/tutorials/distribute/keras). Notably, you should shard the data across the workers and adjust the per-worker batch_size to account for additional nodes in the cluster. However, in order to minimize code changes here, this is left as an exercise for the reader. 51 | 52 | #### Train via TensorFlowOnSpark 53 | 54 | Next, we convert the `segmentation_dist.py` file to TensorFlowOnSpark, resulting in the [segmentation_spark.py](segmentation_spark.py) file. Then, run in a local Spark standalone cluster as follows: 55 | ```bash 56 | # Start a local standalone Spark cluster 57 | export MASTER=spark://$(hostname):7077 58 | export SPARK_WORKER_INSTANCES=3 59 | export CORES_PER_WORKER=1 60 | export TOTAL_CORES=$((${CORES_PER_WORKER}*${SPARK_WORKER_INSTANCES})) 61 | export TFoS_HOME= 62 | 63 | ${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c $CORES_PER_WORKER -m 3G ${MASTER} 64 | 65 | # remove any old artifacts 66 | rm -Rf ${TFoS_HOME}/segmentation_model.h5 ${TFoS_HOME}/segmentation_model ${TFoS_HOME}/segmentation_export 67 | 68 | # train 69 | ${SPARK_HOME}/bin/spark-submit \ 70 | --master ${MASTER} \ 71 | --conf spark.cores.max=${TOTAL_CORES} \ 72 | --conf spark.task.cpus=${CORES_PER_WORKER} \ 73 | ${TFoS_HOME}/examples/segmentation/segmentation_spark.py \ 74 | --cluster_size ${SPARK_WORKER_INSTANCES} \ 75 | --model_dir ${TFoS_HOME}/segmentation_model \ 76 | --export_dir ${TFoS_HOME}/segmentation_export \ 77 | --epochs 1 78 | 79 | # confirm model 80 | ls -lR ${TFoS_HOME}/segmentation_model 81 | ls -lR ${TFoS_HOME}/segmentation_export 82 | 83 | # Shutdown the Spark Standalone cluster 84 | ${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh 85 | ``` 86 | 87 | Once again, this only trains a single epoch and doesn't adjust for the increased cluster size. Feel free to experiment on your own. 88 | 89 | This example will save the model in several different formats: 90 | - TensorFlow/Keras checkpoint (`segmentation_model`) 91 | - Keras HDF5 file (`segmentation_model.h5`) 92 | - TensorFlow saved_model (`segmentation_export`) 93 | 94 | You can re-load these into the original notebook example (for visualization of the segmentation masks) with the following code: 95 | ``` 96 | # segmentation_model 97 | model = unet_model(OUTPUT_CHANNELS) 98 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', 99 | metrics=['accuracy']) 100 | model.load_weights("/path/to/segmentation_model/weights-0001") 101 | show_predictions(test_dataset) 102 | 103 | # segmentation_model.h5 104 | model = tf.keras.models.load_model("/path/to/segmentation_model.h5") 105 | show_predictions(test_dataset) 106 | 107 | # segmentation_export 108 | model = tf.keras.experimental.load_from_saved_model("/path/to/segmentation_export") 109 | show_predictions(test_dataset) 110 | ``` 111 | -------------------------------------------------------------------------------- /examples/segmentation/segmentation_dist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # 4 | #@title Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import, division, print_function, unicode_literals 17 | 18 | from tensorflow_examples.models.pix2pix import pix2pix 19 | import json 20 | import os 21 | import tensorflow_datasets as tfds 22 | import tensorflow as tf 23 | 24 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 25 | 26 | tf_config = json.loads(os.environ.get('TF_CONFIG')) 27 | print("tf_config = ", tf_config) 28 | print("I'm {}:{}".format(tf_config['task']['type'], tf_config['task']['index'])) 29 | 30 | dataset, info = tfds.load('oxford_iiit_pet:3.2.0', with_info=True) 31 | 32 | 33 | def normalize(input_image, input_mask): 34 | input_image = tf.cast(input_image, tf.float32)/128.0 - 1 35 | input_mask -= 1 36 | return input_image, input_mask 37 | 38 | 39 | @tf.function 40 | def load_image_train(datapoint): 41 | input_image = tf.image.resize(datapoint['image'], (128, 128)) 42 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) 43 | 44 | if tf.random.uniform(()) > 0.5: 45 | input_image = tf.image.flip_left_right(input_image) 46 | input_mask = tf.image.flip_left_right(input_mask) 47 | 48 | input_image, input_mask = normalize(input_image, input_mask) 49 | 50 | return input_image, input_mask 51 | 52 | 53 | def load_image_test(datapoint): 54 | input_image = tf.image.resize(datapoint['image'], (128, 128)) 55 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) 56 | input_image, input_mask = normalize(input_image, input_mask) 57 | return input_image, input_mask 58 | 59 | 60 | TRAIN_LENGTH = info.splits['train'].num_examples 61 | BATCH_SIZE = 64 62 | BUFFER_SIZE = 1000 63 | STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE 64 | 65 | train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) 66 | test = dataset['test'].map(load_image_test) 67 | 68 | train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() 69 | train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 70 | test_dataset = test.batch(BATCH_SIZE) 71 | 72 | OUTPUT_CHANNELS = 3 73 | 74 | with strategy.scope(): 75 | base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False) 76 | 77 | # Use the activations of these layers 78 | layer_names = [ 79 | 'block_1_expand_relu', # 64x64 80 | 'block_3_expand_relu', # 32x32 81 | 'block_6_expand_relu', # 16x16 82 | 'block_13_expand_relu', # 8x8 83 | 'block_16_project', # 4x4 84 | ] 85 | layers = [base_model.get_layer(name).output for name in layer_names] 86 | 87 | # Create the feature extraction model 88 | down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers) 89 | 90 | down_stack.trainable = False 91 | 92 | up_stack = [ 93 | pix2pix.upsample(512, 3), # 4x4 -> 8x8 94 | pix2pix.upsample(256, 3), # 8x8 -> 16x16 95 | pix2pix.upsample(128, 3), # 16x16 -> 32x32 96 | pix2pix.upsample(64, 3), # 32x32 -> 64x64 97 | ] 98 | 99 | def unet_model(output_channels): 100 | 101 | # This is the last layer of the model 102 | last = tf.keras.layers.Conv2DTranspose( 103 | output_channels, 3, strides=2, 104 | padding='same', activation='softmax') # 64x64 -> 128x128 105 | 106 | inputs = tf.keras.layers.Input(shape=[128, 128, 3]) 107 | x = inputs 108 | 109 | # Downsampling through the model 110 | skips = down_stack(x) 111 | x = skips[-1] 112 | skips = reversed(skips[:-1]) 113 | 114 | # Upsampling and establishing the skip connections 115 | for up, skip in zip(up_stack, skips): 116 | x = up(x) 117 | concat = tf.keras.layers.Concatenate() 118 | x = concat([x, skip]) 119 | 120 | x = last(x) 121 | 122 | return tf.keras.Model(inputs=inputs, outputs=x) 123 | 124 | model = unet_model(OUTPUT_CHANNELS) 125 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', 126 | metrics=['accuracy']) 127 | 128 | # Training only (since we're using command-line) 129 | # def create_mask(pred_mask): 130 | # pred_mask = tf.argmax(pred_mask, axis=-1) 131 | # pred_mask = pred_mask[..., tf.newaxis] 132 | # return pred_mask[0] 133 | # 134 | # 135 | # def show_predictions(dataset=None, num=1): 136 | # if dataset: 137 | # for image, mask in dataset.take(num): 138 | # pred_mask = model.predict(image) 139 | # display([image[0], mask[0], create_mask(pred_mask)]) 140 | # else: 141 | # display([sample_image, sample_mask, 142 | # create_mask(model.predict(sample_image[tf.newaxis, ...]))]) 143 | # 144 | # 145 | # class DisplayCallback(tf.keras.callbacks.Callback): 146 | # def on_epoch_end(self, epoch, logs=None): 147 | # clear_output(wait=True) 148 | # show_predictions() 149 | # print ('\nSample Prediction after epoch {}\n'.format(epoch+1)) 150 | # 151 | 152 | # EPOCHS = 20 153 | EPOCHS = 1 154 | VAL_SUBSPLITS = 5 155 | VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS 156 | 157 | model_history = model.fit(train_dataset, epochs=EPOCHS, 158 | steps_per_epoch=STEPS_PER_EPOCH, 159 | validation_steps=VALIDATION_STEPS, 160 | validation_data=test_dataset) 161 | 162 | if tf_config['task']['index'] == 0: 163 | model.save_weights("keras_weights", save_format='h5') 164 | -------------------------------------------------------------------------------- /examples/mnist/keras/mnist_pipeline.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras 2 | 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | 6 | def main_fun(args, ctx): 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflowonspark import compat, TFNode 10 | 11 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 12 | 13 | def build_and_compile_cnn_model(): 14 | model = tf.keras.Sequential([ 15 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 16 | tf.keras.layers.MaxPooling2D(), 17 | tf.keras.layers.Flatten(), 18 | tf.keras.layers.Dense(64, activation='relu'), 19 | tf.keras.layers.Dense(10, activation='softmax') 20 | ]) 21 | model.compile( 22 | loss=tf.keras.losses.sparse_categorical_crossentropy, 23 | optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), 24 | metrics=['accuracy']) 25 | return model 26 | 27 | # single node 28 | # single_worker_model = build_and_compile_cnn_model() 29 | # single_worker_model.fit(x=train_datasets, epochs=3) 30 | 31 | tf_feed = TFNode.DataFeed(ctx.mgr, False) 32 | 33 | def rdd_generator(): 34 | while not tf_feed.should_stop(): 35 | batch = tf_feed.next_batch(1) 36 | if len(batch) > 0: 37 | example = batch[0] 38 | image = np.array(example[0]).astype(np.float32) / 255.0 39 | image = np.reshape(image, (28, 28, 1)) 40 | label = np.array(example[1]).astype(np.float32) 41 | label = np.reshape(label, (1,)) 42 | yield (image, label) 43 | else: 44 | return 45 | 46 | ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1]))) 47 | ds = ds.batch(args.batch_size) 48 | 49 | # this fails 50 | # callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=args.model_dir)] 51 | tf.io.gfile.makedirs(args.model_dir) 52 | filepath = args.model_dir + "/weights-{epoch:04d}" 53 | callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True)] 54 | 55 | with strategy.scope(): 56 | multi_worker_model = build_and_compile_cnn_model() 57 | 58 | # Note: MultiWorkerMirroredStrategy (CollectiveAllReduceStrategy) is synchronous, 59 | # so we need to ensure that all workers complete training before any of them run out of data from the RDD. 60 | # And given that Spark RDD partitions (and partition sizes) can be non-evenly divisible by num_workers, 61 | # we'll just stop training at 90% of the total expected number of steps. 62 | steps_per_epoch = 60000 / args.batch_size 63 | steps_per_epoch_per_worker = steps_per_epoch / ctx.num_workers 64 | max_steps_per_worker = steps_per_epoch_per_worker * 0.9 65 | 66 | multi_worker_model.fit(x=ds, epochs=args.epochs, steps_per_epoch=max_steps_per_worker, callbacks=callbacks) 67 | 68 | compat.export_saved_model(multi_worker_model, args.export_dir, ctx.job_name == 'chief') 69 | 70 | # terminating feed tells spark to skip processing further partitions 71 | tf_feed.terminate() 72 | 73 | 74 | if __name__ == '__main__': 75 | import argparse 76 | from pyspark.context import SparkContext 77 | from pyspark.conf import SparkConf 78 | from pyspark.sql import SparkSession 79 | from pyspark.sql.functions import udf 80 | from pyspark.sql.types import IntegerType 81 | from tensorflowonspark import dfutil 82 | from tensorflowonspark.pipeline import TFEstimator, TFModel 83 | 84 | sc = SparkContext(conf=SparkConf().setAppName("mnist_keras")) 85 | spark = SparkSession(sc) 86 | 87 | executors = sc._conf.get("spark.executor.instances") 88 | num_executors = int(executors) if executors is not None else 1 89 | 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 92 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 93 | parser.add_argument("--epochs", help="number of epochs", type=int, default=3) 94 | parser.add_argument("--format", help="example format: (csv|tfr)", choices=["csv", "tfr"], default="csv") 95 | parser.add_argument("--images_labels", help="path to MNIST images and labels in parallelized format") 96 | parser.add_argument("--mode", help="train|inference", choices=["train", "inference"], default="train") 97 | parser.add_argument("--model_dir", help="path to save checkpoint", default="mnist_model") 98 | parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export") 99 | parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions") 100 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 101 | 102 | args = parser.parse_args() 103 | print("args:", args) 104 | 105 | if args.format == 'tfr': 106 | # load TFRecords as a DataFrame 107 | df = dfutil.loadTFRecords(sc, args.images_labels) 108 | else: # args.format == 'csv': 109 | # create RDD of input data 110 | def parse(ln): 111 | vec = [int(x) for x in ln.split(',')] 112 | return (vec[1:], vec[0]) 113 | 114 | images_labels = sc.textFile(args.images_labels).map(parse) 115 | df = spark.createDataFrame(images_labels, ['image', 'label']) 116 | 117 | df.show() 118 | 119 | if args.mode == 'train': 120 | estimator = TFEstimator(main_fun, args) \ 121 | .setInputMapping({'image': 'image', 'label': 'label'}) \ 122 | .setModelDir(args.model_dir) \ 123 | .setExportDir(args.export_dir) \ 124 | .setClusterSize(args.cluster_size) \ 125 | .setTensorboard(args.tensorboard) \ 126 | .setEpochs(args.epochs) \ 127 | .setBatchSize(args.batch_size) \ 128 | .setGraceSecs(60) 129 | model = estimator.fit(df) 130 | else: # args.mode == 'inference': 131 | # using a trained/exported model 132 | model = TFModel(args) \ 133 | .setInputMapping({'image': 'conv2d_input'}) \ 134 | .setOutputMapping({'dense_1': 'prediction'}) \ 135 | .setSignatureDefKey('serving_default') \ 136 | .setExportDir(args.export_dir) \ 137 | .setBatchSize(args.batch_size) 138 | 139 | def argmax_fn(l): 140 | return max(range(len(l)), key=lambda i: l[i]) 141 | 142 | argmax = udf(argmax_fn, IntegerType()) 143 | 144 | preds = model.transform(df).withColumn('argmax', argmax('prediction')) 145 | preds.show() 146 | preds.write.json(args.output) 147 | -------------------------------------------------------------------------------- /examples/mnist/estimator/mnist_spark_streaming.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator 2 | 3 | 4 | def main_fun(args, ctx): 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow_datasets as tfds 8 | from tensorflowonspark import TFNode 9 | 10 | tfds.disable_progress_bar() 11 | 12 | BUFFER_SIZE = args.buffer_size 13 | BATCH_SIZE = args.batch_size 14 | LEARNING_RATE = args.learning_rate 15 | 16 | tf_feed = TFNode.DataFeed(ctx.mgr) 17 | 18 | def rdd_generator(): 19 | while not tf_feed.should_stop(): 20 | batch = tf_feed.next_batch(1) 21 | if len(batch) > 0: 22 | example = batch[0] 23 | image = np.array(example[0]).astype(np.float32) / 255.0 24 | image = np.reshape(image, (28, 28, 1)) 25 | label = np.array(example[1]).astype(np.float32) 26 | label = np.reshape(label, (1,)) 27 | yield (image, label) 28 | else: 29 | return 30 | 31 | def input_fn(mode, input_context=None): 32 | if mode == tf.estimator.ModeKeys.TRAIN: 33 | # Note: Spark is responsible for feeding data via streaming RDD 34 | ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1]))) 35 | return ds.batch(BATCH_SIZE) 36 | else: 37 | raise Exception("I'm evaluating: mode={}, input_context={}".format(mode, input_context)) 38 | 39 | def scale(image, label): 40 | image = tf.cast(image, tf.float32) / 255.0 41 | return image, label 42 | 43 | mnist = tfds.load(name='mnist', with_info=True, as_supervised=True) 44 | ds = mnist['test'] 45 | if input_context: 46 | ds = ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) 47 | return ds.map(scale).batch(BATCH_SIZE) 48 | 49 | def serving_input_receiver_fn(): 50 | features = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='features') 51 | receiver_tensors = {'conv2d_input': features} 52 | return tf.estimator.export.ServingInputReceiver(receiver_tensors, receiver_tensors) 53 | 54 | def model_fn(features, labels, mode): 55 | model = tf.keras.Sequential([ 56 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 57 | tf.keras.layers.MaxPooling2D(), 58 | tf.keras.layers.Flatten(), 59 | tf.keras.layers.Dense(64, activation='relu'), 60 | tf.keras.layers.Dense(10, activation='softmax') 61 | ]) 62 | logits = model(features, training=False) 63 | 64 | if mode == tf.estimator.ModeKeys.PREDICT: 65 | predictions = {'logits': logits} 66 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 67 | 68 | optimizer = tf.compat.v1.train.GradientDescentOptimizer( 69 | learning_rate=LEARNING_RATE) 70 | loss = tf.keras.losses.SparseCategoricalCrossentropy( 71 | from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits) 72 | loss = tf.reduce_sum(input_tensor=loss) * (1. / BATCH_SIZE) 73 | if mode == tf.estimator.ModeKeys.EVAL: 74 | return tf.estimator.EstimatorSpec(mode, loss=loss) 75 | 76 | return tf.estimator.EstimatorSpec( 77 | mode=mode, 78 | loss=loss, 79 | train_op=optimizer.minimize( 80 | loss, tf.compat.v1.train.get_or_create_global_step())) 81 | 82 | # Note: the original example used MultiWorkerMirroredStrategy which is a synchronous training strategy. 83 | # Since streaming data arrives irregularly, we must use the asynchronous ParameterServerStrategy 84 | # to allow data to be processed as it arrives and to avoid deadlocks. 85 | # strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 86 | strategy = tf.distribute.experimental.ParameterServerStrategy() 87 | config = tf.estimator.RunConfig(train_distribute=strategy, save_checkpoints_steps=100) 88 | 89 | classifier = tf.estimator.Estimator( 90 | model_fn=model_fn, model_dir=args.model_dir, config=config) 91 | 92 | # exporter = tf.estimator.FinalExporter("serving", serving_input_receiver_fn=serving_input_receiver_fn) 93 | 94 | tf.estimator.train_and_evaluate( 95 | classifier, 96 | train_spec=tf.estimator.TrainSpec(input_fn=input_fn), 97 | eval_spec=tf.estimator.EvalSpec(input_fn=input_fn) 98 | # eval_spec=tf.estimator.EvalSpec(input_fn=input_fn, exporters=exporter) 99 | ) 100 | 101 | if ctx.job_name == 'chief': 102 | print("Exporting saved_model to {}".format(args.export_dir)) 103 | classifier.export_saved_model(args.export_dir, serving_input_receiver_fn) 104 | 105 | 106 | if __name__ == "__main__": 107 | 108 | from pyspark.context import SparkContext 109 | from pyspark.conf import SparkConf 110 | from pyspark.streaming import StreamingContext 111 | from tensorflowonspark import TFCluster 112 | import argparse 113 | 114 | sc = SparkContext(conf=SparkConf().setAppName("mnist_estimator")) 115 | ssc = StreamingContext(sc, 60) # group data into intervals of one minute 116 | executors = sc._conf.get("spark.executor.instances") 117 | num_executors = int(executors) if executors is not None else 1 118 | 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 121 | parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000) 122 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 123 | parser.add_argument("--images_labels", help="path to MNIST images and labels in parallelized format") 124 | parser.add_argument("--learning_rate", help="learning rate", type=float, default=1e-3) 125 | parser.add_argument("--model_dir", help="path to save checkpoint", default="mnist_model") 126 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 127 | 128 | args = parser.parse_args() 129 | print("args:", args) 130 | 131 | # create RDD of input data 132 | def parse(ln): 133 | vec = [int(x) for x in ln.split(',')] 134 | return (vec[1:], vec[0]) 135 | 136 | stream = ssc.textFileStream(args.images_labels) 137 | images_labels = stream.map(parse) 138 | 139 | cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=1, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='chief') 140 | cluster.train(images_labels, feed_timeout=86400) # extend feed timeout to 24hrs for streaming data to arrive 141 | ssc.start() 142 | cluster.shutdown(ssc) 143 | -------------------------------------------------------------------------------- /examples/mnist/estimator/mnist_spark.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator 2 | 3 | 4 | def main_fun(args, ctx): 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow_datasets as tfds 8 | from tensorflowonspark import TFNode 9 | 10 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 11 | 12 | tfds.disable_progress_bar() 13 | 14 | class StopFeedHook(tf.estimator.SessionRunHook): 15 | """SessionRunHook to terminate InputMode.SPARK RDD feeding if the training loop exits before the entire RDD is consumed.""" 16 | 17 | def __init__(self, feed): 18 | self.feed = feed 19 | 20 | def end(self, session): 21 | self.feed.terminate() 22 | self.feed.next_batch(1) 23 | 24 | BUFFER_SIZE = args.buffer_size 25 | BATCH_SIZE = args.batch_size 26 | LEARNING_RATE = args.learning_rate 27 | 28 | tf_feed = TFNode.DataFeed(ctx.mgr) 29 | 30 | def rdd_generator(): 31 | while not tf_feed.should_stop(): 32 | batch = tf_feed.next_batch(1) 33 | if len(batch) > 0: 34 | example = batch[0] 35 | image = np.array(example[0]).astype(np.float32) / 255.0 36 | image = np.reshape(image, (28, 28, 1)) 37 | label = np.array(example[1]).astype(np.float32) 38 | label = np.reshape(label, (1,)) 39 | yield (image, label) 40 | else: 41 | return 42 | 43 | def input_fn(mode, input_context=None): 44 | if mode == tf.estimator.ModeKeys.TRAIN: 45 | # Note: Spark is responsible for sharding/repeating/shuffling the data via RDD 46 | ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1]))) 47 | return ds.batch(BATCH_SIZE) 48 | else: 49 | raise Exception("I'm evaluating: mode={}, input_context={}".format(mode, input_context)) 50 | 51 | def scale(image, label): 52 | image = tf.cast(image, tf.float32) / 255.0 53 | return image, label 54 | 55 | mnist = tfds.load(name='mnist', with_info=True, as_supervised=True) 56 | ds = mnist['test'] 57 | if input_context: 58 | ds = ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) 59 | return ds.map(scale).batch(BATCH_SIZE) 60 | 61 | def serving_input_receiver_fn(): 62 | features = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='features') 63 | receiver_tensors = {'conv2d_input': features} 64 | return tf.estimator.export.ServingInputReceiver(receiver_tensors, receiver_tensors) 65 | 66 | def model_fn(features, labels, mode): 67 | model = tf.keras.Sequential([ 68 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 69 | tf.keras.layers.MaxPooling2D(), 70 | tf.keras.layers.Flatten(), 71 | tf.keras.layers.Dense(64, activation='relu'), 72 | tf.keras.layers.Dense(10, activation='softmax') 73 | ]) 74 | logits = model(features, training=False) 75 | 76 | if mode == tf.estimator.ModeKeys.PREDICT: 77 | predictions = {'logits': logits} 78 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 79 | 80 | optimizer = tf.compat.v1.train.GradientDescentOptimizer( 81 | learning_rate=LEARNING_RATE) 82 | loss = tf.keras.losses.SparseCategoricalCrossentropy( 83 | from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits) 84 | loss = tf.reduce_sum(input_tensor=loss) * (1. / BATCH_SIZE) 85 | if mode == tf.estimator.ModeKeys.EVAL: 86 | return tf.estimator.EstimatorSpec(mode, loss=loss) 87 | 88 | return tf.estimator.EstimatorSpec( 89 | mode=mode, 90 | loss=loss, 91 | train_op=optimizer.minimize( 92 | loss, tf.compat.v1.train.get_or_create_global_step())) 93 | 94 | config = tf.estimator.RunConfig(train_distribute=strategy, save_checkpoints_steps=100) 95 | 96 | classifier = tf.estimator.Estimator( 97 | model_fn=model_fn, model_dir=args.model_dir, config=config) 98 | 99 | # exporter = tf.estimator.FinalExporter("serving", serving_input_receiver_fn=serving_input_receiver_fn) 100 | 101 | # Note: MultiWorkerMirroredStrategy (CollectiveAllReduceStrategy) is synchronous, 102 | # so we need to ensure that all workers complete training before any of them run out of data from the RDD. 103 | # And given that Spark RDD partitions (and partition sizes) can be non-evenly divisible by num_workers, 104 | # we'll just stop training at 90% of the total expected number of steps. 105 | steps = 60000 * args.epochs / args.batch_size 106 | steps_per_worker = steps / ctx.num_workers 107 | max_steps_per_worker = steps_per_worker * 0.9 108 | 109 | tf.estimator.train_and_evaluate( 110 | classifier, 111 | train_spec=tf.estimator.TrainSpec(input_fn=input_fn, max_steps=max_steps_per_worker, hooks=[StopFeedHook(tf_feed)]), 112 | eval_spec=tf.estimator.EvalSpec(input_fn=input_fn) 113 | # eval_spec=tf.estimator.EvalSpec(input_fn=input_fn, exporters=exporter) 114 | ) 115 | 116 | if ctx.job_name == 'chief': 117 | print("Exporting saved_model to {}".format(args.export_dir)) 118 | classifier.export_saved_model(args.export_dir, serving_input_receiver_fn) 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | from pyspark.context import SparkContext 124 | from pyspark.conf import SparkConf 125 | from tensorflowonspark import TFCluster 126 | import argparse 127 | 128 | sc = SparkContext(conf=SparkConf().setAppName("mnist_estimator")) 129 | executors = sc._conf.get("spark.executor.instances") 130 | num_executors = int(executors) if executors is not None else 1 131 | 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 134 | parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000) 135 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 136 | parser.add_argument("--epochs", help="number of epochs", type=int, default=3) 137 | parser.add_argument("--images_labels", help="path to MNIST images and labels in parallelized format") 138 | parser.add_argument("--learning_rate", help="learning rate", type=float, default=1e-3) 139 | parser.add_argument("--model_dir", help="path to save checkpoint", default="mnist_model") 140 | parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export") 141 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 142 | 143 | args = parser.parse_args() 144 | print("args:", args) 145 | 146 | # create RDD of input data 147 | def parse(ln): 148 | vec = [int(x) for x in ln.split(',')] 149 | return (vec[1:], vec[0]) 150 | 151 | images_labels = sc.textFile(args.images_labels).map(parse) 152 | 153 | cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.SPARK, log_dir=args.model_dir, master_node='chief') 154 | cluster.train(images_labels, args.epochs) 155 | cluster.shutdown(grace_secs=60) # allow time for the chief to export model after data feeding 156 | -------------------------------------------------------------------------------- /Code-of-Conduct.md: -------------------------------------------------------------------------------- 1 | # Yahoo Open Source Code of Conduct 2 | 3 | ## Summary 4 | This Code of Conduct is our way to encourage good behavior and discourage bad behavior in our open source projects. We invite participation from many people to bring different perspectives to our projects. We will do our part to foster a welcoming and professional environment free of harassment. We expect participants to communicate professionally and thoughtfully during their involvement with this project. 5 | 6 | Participants may lose their good standing by engaging in misconduct. For example: insulting, threatening, or conveying unwelcome sexual content. We ask participants who observe conduct issues to report the incident directly to the project's Response Team at opensource-conduct@yahooinc.com. Yahoo will assign a respondent to address the issue. We may remove harassers from this project. 7 | 8 | This code does not replace the terms of service or acceptable use policies of the websites used to support this project. We acknowledge that participants may be subject to additional conduct terms based on their employment which may govern their online expressions. 9 | 10 | ## Details 11 | This Code of Conduct makes our expectations of participants in this community explicit. 12 | * We forbid harassment and abusive speech within this community. 13 | * We request participants to report misconduct to the project’s Response Team. 14 | * We urge participants to refrain from using discussion forums to play out a fight. 15 | 16 | ### Expected Behaviors 17 | We expect participants in this community to conduct themselves professionally. Since our primary mode of communication is text on an online forum (e.g. issues, pull requests, comments, emails, or chats) devoid of vocal tone, gestures, or other context that is often vital to understanding, it is important that participants are attentive to their interaction style. 18 | 19 | * **Assume positive intent.** We ask community members to assume positive intent on the part of other people’s communications. We may disagree on details, but we expect all suggestions to be supportive of the community goals. 20 | * **Respect participants.** We expect occasional disagreements. Open Source projects are learning experiences. Ask, explore, challenge, and then _respectfully_ state if you agree or disagree. If your idea is rejected, be more persuasive not bitter. 21 | * **Welcoming to new members.** New members bring new perspectives. Some ask questions that have been addressed before. _Kindly_ point to existing discussions. Everyone is new to every project once. 22 | * **Be kind to beginners.** Beginners use open source projects to get experience. They might not be talented coders yet, and projects should not accept poor quality code. But we were all beginners once, and we need to engage kindly. 23 | * **Consider your impact on others.** Your work will be used by others, and you depend on the work of others. We expect community members to be considerate and establish a balance their self-interest with communal interest. 24 | * **Use words carefully.** We may not understand intent when you say something ironic. Often, people will misinterpret sarcasm in online communications. We ask community members to communicate plainly. 25 | * **Leave with class.** When you wish to resign from participating in this project for any reason, you are free to fork the code and create a competitive project. Open Source explicitly allows this. Your exit should not be dramatic or bitter. 26 | 27 | ### Unacceptable Behaviors 28 | Participants remain in good standing when they do not engage in misconduct or harassment (some examples follow). We do not list all forms of harassment, nor imply some forms of harassment are not worthy of action. Any participant who *feels* harassed or *observes* harassment, should report the incident to the Response Team. 29 | * **Don't be a bigot.** Calling out project members by their identity or background in a negative or insulting manner. This includes, but is not limited to, slurs or insinuations related to protected or suspect classes e.g. race, color, citizenship, national origin, political belief, religion, sexual orientation, gender identity and expression, age, size, culture, ethnicity, genetic features, language, profession, national minority status, mental or physical ability. 30 | * **Don't insult.** Insulting remarks about a person’s lifestyle practices. 31 | * **Don't dox.** Revealing private information about other participants without explicit permission. 32 | * **Don't intimidate.** Threats of violence or intimidation of any project member. 33 | * **Don't creep.** Unwanted sexual attention or content unsuited for the subject of this project. 34 | * **Don't inflame.** We ask that victim of harassment not address their grievances in the public forum, as this often intensifies the problem. Report it, and let us address it off-line. 35 | * **Don't disrupt.** Sustained disruptions in a discussion. 36 | 37 | ### Reporting Issues 38 | If you experience or witness misconduct, or have any other concerns about the conduct of members of this project, please report it by contacting our Response Team at opensource-conduct@yahooinc.com who will handle your report with discretion. Your report should include: 39 | * Your preferred contact information. We cannot process anonymous reports. 40 | * Names (real or usernames) of those involved in the incident. 41 | * Your account of what occurred, and if the incident is ongoing. Please provide links to or transcripts of the publicly available records (e.g. a mailing list archive or a public IRC logger), so that we can review it. 42 | * Any additional information that may be helpful to achieve resolution. 43 | 44 | After filing a report, a representative will contact you directly to review the incident and ask additional questions. If a member of the Yahoo Response Team is named in an incident report, that member will be recused from handling your incident. If the complaint originates from a member of the Response Team, it will be addressed by a different member of the Response Team. We will consider reports to be confidential for the purpose of protecting victims of abuse. 45 | 46 | ### Scope 47 | Yahoo will assign a Response Team member with admin rights on the project and legal rights on the project copyright. The Response Team is empowered to restrict some privileges to the project as needed. Since this project is governed by an open source license, any participant may fork the code under the terms of the project license. The Response Team’s goal is to preserve the project if possible, and will restrict or remove participation from those who disrupt the project. 48 | 49 | This code does not replace the terms of service or acceptable use policies that are provided by the websites used to support this community. Nor does this code apply to communications or actions that take place outside of the context of this community. Many participants in this project are also subject to codes of conduct based on their employment. This code is a social-contract that informs participants of our social expectations. It is not a terms of service or legal contract. 50 | 51 | ## License and Acknowledgment. 52 | This text is shared under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0/). This code is based on a study conducted by the [TODO Group](https://todogroup.org/) of many codes used in the open source community. If you have feedback about this code, contact our Response Team at the address listed above. 53 | -------------------------------------------------------------------------------- /tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import test 5 | import unittest 6 | 7 | from tensorflowonspark import compat 8 | from tensorflowonspark.pipeline import HasBatchSize, HasSteps, Namespace, TFEstimator, TFParams 9 | from tensorflow.keras import Sequential 10 | from tensorflow.keras.layers import Dense 11 | 12 | 13 | class PipelineTest(test.SparkTest): 14 | @classmethod 15 | def setUpClass(cls): 16 | super(PipelineTest, cls).setUpClass() 17 | 18 | # create an artificial training dataset of two features with labels computed from known weights 19 | np.random.seed(1234) 20 | cls.features = np.random.rand(1000, 2) 21 | cls.weights = np.array([3.14, 1.618]) 22 | cls.labels = np.matmul(cls.features, cls.weights) 23 | # convert to Python types for use with Spark DataFrames 24 | cls.train_examples = [(cls.features[i].tolist(), [cls.labels[i].item()]) for i in range(1000)] 25 | # create a simple test dataset 26 | cls.test_examples = [([1.0, 1.0], [0.0])] 27 | 28 | # define model_dir and export_dir for tests 29 | cls.model_dir = os.getcwd() + os.sep + "test_model" 30 | cls.export_dir = os.getcwd() + os.sep + "test_export" 31 | cls.tfrecord_dir = os.getcwd() + os.sep + "test_tfr" 32 | 33 | @classmethod 34 | def tearDownClass(cls): 35 | super(PipelineTest, cls).tearDownClass() 36 | 37 | def setUp(self): 38 | super(PipelineTest, self).setUp() 39 | # remove any prior test artifacts 40 | shutil.rmtree(self.model_dir, ignore_errors=True) 41 | shutil.rmtree(self.export_dir, ignore_errors=True) 42 | shutil.rmtree(self.tfrecord_dir, ignore_errors=True) 43 | 44 | def tearDown(self): 45 | # Note: don't clean up artifacts after test (in case we need to view/debug) 46 | pass 47 | 48 | def test_namespace(self): 49 | """Namespace class initializers""" 50 | # from dictionary 51 | d = {'string': 'foo', 'integer': 1, 'float': 3.14, 'array': [1, 2, 3], 'map': {'a': 1, 'b': 2}} 52 | n1 = Namespace(d) 53 | self.assertEqual(n1.string, 'foo') 54 | self.assertEqual(n1.integer, 1) 55 | self.assertEqual(n1.float, 3.14) 56 | self.assertEqual(n1.array, [1, 2, 3]) 57 | self.assertEqual(n1.map, {'a': 1, 'b': 2}) 58 | self.assertTrue('string' in n1) 59 | self.assertFalse('extra' in n1) 60 | 61 | # from namespace 62 | n2 = Namespace(n1) 63 | self.assertEqual(n2.string, 'foo') 64 | self.assertEqual(n2.integer, 1) 65 | self.assertEqual(n2.float, 3.14) 66 | self.assertEqual(n2.array, [1, 2, 3]) 67 | self.assertEqual(n2.map, {'a': 1, 'b': 2}) 68 | self.assertTrue('string' in n2) 69 | self.assertFalse('extra' in n2) 70 | 71 | # from argv list 72 | argv = ["--foo", "1", "--bar", "test", "--baz", "3.14"] 73 | n3 = Namespace(argv) 74 | self.assertEqual(n3.argv, argv) 75 | 76 | def test_TFParams(self): 77 | """Merging namespace args w/ ML Params""" 78 | class Foo(TFParams, HasBatchSize, HasSteps): 79 | def __init__(self, args): 80 | super(Foo, self).__init__() 81 | self.args = args 82 | 83 | n = Namespace({'a': 1, 'b': 2}) 84 | f = Foo(n).setBatchSize(10).setSteps(100) 85 | combined_args = f.merge_args_params() 86 | expected_args = Namespace({'a': 1, 'b': 2, 'batch_size': 10, 'steps': 100}) 87 | self.assertEqual(combined_args, expected_args) 88 | 89 | def test_spark_saved_model(self): 90 | """InputMode.SPARK TFEstimator w/ explicit saved_model export for TFModel inferencing""" 91 | 92 | def _spark_train(args, ctx): 93 | """Basic linear regression in a distributed TF cluster using InputMode.SPARK""" 94 | import tensorflow as tf 95 | from tensorflowonspark import TFNode 96 | 97 | tf.compat.v1.reset_default_graph() 98 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 99 | 100 | with strategy.scope(): 101 | model = Sequential() 102 | model.add(Dense(1, activation='linear', input_shape=[2])) 103 | model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.2), loss='mse', metrics=['mse']) 104 | model.summary() 105 | 106 | tf_feed = TFNode.DataFeed(ctx.mgr, input_mapping=args.input_mapping) 107 | 108 | def rdd_generator(): 109 | while not tf_feed.should_stop(): 110 | batch = tf_feed.next_batch(1) 111 | if len(batch['x']) > 0: 112 | features = batch['x'][0] 113 | label = batch['y_'][0] 114 | yield (features, label) 115 | else: 116 | return 117 | 118 | ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([2]), tf.TensorShape([1]))) 119 | # disable auto-sharding since we're feeding from an RDD generator 120 | options = tf.data.Options() 121 | compat.disable_auto_shard(options) 122 | ds = ds.with_options(options) 123 | ds = ds.batch(args.batch_size) 124 | 125 | # only train 90% of each epoch to account for uneven RDD partition sizes 126 | steps_per_epoch = 1000 * 0.9 // (args.batch_size * ctx.num_workers) 127 | 128 | tf.io.gfile.makedirs(args.model_dir) 129 | filepath = args.model_dir + "/weights-{epoch:04d}" 130 | callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, load_weights_on_restart=True, save_weights_only=True)] 131 | 132 | model.fit(ds, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks) 133 | 134 | # This fails with: "NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy" 135 | # model.fit_generator(ds, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks) 136 | 137 | if args.export_dir: 138 | print("exporting model to: {}".format(args.export_dir)) 139 | compat.export_saved_model(model, args.export_dir, ctx.job_name == 'chief') 140 | 141 | tf_feed.terminate() 142 | 143 | # create a Spark DataFrame of training examples (features, labels) 144 | rdd = self.sc.parallelize(self.train_examples, 2) 145 | trainDF = rdd.toDF(['col1', 'col2']) 146 | 147 | # train and export model 148 | args = {} 149 | estimator = TFEstimator(_spark_train, args) \ 150 | .setInputMapping({'col1': 'x', 'col2': 'y_'}) \ 151 | .setModelDir(self.model_dir) \ 152 | .setExportDir(self.export_dir) \ 153 | .setClusterSize(self.num_workers) \ 154 | .setMasterNode("chief") \ 155 | .setNumPS(0) \ 156 | .setBatchSize(1) \ 157 | .setEpochs(1) 158 | model = estimator.fit(trainDF) 159 | self.assertTrue(os.path.isdir(self.export_dir)) 160 | 161 | # create a Spark DataFrame of test examples (features, labels) 162 | testDF = self.spark.createDataFrame(self.test_examples, ['c1', 'c2']) 163 | 164 | # test saved_model using exported signature 165 | model.setTagSet('serve') \ 166 | .setSignatureDefKey('serving_default') \ 167 | .setInputMapping({'c1': 'dense_input'}) \ 168 | .setOutputMapping({'dense': 'cout'}) 169 | preds = model.transform(testDF).head() # take first/only result 170 | pred = preds.cout[0] # unpack scalar from tensor 171 | expected = np.sum(self.weights) 172 | self.assertAlmostEqual(pred, expected, 2) 173 | 174 | 175 | if __name__ == '__main__': 176 | unittest.main() 177 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | com.yahoo.ml 7 | tensorflowonspark 8 | 2.2.5-SNAPSHOT 9 | jar 10 | tensorflowonspark 11 | Spark Scala inferencing for TensorFlowOnSpark 12 | 13 | 14 | 1.8 15 | 1.8 16 | 17 | 2.10.1 18 | 3.5.3 19 | 2.4.3 20 | 2.20.1 21 | 3.0.1 22 | 2.12.10 23 | 3.2.1 24 | 1.1.2 25 | 3.0.5 26 | 1.0 27 | 3.7.0 28 | 1.15.0 29 | 30 | 31 | 32 | org.apache.spark 33 | spark-core_2.12 34 | ${spark.version} 35 | provided 36 | 37 | 38 | org.apache.spark 39 | spark-sql_2.12 40 | ${spark.version} 41 | provided 42 | 43 | 44 | org.apache.spark 45 | spark-mllib_2.12 46 | ${spark.version} 47 | provided 48 | 49 | 50 | org.scala-lang 51 | scala-library 52 | ${scala.version} 53 | 54 | 55 | org.scala-lang.modules 56 | scala-parser-combinators_2.12 57 | ${scala-parser-combinators.version} 58 | 59 | 60 | org.tensorflow 61 | tensorflow 62 | ${tensorflow.version} 63 | 64 | 65 | org.tensorflow 66 | tensorflow-hadoop 67 | ${tensorflow.version} 68 | 69 | 70 | com.google.protobuf 71 | protobuf-java 72 | 3.16.1 73 | 74 | 75 | org.scalatest 76 | scalatest_2.12 77 | ${scalatest.version} 78 | test 79 | 80 | 81 | com.github.scopt 82 | scopt_2.12 83 | ${scopt.version} 84 | 85 | 86 | org.json4s 87 | json4s-native_2.12 88 | ${json4s-native.version} 89 | 90 | 91 | 92 | 93 | 94 | 95 | net.alchim31.maven 96 | scala-maven-plugin 97 | ${scala-maven-plugin.version} 98 | 99 | 100 | 101 | compile 102 | testCompile 103 | 104 | 105 | 106 | 107 | ${scala.version} 108 | 109 | 110 | 111 | 112 | org.apache.maven.plugins 113 | maven-surefire-plugin 114 | ${maven-surefire-plugin.version} 115 | 116 | true 117 | 118 | 119 | 120 | 121 | org.scalatest 122 | scalatest-maven-plugin 123 | ${scalatest-maven-plugin.version} 124 | 125 | ${project.build.directory}/surefire-reports 126 | . 127 | WDF TestSuite.txt 128 | 129 | 130 | 131 | 132 | test 133 | 134 | 135 | 136 | 137 | 138 | org.apache.maven.plugins 139 | maven-shade-plugin 140 | ${maven-shade-plugin.version} 141 | 142 | 143 | package 144 | 145 | shade 146 | 147 | 148 | 149 | 150 | org.apache.hadoop:* 151 | 152 | 153 | 154 | 155 | com.google.protobuf 156 | org.tensorflow.hadoop.shaded.protobuf 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /examples/segmentation/segmentation_spark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # 4 | # @title Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import, division, print_function, unicode_literals 17 | 18 | 19 | def main_fun(args, ctx): 20 | from tensorflow_examples.models.pix2pix import pix2pix 21 | import tensorflow_datasets as tfds 22 | import tensorflow as tf 23 | 24 | print("TensorFlow version: ", tf.__version__) 25 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 26 | 27 | dataset, info = tfds.load('oxford_iiit_pet:3.2.0', with_info=True) 28 | 29 | def normalize(input_image, input_mask): 30 | input_image = tf.cast(input_image, tf.float32)/128.0 - 1 31 | input_mask -= 1 32 | return input_image, input_mask 33 | 34 | @tf.function 35 | def load_image_train(datapoint): 36 | input_image = tf.image.resize(datapoint['image'], (128, 128)) 37 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) 38 | 39 | if tf.random.uniform(()) > 0.5: 40 | input_image = tf.image.flip_left_right(input_image) 41 | input_mask = tf.image.flip_left_right(input_mask) 42 | 43 | input_image, input_mask = normalize(input_image, input_mask) 44 | 45 | return input_image, input_mask 46 | 47 | def load_image_test(datapoint): 48 | input_image = tf.image.resize(datapoint['image'], (128, 128)) 49 | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128)) 50 | input_image, input_mask = normalize(input_image, input_mask) 51 | return input_image, input_mask 52 | 53 | TRAIN_LENGTH = info.splits['train'].num_examples 54 | BATCH_SIZE = args.batch_size 55 | BUFFER_SIZE = args.buffer_size 56 | STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE 57 | 58 | train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) 59 | test = dataset['test'].map(load_image_test) 60 | 61 | train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat() 62 | train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 63 | test_dataset = test.batch(BATCH_SIZE) 64 | 65 | OUTPUT_CHANNELS = 3 66 | 67 | with strategy.scope(): 68 | base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False) 69 | 70 | # Use the activations of these layers 71 | layer_names = [ 72 | 'block_1_expand_relu', # 64x64 73 | 'block_3_expand_relu', # 32x32 74 | 'block_6_expand_relu', # 16x16 75 | 'block_13_expand_relu', # 8x8 76 | 'block_16_project', # 4x4 77 | ] 78 | layers = [base_model.get_layer(name).output for name in layer_names] 79 | 80 | # Create the feature extraction model 81 | down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers) 82 | 83 | down_stack.trainable = False 84 | 85 | up_stack = [ 86 | pix2pix.upsample(512, 3), # 4x4 -> 8x8 87 | pix2pix.upsample(256, 3), # 8x8 -> 16x16 88 | pix2pix.upsample(128, 3), # 16x16 -> 32x32 89 | pix2pix.upsample(64, 3), # 32x32 -> 64x64 90 | ] 91 | 92 | def unet_model(output_channels): 93 | 94 | # This is the last layer of the model 95 | last = tf.keras.layers.Conv2DTranspose( 96 | output_channels, 3, strides=2, 97 | padding='same', activation='softmax') # 64x64 -> 128x128 98 | 99 | inputs = tf.keras.layers.Input(shape=[128, 128, 3]) 100 | x = inputs 101 | 102 | # Downsampling through the model 103 | skips = down_stack(x) 104 | x = skips[-1] 105 | skips = reversed(skips[:-1]) 106 | 107 | # Upsampling and establishing the skip connections 108 | for up, skip in zip(up_stack, skips): 109 | x = up(x) 110 | concat = tf.keras.layers.Concatenate() 111 | x = concat([x, skip]) 112 | 113 | x = last(x) 114 | 115 | return tf.keras.Model(inputs=inputs, outputs=x) 116 | 117 | model = unet_model(OUTPUT_CHANNELS) 118 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', 119 | metrics=['accuracy']) 120 | 121 | # Training only (since we're using command-line) 122 | # def create_mask(pred_mask): 123 | # pred_mask = tf.argmax(pred_mask, axis=-1) 124 | # pred_mask = pred_mask[..., tf.newaxis] 125 | # return pred_mask[0] 126 | # 127 | # 128 | # def show_predictions(dataset=None, num=1): 129 | # if dataset: 130 | # for image, mask in dataset.take(num): 131 | # pred_mask = model.predict(image) 132 | # display([image[0], mask[0], create_mask(pred_mask)]) 133 | # else: 134 | # display([sample_image, sample_mask, 135 | # create_mask(model.predict(sample_image[tf.newaxis, ...]))]) 136 | # 137 | # 138 | # class DisplayCallback(tf.keras.callbacks.Callback): 139 | # def on_epoch_end(self, epoch, logs=None): 140 | # clear_output(wait=True) 141 | # show_predictions() 142 | # print ('\nSample Prediction after epoch {}\n'.format(epoch+1)) 143 | # 144 | 145 | EPOCHS = args.epochs 146 | VAL_SUBSPLITS = 5 147 | VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS 148 | 149 | tf.io.gfile.makedirs(args.model_dir) 150 | filepath = args.model_dir + "/weights-{epoch:04d}" 151 | ckpt_callback = tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True) 152 | 153 | model_history = model.fit(train_dataset, epochs=EPOCHS, 154 | steps_per_epoch=STEPS_PER_EPOCH, 155 | callbacks=[ckpt_callback], 156 | validation_steps=VALIDATION_STEPS, 157 | validation_data=test_dataset) 158 | 159 | if tf.__version__ == '2.0.0': 160 | # Workaround for: https://github.com/tensorflow/tensorflow/issues/30251 161 | # Save model locally as h5py and reload it w/o distribution strategy 162 | if ctx.job_name == 'chief': 163 | model.save(args.model_dir + ".h5") 164 | new_model = tf.keras.models.load_model(args.model_dir + ".h5") 165 | tf.keras.experimental.export_saved_model(new_model, args.export_dir) 166 | else: 167 | model.save(args.export_dir, save_format='tf') 168 | 169 | 170 | if __name__ == '__main__': 171 | import argparse 172 | from pyspark.context import SparkContext 173 | from pyspark.conf import SparkConf 174 | from tensorflowonspark import TFCluster 175 | 176 | sc = SparkContext(conf=SparkConf().setAppName("segmentation")) 177 | executors = sc._conf.get("spark.executor.instances") 178 | num_executors = int(executors) if executors is not None else 1 179 | 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 182 | parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=1000) 183 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 184 | parser.add_argument("--epochs", help="number of epochs", type=int, default=3) 185 | parser.add_argument("--model_dir", help="path to save model/checkpoint", default="segmentation_model") 186 | parser.add_argument("--export_dir", help="path to export saved_model", default="segmentation_export") 187 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 188 | 189 | args = parser.parse_args() 190 | print("args:", args) 191 | 192 | cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief') 193 | cluster.shutdown(grace_secs=30) 194 | -------------------------------------------------------------------------------- /examples/mnist/estimator/mnist_pipeline.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_estimator 2 | 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | 6 | def main_fun(args, ctx): 7 | import numpy as np 8 | import tensorflow as tf 9 | import tensorflow_datasets as tfds 10 | from tensorflowonspark import TFNode 11 | 12 | tfds.disable_progress_bar() 13 | 14 | class StopFeedHook(tf.estimator.SessionRunHook): 15 | """SessionRunHook to terminate InputMode.SPARK RDD feeding if the training loop exits before the entire RDD is consumed.""" 16 | 17 | def __init__(self, feed): 18 | self.feed = feed 19 | 20 | def end(self, session): 21 | self.feed.terminate() 22 | self.feed.next_batch(1) 23 | 24 | BATCH_SIZE = args.batch_size 25 | LEARNING_RATE = args.learning_rate 26 | 27 | tf_feed = TFNode.DataFeed(ctx.mgr) 28 | 29 | def rdd_generator(): 30 | while not tf_feed.should_stop(): 31 | batch = tf_feed.next_batch(1) 32 | if len(batch) > 0: 33 | example = batch[0] 34 | image = np.array(example[0]).astype(np.float32) / 255.0 35 | image = np.reshape(image, (28, 28, 1)) 36 | label = np.array(example[1]).astype(np.float32) 37 | label = np.reshape(label, (1,)) 38 | yield (image, label) 39 | else: 40 | return 41 | 42 | def input_fn(mode, input_context=None): 43 | if mode == tf.estimator.ModeKeys.TRAIN: 44 | # Note: Spark is responsible for sharding/repeating/shuffling the data via RDD 45 | ds = tf.data.Dataset.from_generator(rdd_generator, (tf.float32, tf.float32), (tf.TensorShape([28, 28, 1]), tf.TensorShape([1]))) 46 | return ds.batch(BATCH_SIZE) 47 | else: 48 | # read evaluation data from tensorflow_datasets directly 49 | def scale(image, label): 50 | image = tf.cast(image, tf.float32) / 255.0 51 | return image, label 52 | 53 | mnist = tfds.load(name='mnist', with_info=True, as_supervised=True) 54 | ds = mnist['test'] 55 | if input_context: 56 | ds = ds.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) 57 | return ds.map(scale).batch(BATCH_SIZE) 58 | 59 | def serving_input_receiver_fn(): 60 | features = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='conv2d_input') 61 | receiver_tensors = {'conv2d_input': features} 62 | return tf.estimator.export.ServingInputReceiver(receiver_tensors, receiver_tensors) 63 | 64 | def model_fn(features, labels, mode): 65 | model = tf.keras.Sequential([ 66 | tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), 67 | tf.keras.layers.MaxPooling2D(), 68 | tf.keras.layers.Flatten(), 69 | tf.keras.layers.Dense(64, activation='relu'), 70 | tf.keras.layers.Dense(10, activation='softmax') 71 | ]) 72 | logits = model(features, training=False) 73 | 74 | if mode == tf.estimator.ModeKeys.PREDICT: 75 | predictions = {'logits': logits} 76 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 77 | 78 | optimizer = tf.compat.v1.train.GradientDescentOptimizer( 79 | learning_rate=LEARNING_RATE) 80 | loss = tf.keras.losses.SparseCategoricalCrossentropy( 81 | from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits) 82 | loss = tf.reduce_sum(input_tensor=loss) * (1. / BATCH_SIZE) 83 | if mode == tf.estimator.ModeKeys.EVAL: 84 | return tf.estimator.EstimatorSpec(mode, loss=loss) 85 | 86 | return tf.estimator.EstimatorSpec( 87 | mode=mode, 88 | loss=loss, 89 | train_op=optimizer.minimize( 90 | loss, tf.compat.v1.train.get_or_create_global_step())) 91 | 92 | strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 93 | config = tf.estimator.RunConfig(train_distribute=strategy, save_checkpoints_steps=100) 94 | 95 | classifier = tf.estimator.Estimator( 96 | model_fn=model_fn, model_dir=args.model_dir, config=config) 97 | 98 | # exporter = tf.estimator.FinalExporter("serving", serving_input_receiver_fn=serving_input_receiver_fn) 99 | 100 | # Note: MultiWorkerMirroredStrategy (CollectiveAllReduceStrategy) is synchronous, 101 | # so we need to ensure that all workers complete training before any of them run out of data from the RDD. 102 | # And given that Spark RDD partitions (and partition sizes) can be non-evenly divisible by num_workers, 103 | # we'll just stop training at 90% of the total expected number of steps. 104 | steps = 60000 * args.epochs / args.batch_size 105 | steps_per_worker = steps / ctx.num_workers 106 | max_steps_per_worker = steps_per_worker * 0.9 107 | 108 | tf.estimator.train_and_evaluate( 109 | classifier, 110 | train_spec=tf.estimator.TrainSpec(input_fn=input_fn, max_steps=max_steps_per_worker, hooks=[StopFeedHook(tf_feed)]), 111 | eval_spec=tf.estimator.EvalSpec(input_fn=input_fn) 112 | # eval_spec=tf.estimator.EvalSpec(input_fn=input_fn, exporters=exporter) 113 | ) 114 | 115 | if ctx.job_name == 'chief': 116 | print("Exporting saved_model to {}".format(args.export_dir)) 117 | classifier.export_saved_model(args.export_dir, serving_input_receiver_fn) 118 | 119 | 120 | if __name__ == "__main__": 121 | 122 | from pyspark.context import SparkContext 123 | from pyspark.conf import SparkConf 124 | from pyspark.sql import SparkSession 125 | from pyspark.sql.functions import udf 126 | from pyspark.sql.types import IntegerType 127 | from tensorflowonspark import dfutil 128 | from tensorflowonspark.pipeline import TFEstimator, TFModel 129 | import argparse 130 | 131 | sc = SparkContext(conf=SparkConf().setAppName("mnist_estimator")) 132 | spark = SparkSession(sc) 133 | 134 | executors = sc._conf.get("spark.executor.instances") 135 | num_executors = int(executors) if executors is not None else 1 136 | 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64) 139 | parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000) 140 | parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) 141 | parser.add_argument("--epochs", help="number of epochs", type=int, default=3) 142 | parser.add_argument("--format", help="example format: (csv|tfr)", choices=["csv", "tfr"], default="csv") 143 | parser.add_argument("--images_labels", help="path to MNIST images and labels in parallelized format") 144 | parser.add_argument("--learning_rate", help="learning rate", type=float, default=1e-3) 145 | parser.add_argument("--mode", help="train|inference", choices=["train", "inference"], default="train") 146 | parser.add_argument("--model_dir", help="path to save checkpoint", default="mnist_model") 147 | parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export") 148 | parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions") 149 | parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true") 150 | 151 | args = parser.parse_args() 152 | print("args:", args) 153 | 154 | if args.format == 'tfr': 155 | # load TFRecords as a DataFrame 156 | df = dfutil.loadTFRecords(sc, args.images_labels) 157 | else: # args.format == 'csv': 158 | # create RDD of input data 159 | def parse(ln): 160 | vec = [int(x) for x in ln.split(',')] 161 | return (vec[1:], vec[0]) 162 | 163 | images_labels = sc.textFile(args.images_labels).map(parse) 164 | df = spark.createDataFrame(images_labels, ['image', 'label']) 165 | 166 | df.show() 167 | 168 | if args.mode == 'train': 169 | estimator = TFEstimator(main_fun, args) \ 170 | .setInputMapping({'image': 'image', 'label': 'label'}) \ 171 | .setModelDir(args.model_dir) \ 172 | .setExportDir(args.export_dir) \ 173 | .setClusterSize(args.cluster_size) \ 174 | .setTensorboard(args.tensorboard) \ 175 | .setEpochs(args.epochs) \ 176 | .setBatchSize(args.batch_size) \ 177 | .setGraceSecs(60) 178 | model = estimator.fit(df) 179 | else: # args.mode == 'inference': 180 | # using a trained/exported model 181 | model = TFModel(args) \ 182 | .setInputMapping({'image': 'conv2d_input'}) \ 183 | .setOutputMapping({'logits': 'prediction'}) \ 184 | .setSignatureDefKey('serving_default') \ 185 | .setExportDir(args.export_dir) \ 186 | .setBatchSize(args.batch_size) 187 | 188 | def argmax_fn(l): 189 | return max(range(len(l)), key=lambda i: l[i]) 190 | 191 | argmax = udf(argmax_fn, IntegerType()) 192 | 193 | preds = model.transform(df).withColumn('argmax', argmax('prediction')) 194 | preds.show() 195 | preds.write.json(args.output) 196 | -------------------------------------------------------------------------------- /tensorflowonspark/dfutil.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Yahoo Inc. 2 | # Licensed under the terms of the Apache 2.0 license. 3 | # Please see LICENSE file in the project root for terms. 4 | """A collection of utility functions for loading/saving TensorFlow TFRecords files as Spark DataFrames.""" 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import nested_scopes 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | from pyspark.sql import Row 13 | from pyspark.sql.types import ArrayType, BinaryType, DoubleType, LongType, StringType, StructField, StructType 14 | 15 | loadedDF = {} # Stores origin paths of loaded DataFrames (df => path) 16 | 17 | 18 | def isLoadedDF(df): 19 | """Returns True if the input DataFrame was produced by the loadTFRecords() method. 20 | 21 | This is primarily used by the Spark ML Pipelines APIs. 22 | 23 | Args: 24 | :df: Spark Dataframe 25 | """ 26 | return df in loadedDF 27 | 28 | 29 | def saveAsTFRecords(df, output_dir): 30 | """Save a Spark DataFrame as TFRecords. 31 | 32 | This will convert the DataFrame rows to TFRecords prior to saving. 33 | 34 | Args: 35 | :df: Spark DataFrame 36 | :output_dir: Path to save TFRecords 37 | """ 38 | tf_rdd = df.rdd.mapPartitions(toTFExample(df.dtypes)) 39 | tf_rdd.saveAsNewAPIHadoopFile(output_dir, "org.tensorflow.hadoop.io.TFRecordFileOutputFormat", 40 | keyClass="org.apache.hadoop.io.BytesWritable", 41 | valueClass="org.apache.hadoop.io.NullWritable") 42 | 43 | 44 | def loadTFRecords(sc, input_dir, binary_features=[]): 45 | """Load TFRecords from disk into a Spark DataFrame. 46 | 47 | This will attempt to automatically convert the tf.train.Example features into Spark DataFrame columns of equivalent types. 48 | 49 | Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to 50 | disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint" 51 | from the caller in the ``binary_features`` argument. 52 | 53 | Args: 54 | :sc: SparkContext 55 | :input_dir: location of TFRecords on disk. 56 | :binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays. 57 | 58 | Returns: 59 | A Spark DataFrame mirroring the tf.train.Example schema. 60 | """ 61 | import tensorflow as tf 62 | 63 | tfr_rdd = sc.newAPIHadoopFile(input_dir, "org.tensorflow.hadoop.io.TFRecordFileInputFormat", 64 | keyClass="org.apache.hadoop.io.BytesWritable", 65 | valueClass="org.apache.hadoop.io.NullWritable") 66 | 67 | # infer Spark SQL types from tf.Example 68 | record = tfr_rdd.take(1)[0] 69 | example = tf.train.Example() 70 | example.ParseFromString(bytes(record[0])) 71 | schema = infer_schema(example, binary_features) 72 | 73 | # convert serialized protobuf to tf.Example to Row 74 | example_rdd = tfr_rdd.mapPartitions(lambda x: fromTFExample(x, binary_features)) 75 | 76 | # create a Spark DataFrame from RDD[Row] 77 | df = example_rdd.toDF(schema) 78 | 79 | # save reference of this dataframe 80 | loadedDF[df] = input_dir 81 | return df 82 | 83 | 84 | def toTFExample(dtypes): 85 | """mapPartition function to convert a Spark RDD of Row into an RDD of serialized tf.train.Example bytestring. 86 | 87 | Note that tf.train.Example is a fairly flat structure with limited datatypes, e.g. tf.train.FloatList, 88 | tf.train.Int64List, and tf.train.BytesList, so most DataFrame types will be coerced into one of these types. 89 | 90 | Args: 91 | :dtypes: the DataFrame.dtypes of the source DataFrame. 92 | 93 | Returns: 94 | A mapPartition function which converts the source DataFrame into tf.train.Example bytestrings. 95 | """ 96 | def _toTFExample(iter): 97 | 98 | # supported type mappings between DataFrame.dtypes and tf.train.Feature types 99 | float_dtypes = ['float', 'double'] 100 | int64_dtypes = ['boolean', 'tinyint', 'smallint', 'int', 'bigint', 'long'] 101 | bytes_dtypes = ['binary', 'string'] 102 | float_list_dtypes = ['array', 'array'] 103 | int64_list_dtypes = ['array', 'array', 'array', 'array', 'array', 'array'] 104 | 105 | def _toTFFeature(name, dtype, row): 106 | feature = None 107 | if dtype in float_dtypes: 108 | feature = (name, tf.train.Feature(float_list=tf.train.FloatList(value=[row[name]]))) 109 | elif dtype in int64_dtypes: 110 | feature = (name, tf.train.Feature(int64_list=tf.train.Int64List(value=[row[name]]))) 111 | elif dtype in bytes_dtypes: 112 | if dtype == 'binary': 113 | feature = (name, tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(row[name])]))) 114 | else: 115 | feature = (name, tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(row[name]).encode('utf-8')]))) 116 | elif dtype in float_list_dtypes: 117 | feature = (name, tf.train.Feature(float_list=tf.train.FloatList(value=row[name]))) 118 | elif dtype in int64_list_dtypes: 119 | feature = (name, tf.train.Feature(int64_list=tf.train.Int64List(value=row[name]))) 120 | else: 121 | raise Exception("Unsupported dtype: {0}".format(dtype)) 122 | return feature 123 | 124 | results = [] 125 | for row in iter: 126 | features = dict([_toTFFeature(name, dtype, row) for name, dtype in dtypes]) 127 | example = tf.train.Example(features=tf.train.Features(feature=features)) 128 | results.append((bytearray(example.SerializeToString()), None)) 129 | return results 130 | 131 | return _toTFExample 132 | 133 | 134 | def infer_schema(example, binary_features=[]): 135 | """Given a tf.train.Example, infer the Spark DataFrame schema (StructFields). 136 | 137 | Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to 138 | disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint" 139 | from the caller in the ``binary_features`` argument. 140 | 141 | Args: 142 | :example: a tf.train.Example 143 | :binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays. 144 | 145 | Returns: 146 | A DataFrame StructType schema 147 | """ 148 | def _infer_sql_type(k, v): 149 | # special handling for binary features 150 | if k in binary_features: 151 | return BinaryType() 152 | 153 | if v.int64_list.value: 154 | result = v.int64_list.value 155 | sql_type = LongType() 156 | elif v.float_list.value: 157 | result = v.float_list.value 158 | sql_type = DoubleType() 159 | else: 160 | result = v.bytes_list.value 161 | sql_type = StringType() 162 | 163 | if len(result) > 1: # represent multi-item tensors as Spark SQL ArrayType() of base types 164 | return ArrayType(sql_type) 165 | else: # represent everything else as base types (and empty tensors as StringType()) 166 | return sql_type 167 | 168 | return StructType([StructField(k, _infer_sql_type(k, v), True) for k, v in sorted(example.features.feature.items())]) 169 | 170 | 171 | def fromTFExample(iter, binary_features=[]): 172 | """mapPartition function to convert an RDD of serialized tf.train.Example bytestring into an RDD of Row. 173 | 174 | Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to 175 | disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint" 176 | from the caller in the ``binary_features`` argument. 177 | 178 | Args: 179 | :iter: the RDD partition iterator 180 | :binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays. 181 | 182 | Returns: 183 | An array/iterator of DataFrame Row with features converted into columns. 184 | """ 185 | # convert from protobuf-like dict to DataFrame-friendly dict 186 | def _get_value(k, v): 187 | if v.int64_list.value: 188 | result = v.int64_list.value 189 | elif v.float_list.value: 190 | result = v.float_list.value 191 | else: # string or bytearray 192 | if k in binary_features: 193 | return bytearray(v.bytes_list.value[0]) 194 | else: 195 | return v.bytes_list.value[0].decode('utf-8') 196 | 197 | if len(result) > 1: # represent multi-item tensors as python lists 198 | return list(result) 199 | elif len(result) == 1: # extract scalars from single-item tensors 200 | return result[0] 201 | else: # represent empty tensors as python None 202 | return None 203 | 204 | results = [] 205 | for record in iter: 206 | example = tf.train.Example() 207 | example.ParseFromString(bytes(record[0])) # record is (bytestr, None) 208 | d = {k: _get_value(k, v) for k, v in sorted(example.features.feature.items())} 209 | row = Row(**d) 210 | results.append(row) 211 | 212 | return results 213 | --------------------------------------------------------------------------------