├── docs ├── .gitignore ├── img │ ├── java-sm.png │ ├── python-sm.png │ └── scala-sm.png ├── user-guide.md ├── css │ ├── api-docs.css │ ├── api-javadocs.css │ ├── main.css │ └── pygments-default.css ├── _plugins │ ├── production_tag.rb │ └── copy_api_dirs.rb ├── _config.yml ├── index.md ├── quick-start.md ├── js │ ├── api-docs.js │ ├── api-javadocs.js │ ├── main.js │ └── vendor │ │ └── anchor.min.js ├── prepare ├── README.md └── _layouts │ ├── 404.html │ └── global.html ├── python ├── .gitignore ├── setup.py ├── tests │ ├── resources │ │ ├── images │ │ │ ├── 00074201.jpg │ │ │ ├── 00081101.jpg │ │ │ ├── 00084301.png │ │ │ ├── 00093801.jpg │ │ │ └── 19207401.jpg │ │ └── images-source.txt │ ├── __init__.py │ ├── image │ │ ├── __init__.py │ │ └── test_imageIO.py │ ├── utils │ │ ├── __init__.py │ │ └── test_python_interface.py │ ├── graph │ │ ├── __init__.py │ │ ├── test_builder.py │ │ └── test_pieces.py │ ├── transformers │ │ ├── __init__.py │ │ ├── keras_image_test.py │ │ ├── image_utils.py │ │ ├── tf_image_test.py │ │ └── named_image_test.py │ └── tests.py ├── setup.cfg ├── spark-package-deps.txt ├── docs │ ├── _templates │ │ └── layout.html │ ├── sparkdl.rst │ ├── epytext.py │ ├── index.rst │ ├── static │ │ ├── pysparkdl.css │ │ └── pysparkdl.js │ ├── underscores.py │ └── Makefile ├── requirements.txt ├── MANIFEST.in ├── sparkdl │ ├── image │ │ ├── __init__.py │ │ └── imageIO.py │ ├── graph │ │ ├── __init__.py │ │ ├── pieces.py │ │ └── utils.py │ ├── transformers │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── keras_utils.py │ │ ├── param.py │ │ └── keras_image.py │ ├── utils │ │ ├── __init__.py │ │ └── jvmapi.py │ └── __init__.py └── run-tests.sh ├── project ├── build.properties └── plugins.sbt ├── Makefile ├── .gitignore ├── src ├── test │ ├── resources │ │ └── log4j.properties │ └── scala │ │ ├── org │ │ ├── tensorframes │ │ │ └── impl │ │ │ │ ├── GraphScoping.scala │ │ │ │ ├── SqlOpsSuite.scala │ │ │ │ └── TestUtils.scala │ │ └── apache │ │ │ └── spark │ │ │ └── sql │ │ │ └── sparkdl_stubs │ │ │ └── SparkDLStubsSuite.scala │ │ └── com │ │ └── databricks │ │ └── sparkdl │ │ └── TestSparkContext.scala └── main │ └── scala │ ├── com │ └── databricks │ │ └── sparkdl │ │ ├── Logging.scala │ │ └── python │ │ ├── PythonInterface.scala │ │ └── ModelFactory.scala │ └── org │ └── apache │ └── spark │ └── sql │ └── sparkdl_stubs │ └── UDFUtils.scala ├── NOTICE ├── bin └── download_travis_dependencies.sh ├── .travis.yml └── README.md /docs/.gitignore: -------------------------------------------------------------------------------- 1 | jekyll -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | docs/_build/ 3 | build/ 4 | dist/ 5 | -------------------------------------------------------------------------------- /docs/img/java-sm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/docs/img/java-sm.png -------------------------------------------------------------------------------- /docs/img/python-sm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/docs/img/python-sm.png -------------------------------------------------------------------------------- /docs/img/scala-sm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/docs/img/scala-sm.png -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | // This file should only contain the version of sbt to use. 2 | sbt.version=0.13.15 3 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | # Your python setup file. An example can be found at: 2 | # https://github.com/pypa/sampleproject/blob/master/setup.py 3 | -------------------------------------------------------------------------------- /python/tests/resources/images/00074201.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/00074201.jpg -------------------------------------------------------------------------------- /python/tests/resources/images/00081101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/00081101.jpg -------------------------------------------------------------------------------- /python/tests/resources/images/00084301.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/00084301.png -------------------------------------------------------------------------------- /python/tests/resources/images/00093801.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/00093801.jpg -------------------------------------------------------------------------------- /python/tests/resources/images/19207401.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateiz/spark-deep-learning/HEAD/python/tests/resources/images/19207401.jpg -------------------------------------------------------------------------------- /python/tests/resources/images-source.txt: -------------------------------------------------------------------------------- 1 | Digital image courtesy of the Getty's Open Content Program. 2 | http://www.getty.edu/about/whatwedo/opencontent.html 3 | -------------------------------------------------------------------------------- /python/setup.cfg: -------------------------------------------------------------------------------- 1 | # This file contains the default option values to be used during setup. An 2 | # example can be found at https://github.com/pypa/sampleproject/blob/master/setup.cfg 3 | -------------------------------------------------------------------------------- /python/spark-package-deps.txt: -------------------------------------------------------------------------------- 1 | # This file should list any spark package dependencies as: 2 | # :package_name==:version e.g. databricks/spark-csv==0.1 3 | databricks/tensorframes==0.2.8 4 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2.1.0s2.10 2.0.2 2.1.0 2 | 3 | clean: 4 | rm -rf target/sparkdl_*.zip 5 | 6 | 2.0.2 2.1.0: 7 | build/sbt -Dspark.version=$@ spDist 8 | 9 | 2.1.0s2.10: 10 | build/sbt -Dspark.version=2.1.0 -Dscala.version=2.10.6 spDist assembly test 11 | -------------------------------------------------------------------------------- /python/docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% set script_files = script_files + ["_static/pysparkdl.js"] %} 3 | {% set css_files = css_files + ['_static/pysparkdl.css'] %} 4 | {% block rootrellink %} 5 | {{ super() }} 6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | # This file should list any python package dependencies. 2 | coverage==4.4.1 3 | h5py==2.7.0 4 | keras==2.0.4 5 | nose==1.3.7 # for testing 6 | numpy==1.11.2 7 | pillow==4.1.1 8 | pygments==2.2.0 9 | tensorflow==1.1.0 10 | six==1.10.0 11 | -------------------------------------------------------------------------------- /python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # An example MANIFEST file can be found at: 2 | # https://github.com/pypa/sampleproject/blob/master/MANIFEST.in 3 | # For more details about the MANIFEST file, you may read the docs at 4 | # https://docs.python.org/2/distutils/sourcedist.html#the-manifest-in-template 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | *.pyc 4 | build/*.jar 5 | 6 | docs/_site 7 | docs/api 8 | 9 | # sbt specific 10 | .cache/ 11 | .history/ 12 | .lib/ 13 | dist/* 14 | target/ 15 | lib_managed/ 16 | src_managed/ 17 | project/boot/ 18 | project/plugins/project/ 19 | 20 | # intellij 21 | .idea/ 22 | 23 | # MacOS 24 | .DS_Store 25 | -------------------------------------------------------------------------------- /docs/user-guide.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: global 3 | displayTitle: Deep Learning Pipelines User Guide 4 | title: User Guide 5 | description: Deep Learning Pipelines SPARKDL_VERSION user guide 6 | --- 7 | 8 | This page gives examples of how to use Deep Learning Pipelines 9 | * Table of contents (This text will be scraped.) 10 | {:toc} 11 | 12 | -------------------------------------------------------------------------------- /docs/css/api-docs.css: -------------------------------------------------------------------------------- 1 | /* Dynamically injected style for the API docs */ 2 | 3 | .developer { 4 | background-color: #44751E; 5 | } 6 | 7 | .experimental { 8 | background-color: #257080; 9 | } 10 | 11 | .alphaComponent { 12 | background-color: #bb0000; 13 | } 14 | 15 | .badge { 16 | font-family: Arial, san-serif; 17 | float: right; 18 | } 19 | -------------------------------------------------------------------------------- /docs/_plugins/production_tag.rb: -------------------------------------------------------------------------------- 1 | module Jekyll 2 | class ProductionTag < Liquid::Block 3 | 4 | def initialize(tag_name, markup, tokens) 5 | super 6 | end 7 | 8 | def render(context) 9 | if ENV['PRODUCTION'] then super else "" end 10 | end 11 | end 12 | end 13 | 14 | Liquid::Template.register_tag('production', Jekyll::ProductionTag) 15 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | // You may use this file to add plugin dependencies for sbt. 2 | resolvers += "Spark Packages repo" at "https://dl.bintray.com/spark-packages/maven/" 3 | 4 | addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.5") 5 | 6 | // scalacOptions in (Compile,doc) := Seq("-groups", "-implicits") 7 | 8 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") 9 | -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootCategory=WARN, console 2 | log4j.appender.console=org.apache.log4j.ConsoleAppender 3 | log4j.appender.console.target=System.err 4 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 5 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 6 | 7 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR 8 | log4j.logger.org.apache.spark=WARN 9 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 10 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 11 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Databricks, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /bin/download_travis_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Downloading Spark if necessary" 3 | echo "Spark version = $SPARK_VERSION" 4 | echo "Spark build = $SPARK_BUILD" 5 | echo "Spark build URL = $SPARK_BUILD_URL" 6 | mkdir -p $HOME/.cache/spark-versions 7 | filename="$HOME/.cache/spark-versions/$SPARK_BUILD.tgz" 8 | if ! [ -f $filename ]; then 9 | echo "Downloading file..." 10 | echo `which curl` 11 | curl "$SPARK_BUILD_URL" > $filename 12 | echo "Content of directory:" 13 | ls -la $HOME/.cache/spark-versions/* 14 | tar xvf $filename --directory $HOME/.cache/spark-versions > /dev/null 15 | fi 16 | -------------------------------------------------------------------------------- /python/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/sparkdl/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/tests/image/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/tests/graph/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /python/tests/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/sparkdl/graph/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /python/sparkdl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | highlighter: pygments 2 | markdown: kramdown 3 | gems: 4 | - jekyll-redirect-from 5 | 6 | # For some reason kramdown seems to behave differently on different 7 | # OS/packages wrt encoding. So we hard code this config. 8 | kramdown: 9 | entity_output: numeric 10 | 11 | include: 12 | - _static 13 | - _modules 14 | 15 | # These allow the documentation to be updated with newer releases 16 | # of Spark, Scala, and Mesos. 17 | SPARKDL_VERSION: 0.1.0 18 | #SCALA_BINARY_VERSION: "2.10" 19 | #SCALA_VERSION: "2.10.4" 20 | #MESOS_VERSION: 0.21.0 21 | #SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK 22 | #SPARK_GITHUB_URL: https://github.com/apache/spark 23 | -------------------------------------------------------------------------------- /python/docs/sparkdl.rst: -------------------------------------------------------------------------------- 1 | sparkdl package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | sparkdl.image 10 | sparkdl.transformers 11 | 12 | Submodules 13 | ---------- 14 | 15 | sparkdl\.sparkdl module 16 | -------------------------- 17 | 18 | .. automodule:: sparkdl.sparkdl 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | sparkdl\.utils module 24 | --------------------- 25 | 26 | .. automodule:: sparkdl.utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: sparkdl 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /python/docs/epytext.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | RULES = ( 4 | (r"<(!BLANKLINE)[\w.]+>", r""), 5 | (r"L{([\w.()]+)}", r":class:`\1`"), 6 | (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), 7 | (r"C{([\w.()]+)}", r":class:`\1`"), 8 | (r"[IBCM]{([^}]+)}", r"`\1`"), 9 | ('pyspark.rdd.RDD', 'RDD'), 10 | ) 11 | 12 | def _convert_epytext(line): 13 | """ 14 | >>> _convert_epytext("L{A}") 15 | :class:`A` 16 | """ 17 | line = line.replace('@', ':') 18 | for p, sub in RULES: 19 | line = re.sub(p, sub, line) 20 | return line 21 | 22 | def _process_docstring(app, what, name, obj, options, lines): 23 | for i in range(len(lines)): 24 | lines[i] = _convert_epytext(lines[i]) 25 | 26 | def setup(app): 27 | app.connect("autodoc-process-docstring", _process_docstring) 28 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | 3 | dist: trusty 4 | 5 | language: python 6 | python: 7 | - "2.7" 8 | - "3.6" 9 | - "3.5" 10 | 11 | cache: 12 | directories: 13 | - $HOME/.ivy2 14 | - $HOME/.sbt/launchers/ 15 | - $HOME/.cache/spark-versions 16 | 17 | env: 18 | matrix: 19 | - SCALA_VERSION=2.11.8 SPARK_VERSION=2.1.1 SPARK_BUILD="spark-2.1.1-bin-hadoop2.7" SPARK_BUILD_URL="http://d3kbcqa49mib13.cloudfront.net/spark-2.1.1-bin-hadoop2.7.tgz" 20 | 21 | before_install: 22 | - ./bin/download_travis_dependencies.sh 23 | 24 | install: 25 | - pip install -r ./python/requirements.txt 26 | 27 | script: 28 | - ./build/sbt -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION "set test in assembly := {}" assembly 29 | - ./build/sbt -Dspark.version=$SPARK_VERSION -Dscala.version=$SCALA_VERSION coverage test coverageReport 30 | - SPARK_HOME=$HOME/.cache/spark-versions/$SPARK_BUILD ./python/run-tests.sh 31 | 32 | after_success: 33 | - bash <(curl -s https://codecov.io/bash) -------------------------------------------------------------------------------- /src/main/scala/com/databricks/sparkdl/Logging.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.databricks.sparkdl 17 | 18 | import com.typesafe.scalalogging.slf4j.{LazyLogging, StrictLogging} 19 | 20 | private[sparkdl] trait Logging extends LazyLogging { 21 | def logDebug(s: String) = logger.debug(s) 22 | def logInfo(s: String) = logger.info(s) 23 | def logTrace(s: String) = logger.trace(s) 24 | } 25 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: global 3 | displayTitle: Deep Learning Pipelines Overview 4 | title: Overview 5 | description: Deep Learning Pipelines SPARKDL_VERSION documentation homepage 6 | --- 7 | 8 | 9 | # Downloading 10 | 11 | # Where to Go from Here 12 | 13 | **User Guides:** 14 | 15 | * [Quick Start](quick-start.html): a quick introduction to the Deep Learning Pipelines API; start here! 16 | * [Deep Learning Pipelines User Guide](user-guide.html): detailed overview of Deep Learning Pipelines 17 | in all supported languages (Scala, Python) 18 | 19 | **API Docs:** 20 | 21 | * [Deep Learning Pipelines Scala API (Scaladoc)](api/scala/index.html#com.databricks.sparkdl.package) 22 | * [Deep Learning Pipelines Python API (Sphinx)](api/python/index.html) 23 | 24 | **External Resources:** 25 | 26 | * [Apache Spark Homepage](http://spark.apache.org) 27 | * [Apache Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) 28 | * [Mailing Lists](http://spark.apache.org/mailing-lists.html): Ask questions about Spark here 29 | -------------------------------------------------------------------------------- /python/docs/index.rst: -------------------------------------------------------------------------------- 1 | .. pysparkdl documentation master file, created by 2 | sphinx-quickstart on Thu Feb 18 16:43:49 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to the Deep Learning Pipelines Python API docs! 7 | ==================================================================================== 8 | 9 | *Note that most of the Python API docs are currently stubs. The APIs are designed to match 10 | the Scala APIs as closely as reasonable, so please refer to the Scala API docs for more details 11 | on both the algorithms and APIs (particularly DataFrame schema).* 12 | 13 | Contents: 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | 18 | sparkdl 19 | 20 | Core classes: 21 | ------------- 22 | 23 | :class:`sparkdl.OurCoolClass` 24 | 25 | Description of OurCoolClass 26 | 27 | 28 | Indices and tables 29 | ==================================================================================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /python/sparkdl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from .image.imageIO import imageSchema, imageType, readImages 17 | from .transformers.keras_image import KerasImageFileTransformer 18 | from .transformers.named_image import DeepImagePredictor, DeepImageFeaturizer 19 | from .transformers.tf_image import TFImageTransformer 20 | from .transformers.utils import imageInputPlaceholder 21 | 22 | __all__ = [ 23 | 'imageSchema', 'imageType', 'readImages', 24 | 'TFImageTransformer', 25 | 'DeepImagePredictor', 'DeepImageFeaturizer', 26 | 'KerasImageFileTransformer', 27 | 'imageInputPlaceholder'] 28 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import tensorflow as tf 17 | 18 | from pyspark.ml.param import TypeConverters 19 | 20 | from sparkdl.image.imageIO import imageType 21 | 22 | # image stuff 23 | 24 | IMAGE_INPUT_PLACEHOLDER_NAME = "sparkdl_image_input" 25 | 26 | def imageInputPlaceholder(nChannels=None): 27 | return tf.placeholder(tf.float32, [None, None, None, nChannels], 28 | name=IMAGE_INPUT_PLACEHOLDER_NAME) 29 | 30 | class ImageNetConstants: 31 | NUM_CLASSES = 1000 32 | 33 | # probably use a separate module for each network once we have featurizers. 34 | class InceptionV3Constants: 35 | INPUT_SHAPE = (299, 299) 36 | NUM_OUTPUT_FEATURES = 131072 37 | -------------------------------------------------------------------------------- /python/tests/utils/test_python_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import sys, traceback 16 | 17 | from pyspark import SparkContext, SQLContext 18 | from pyspark.sql.column import Column 19 | from sparkdl.utils import jvmapi as JVMAPI 20 | from ..tests import SparkDLTestCase 21 | 22 | class PythonAPITest(SparkDLTestCase): 23 | 24 | def test_using_api(self): 25 | """ Must be able to load the API """ 26 | try: 27 | print(JVMAPI.default()) 28 | except: 29 | traceback.print_exc(file=sys.stdout) 30 | self.fail("failed to load certain classes") 31 | 32 | kls_name = str(JVMAPI.forClass(javaClassName=JVMAPI.PYTHON_INTERFACE_CLASSNAME)) 33 | self.assertEqual(kls_name.split('@')[0], JVMAPI.PYTHON_INTERFACE_CLASSNAME) 34 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/keras_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import keras.backend as K 17 | import tensorflow as tf 18 | 19 | 20 | class KSessionWrap(): 21 | """ 22 | Runs operations in Keras in an isolated manner: the current graph and the current session 23 | are not modified by anything done in this block: 24 | 25 | with KSessionWrap() as (current_session, current_graph): 26 | ... do some things that call Keras 27 | """ 28 | 29 | def __init__(self, graph = None): 30 | self.requested_graph = graph 31 | 32 | def __enter__(self): 33 | self.old_session = K.get_session() 34 | self.g = self.requested_graph or tf.Graph() 35 | self.current_session = tf.Session(graph = self.g) 36 | K.set_session(self.current_session) 37 | return (self.current_session, self.g) 38 | 39 | def __exit__(self, exc_type, exc_val, exc_tb): 40 | # Restore the previous session 41 | K.set_session(self.old_session) 42 | -------------------------------------------------------------------------------- /python/tests/tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import sys 17 | 18 | if sys.version_info[:2] <= (2, 6): 19 | try: 20 | import unittest2 as unittest 21 | except ImportError: 22 | sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') 23 | sys.exit(1) 24 | else: 25 | import unittest 26 | 27 | from pyspark import SparkContext 28 | from pyspark.sql import SQLContext 29 | from pyspark.sql import SparkSession 30 | 31 | 32 | class SparkDLTestCase(unittest.TestCase): 33 | 34 | @classmethod 35 | def setUpClass(cls): 36 | cls.sc = SparkContext('local[*]', cls.__name__) 37 | cls.sql = SQLContext(cls.sc) 38 | cls.session = SparkSession.builder.getOrCreate() 39 | 40 | @classmethod 41 | def tearDownClass(cls): 42 | cls.session.stop() 43 | cls.session = None 44 | cls.sc.stop() 45 | cls.sc = None 46 | cls.sql = None 47 | 48 | def assertDfHasCols(self, df, cols = []): 49 | map(lambda c: self.assertIn(c, df.columns), cols) 50 | -------------------------------------------------------------------------------- /docs/css/api-javadocs.css: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* Dynamically injected style for the API docs */ 19 | 20 | .badge { 21 | font-family: Arial, san-serif; 22 | float: right; 23 | margin: 4px; 24 | /* The following declarations are taken from the ScalaDoc template.css */ 25 | display: inline-block; 26 | padding: 2px 4px; 27 | font-size: 11.844px; 28 | font-weight: bold; 29 | line-height: 14px; 30 | color: #ffffff; 31 | text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25); 32 | white-space: nowrap; 33 | vertical-align: baseline; 34 | background-color: #999999; 35 | padding-right: 9px; 36 | padding-left: 9px; 37 | -webkit-border-radius: 9px; 38 | -moz-border-radius: 9px; 39 | border-radius: 9px; 40 | } 41 | 42 | .developer { 43 | background-color: #44751E; 44 | } 45 | 46 | .experimental { 47 | background-color: #257080; 48 | } 49 | 50 | .alphaComponent { 51 | background-color: #bb0000; 52 | } 53 | -------------------------------------------------------------------------------- /docs/quick-start.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: global 3 | displayTitle: Deep Learning Pipelines Quick-Start Guide 4 | title: Quick-Start Guide 5 | description: Deep Learning Pipelines SPARKDL_VERSION guide for getting started quickly 6 | --- 7 | 8 | This quick-start guide shows how to get started using Deep Learning Pipelines. 9 | After you work through this guide, move on to the [User Guide](user-guide.html) 10 | to learn more about the many queries and algorithms supported by Deep Learning Pipelines. 11 | 12 | * Table of contents 13 | {:toc} 14 | 15 | # Getting started with Apache Spark and Spark packages 16 | 17 | If you are new to using Apache Spark, refer to the 18 | [Apache Spark Documentation](http://spark.apache.org/docs/latest/index.html) and its 19 | [Quick-Start Guide](http://spark.apache.org/docs/latest/quick-start.html) for more information. 20 | 21 | If you are new to using [Spark packages](http://spark-packages.org), you can find more information 22 | in the [Spark User Guide on using the interactive shell](http://spark.apache.org/docs/latest/programming-guide.html#using-the-shell). 23 | You just need to make sure your Spark shell session has the package as a dependency. 24 | 25 | The following example shows how to run the Spark shell with the Deep Learning Pipelines package. 26 | We use the `--packages` argument to download the Deep Learning Pipelines package and any dependencies automatically. 27 | 28 |
29 | 30 |
31 | 32 | {% highlight bash %} 33 | $ ./bin/spark-shell --packages spark-deep-learning 34 | {% endhighlight %} 35 | 36 |
37 | 38 |
39 | 40 | {% highlight bash %} 41 | $ ./bin/pyspark --packages spark-deep-learning 42 | {% endhighlight %} 43 | 44 |
45 | 46 | -------------------------------------------------------------------------------- /src/test/scala/org/tensorframes/impl/GraphScoping.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorframes.impl 17 | 18 | import org.scalatest.FunSuite 19 | 20 | import org.tensorflow.{Graph => TFGraph, Session => TFSession} 21 | import org.tensorframes.{dsl => tf} 22 | 23 | trait GraphScoping { self: FunSuite => 24 | import tf.withGraph 25 | 26 | def testGraph(banner: String)(block: => Unit): Unit = { 27 | test(s"[tfrm:sql-udf-impl] $banner") { withGraph { block } } 28 | } 29 | 30 | // Provides both a TensoFlow Graph and Session 31 | def testIsolatedSession(banner: String)(block: (TFGraph, TFSession) => Unit): Unit = { 32 | test(s"[tf:iso-sess] $banner") { 33 | val g = new TFGraph() 34 | val sess = new TFSession(g) 35 | block(g, sess) 36 | } 37 | } 38 | 39 | // Following TensorFlow's Java API example 40 | // Reference: tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java 41 | def testGraphBuilder(banner: String)(block: GraphBuilder => Unit): Unit = { 42 | test(s"[tf:iso-sess] $banner") { 43 | val builder = new GraphBuilder(new TFGraph()) 44 | block(builder) 45 | builder.close() 46 | } 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /docs/js/api-docs.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* Dynamically injected post-processing code for the API docs */ 19 | 20 | $(document).ready(function() { 21 | var annotations = $("dt:contains('Annotations')").next("dd").children("span.name"); 22 | addBadges(annotations, "AlphaComponent", ":: AlphaComponent ::", 'Alpha Component'); 23 | addBadges(annotations, "DeveloperApi", ":: DeveloperApi ::", 'Developer API'); 24 | addBadges(annotations, "Experimental", ":: Experimental ::", 'Experimental'); 25 | }); 26 | 27 | function addBadges(allAnnotations, name, tag, html) { 28 | var annotations = allAnnotations.filter(":contains('" + name + "')") 29 | var tags = $(".cmt:contains(" + tag + ")") 30 | 31 | // Remove identifier tags from comments 32 | tags.each(function(index) { 33 | var oldHTML = $(this).html(); 34 | var newHTML = oldHTML.replace(tag, ""); 35 | $(this).html(newHTML); 36 | }); 37 | 38 | // Add badges to all containers 39 | tags.prevAll("h4.signature") 40 | .add(annotations.closest("div.fullcommenttop")) 41 | .add(annotations.closest("div.fullcomment").prevAll("h4.signature")) 42 | .prepend(html); 43 | } 44 | -------------------------------------------------------------------------------- /src/test/scala/com/databricks/sparkdl/TestSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.databricks.sparkdl 17 | 18 | import scala.reflect.runtime.universe._ 19 | 20 | import org.scalatest.{FunSuite, BeforeAndAfterAll} 21 | 22 | import org.apache.spark.{SparkConf, SparkContext} 23 | import org.apache.spark.sql.{Row, DataFrame, SQLContext, SparkSession} 24 | 25 | // This context is used for all tests in this project 26 | trait TestSparkContext extends BeforeAndAfterAll { self: FunSuite => 27 | @transient var sc: SparkContext = _ 28 | @transient var sqlContext: SQLContext = _ 29 | @transient lazy val spark: SparkSession = { 30 | val conf = new SparkConf() 31 | .setMaster("local[*]") 32 | .setAppName("Spark-Deep-Learning-Test") 33 | .set("spark.ui.port", "4079") 34 | .set("spark.sql.shuffle.partitions", "4") // makes small tests much faster 35 | 36 | SparkSession.builder().config(conf).getOrCreate() 37 | } 38 | 39 | override def beforeAll() { 40 | super.beforeAll() 41 | sc = spark.sparkContext 42 | sqlContext = spark.sqlContext 43 | import spark.implicits._ 44 | } 45 | 46 | override def afterAll() { 47 | sqlContext = null 48 | if (sc != null) { 49 | sc.stop() 50 | } 51 | sc = null 52 | super.afterAll() 53 | } 54 | 55 | def makeDF[T: TypeTag](xs: Seq[T], col: String): DataFrame = { 56 | sqlContext.createDataFrame(xs.map(Tuple1.apply)).toDF(col) 57 | } 58 | 59 | def compareRows(r1: Array[Row], r2: Seq[Row]): Unit = { 60 | val a = r1.sortBy(_.toString()) 61 | val b = r2.sortBy(_.toString()) 62 | assert(a === b) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /docs/prepare: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | _bsd_="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 4 | 5 | function help { 6 | cat << _HELP_EOF_ 7 | [[ HELP ]] 8 | $(basename ${BASH_SOURCE[0]}) -s -t 9 | options 10 | -s 11 | Spark distribution directory 12 | http://spark.apache.org/downloads.html 13 | 14 | -t 15 | Tensorframes home directory 16 | https://github.com/databricks/tensorframes.git 17 | _HELP_EOF_ 18 | } 19 | 20 | function quit_with { >&2 echo "ERROR: $@"; help; exit 1; } 21 | 22 | while getopts "s:t:" CMD_OPT; do 23 | case "${CMD_OPT}" in 24 | s) spark_home="${OPTARG}" ;; 25 | t) tensorframes_home="${OPTARG}" ;; 26 | \?) help; exit 1 ;; 27 | esac 28 | done 29 | 30 | [[ -n "${spark_home}" ]] || \ 31 | quit_with "must provide Spark home" 32 | [[ -n "${tensorframes_home}" ]] || \ 33 | quit_with "must provide Tensorframes home" 34 | 35 | set -ex 36 | _tfrm_py_pkg="$(find "${tensorframes_home}/src" -type d -name 'python' | head -n1)" 37 | find "${tensorframes_home}" \ 38 | -type f -name "requirements.txt" \ 39 | -exec pip install --user -r {} \; 40 | 41 | [[ -d "${_tfrm_py_pkg}" ]] || \ 42 | quit_with "cannot find spark package: tensorframes" 43 | [[ -f "${_tfrm_py_pkg}/tensorframes/__init__.py" ]] || \ 44 | quit_with "tensorframes directory does not point to a python package" 45 | 46 | (cd "${_bsd_}/../python" 47 | echo "Creating symlink to tensorframes" 48 | rm -f tensorframes 49 | ln -s "${_tfrm_py_pkg}/tensorframes" . 50 | 51 | [[ -f requirements.txt ]] || \ 52 | quit_with "cannot find python requirements file" 53 | 54 | pip install --user -r requirements.txt 55 | ) 56 | 57 | 58 | # Build the wrapper script for jekyll 59 | touch "${_bsd_}/jekyll" 60 | cat << _JEKYLL_EOF_ | tee "${_bsd_}/jekyll" 61 | #!/bin/bash 62 | 63 | export SPARK_HOME=${spark_home} 64 | export PYTHONPATH=$(find ${HOME}/.local -type d -name 'site-packages' | head -n1):${_bsd_}/../python:${spark_home}/python:$(find "${spark_home}/python/lib" -name 'py4j-*-src.zip') 65 | 66 | (cd ${_bsd_}/../python && sphinx-apidoc -f -o docs sparkdl) 67 | 68 | pushd "${_bsd_}" 69 | jekyll \$@ 70 | popd 71 | _JEKYLL_EOF_ 72 | 73 | chmod +x "${_bsd_}/jekyll" 74 | -------------------------------------------------------------------------------- /python/docs/static/pysparkdl.css: -------------------------------------------------------------------------------- 1 | /* 2 | Licensed to the Apache Software Foundation (ASF) under one or more 3 | contributor license agreements. See the NOTICE file distributed with 4 | this work for additional information regarding copyright ownership. 5 | The ASF licenses this file to You under the Apache License, Version 2.0 6 | (the "License"); you may not use this file except in compliance with 7 | the License. You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | */ 17 | 18 | body { 19 | background-color: #ffffff; 20 | } 21 | 22 | div.sphinxsidebar { 23 | width: 274px; 24 | } 25 | 26 | div.bodywrapper { 27 | margin: 0 0 0 274px; 28 | } 29 | 30 | div.sphinxsidebar ul { 31 | margin-right: 10px; 32 | } 33 | 34 | div.sphinxsidebar li a { 35 | word-break: break-all; 36 | } 37 | 38 | span.pys-tag { 39 | font-size: 11px; 40 | font-weight: bold; 41 | margin: 0 0 0 2px; 42 | padding: 1px 3px 1px 3px; 43 | -moz-border-radius: 3px; 44 | -webkit-border-radius: 3px; 45 | border-radius: 3px; 46 | text-align: center; 47 | text-decoration: none; 48 | } 49 | 50 | span.pys-tag-experimental { 51 | background-color: rgb(37, 112, 128); 52 | color: rgb(255, 255, 255); 53 | } 54 | 55 | span.pys-tag-deprecated { 56 | background-color: rgb(238, 238, 238); 57 | color: rgb(62, 67, 73); 58 | } 59 | 60 | div.pys-note-experimental { 61 | background-color: rgb(88, 151, 165); 62 | border-color: rgb(59, 115, 127); 63 | color: rgb(255, 255, 255); 64 | } 65 | 66 | div.pys-note-deprecated { 67 | } 68 | 69 | .hasTooltip { 70 | position:relative; 71 | } 72 | .hasTooltip span { 73 | display:none; 74 | } 75 | 76 | .hasTooltip:hover span.tooltip { 77 | display: inline-block; 78 | -moz-border-radius: 2px; 79 | -webkit-border-radius: 2px; 80 | border-radius: 2px; 81 | background-color: rgb(250, 250, 250); 82 | color: rgb(68, 68, 68); 83 | font-weight: normal; 84 | box-shadow: 1px 1px 3px rgb(127, 127, 127); 85 | position: absolute; 86 | padding: 0 3px 0 3px; 87 | top: 1.3em; 88 | left: 14px; 89 | z-index: 9999 90 | } 91 | -------------------------------------------------------------------------------- /python/sparkdl/utils/jvmapi.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import logging 17 | 18 | from pyspark import SparkContext, SQLContext 19 | from pyspark.sql.column import Column 20 | 21 | # pylint: disable=W0212 22 | PYTHON_INTERFACE_CLASSNAME = "com.databricks.sparkdl.python.PythonInterface" 23 | MODEL_FACTORY_CLASSNAME = "com.databricks.sparkdl.python.GraphModelFactory" 24 | 25 | logger = logging.getLogger('sparkdl') 26 | 27 | def _curr_sql_ctx(sqlCtx=None): 28 | _sql_ctx = sqlCtx if sqlCtx is not None else SQLContext._instantiatedContext 29 | logger.info("Spark SQL Context = " + str(_sql_ctx)) 30 | return _sql_ctx 31 | 32 | def _curr_sc(): 33 | return SparkContext._active_spark_context 34 | 35 | def _curr_jvm(): 36 | return _curr_sc()._jvm 37 | 38 | def forClass(javaClassName, sqlCtx=None): 39 | """ 40 | Loads the JVM API object (lazily, because the spark context needs to be initialized 41 | first). 42 | """ 43 | # (tjh) suspect the SQL context is doing crazy things at import, because I was 44 | # experiencing some issues here. 45 | # You cannot simply call the creation of the the class on the _jvm 46 | # due to classloader issues with Py4J. 47 | jvm_thread = _curr_jvm().Thread.currentThread() 48 | jvm_class = jvm_thread.getContextClassLoader().loadClass(javaClassName) 49 | return jvm_class.newInstance().sqlContext(_curr_sql_ctx(sqlCtx)._ssql_ctx) 50 | 51 | def pyUtils(): 52 | """ 53 | Exposing Spark PythonUtils 54 | spark/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala 55 | """ 56 | return _curr_jvm().PythonUtils 57 | 58 | def default(): 59 | """ Default JVM Python Interface class """ 60 | return forClass(javaClassName=PYTHON_INTERFACE_CLASSNAME) 61 | 62 | def list_to_vector_udf(col): 63 | """ Map struct column from list to MLlib vector """ 64 | return Column(default().listToMLlibVectorUDF(col._jc)) # pylint: disable=W0212 65 | -------------------------------------------------------------------------------- /docs/js/api-javadocs.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* Dynamically injected post-processing code for the API docs */ 19 | 20 | $(document).ready(function() { 21 | addBadges(":: AlphaComponent ::", 'Alpha Component'); 22 | addBadges(":: DeveloperApi ::", 'Developer API'); 23 | addBadges(":: Experimental ::", 'Experimental'); 24 | }); 25 | 26 | function addBadges(tag, html) { 27 | var tags = $(".block:contains(" + tag + ")") 28 | 29 | // Remove identifier tags 30 | tags.each(function(index) { 31 | var oldHTML = $(this).html(); 32 | var newHTML = oldHTML.replace(tag, ""); 33 | $(this).html(newHTML); 34 | }); 35 | 36 | // Add html badge tags 37 | tags.each(function(index) { 38 | if ($(this).parent().is('td.colLast')) { 39 | $(this).parent().prepend(html); 40 | } else if ($(this).parent('li.blockList') 41 | .parent('ul.blockList') 42 | .parent('div.description') 43 | .parent().is('div.contentContainer')) { 44 | var contentContainer = $(this).parent('li.blockList') 45 | .parent('ul.blockList') 46 | .parent('div.description') 47 | .parent('div.contentContainer') 48 | var header = contentContainer.prev('div.header'); 49 | if (header.length > 0) { 50 | header.prepend(html); 51 | } else { 52 | contentContainer.prepend(html); 53 | } 54 | } else if ($(this).parent().is('li.blockList')) { 55 | $(this).parent().prepend(html); 56 | } else { 57 | $(this).prepend(html); 58 | } 59 | }); 60 | } 61 | -------------------------------------------------------------------------------- /python/docs/underscores.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/dinoboff/github-tools/blob/master/src/github/tools/sphinx.py 2 | # 3 | # Copyright (c) 2009, Damien Lebrun 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without modification, 7 | # are permitted provided that the following conditions are met: 8 | # 9 | # * Redistributions of source code must retain the above copyright notice, 10 | # this list of conditions and the following disclaimer. 11 | # 12 | # * Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 18 | # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | """ 28 | :Description: Sphinx extension to remove leading under-scores from directories names in the html build output directory. 29 | """ 30 | import os 31 | import shutil 32 | 33 | 34 | def setup(app): 35 | """ 36 | Add a html-page-context and a build-finished event handlers 37 | """ 38 | app.connect('html-page-context', change_pathto) 39 | app.connect('build-finished', move_private_folders) 40 | 41 | def change_pathto(app, pagename, templatename, context, doctree): 42 | """ 43 | Replace pathto helper to change paths to folders with a leading underscore. 44 | """ 45 | pathto = context.get('pathto') 46 | def gh_pathto(otheruri, *args, **kw): 47 | if otheruri.startswith('_'): 48 | otheruri = otheruri[1:] 49 | return pathto(otheruri, *args, **kw) 50 | context['pathto'] = gh_pathto 51 | 52 | def move_private_folders(app, e): 53 | """ 54 | remove leading underscore from folders in in the output folder. 55 | 56 | :todo: should only affect html built 57 | """ 58 | def join(dir): 59 | return os.path.join(app.builder.outdir, dir) 60 | 61 | for item in os.listdir(app.builder.outdir): 62 | if item.startswith('_') and os.path.isdir(join(item)): 63 | shutil.move(join(item), join(item[1:])) 64 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/sql/sparkdl_stubs/SparkDLStubsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.apache.spark.sql.sparkdl_stubs 17 | 18 | import org.scalatest.FunSuite 19 | 20 | import org.apache.spark.sql.Row 21 | import org.apache.spark.sql.functions._ 22 | 23 | import com.databricks.sparkdl.TestSparkContext 24 | 25 | /** 26 | * Testing UDF registration 27 | */ 28 | class SparkDLStubSuite extends FunSuite with TestSparkContext { 29 | 30 | test("Registered UDF must be found") { 31 | val udfName = "sparkdl-test-udf" 32 | val udfImpl = { (x: Int, y: Int) => x + y } 33 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf(udfImpl)) 34 | assert(spark.catalog.functionExists(udfName)) 35 | } 36 | 37 | test("Registered piped UDF must be found") { 38 | val udfName = "sparkdl_test_piped_udf" 39 | 40 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_0", 41 | udf({ (x: Int, y: Int) => x + y})) 42 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_1", 43 | udf({ (z: Int) => z * 2})) 44 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_2", 45 | udf({ (w: Int) => w * w + 3})) 46 | 47 | UDFUtils.registerPipeline(spark.sqlContext, udfName, 48 | (0 to 2).map { idx => s"${udfName}_$idx" }) 49 | 50 | assert(spark.catalog.functionExists(udfName)) 51 | } 52 | 53 | test("Using piped UDF in SQL") { 54 | val udfName = "sparkdl_test_piped_udf" 55 | 56 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_add", 57 | udf({ (x: Int, y: Int) => x + y})) 58 | UDFUtils.registerUDF(spark.sqlContext, s"${udfName}_mul", 59 | udf({ (z: Int) => z * 2})) 60 | 61 | UDFUtils.registerPipeline(spark.sqlContext, udfName, Seq(s"${udfName}_add", s"${udfName}_mul")) 62 | 63 | import spark.implicits._ 64 | val df = Seq(1 -> 1, 2 -> 2).toDF("x", "y") 65 | df.createOrReplaceTempView("piped_udf_input_df") 66 | df.printSchema() 67 | 68 | val sqlQuery = s"select x, y, $udfName(x, y) as res from piped_udf_input_df" 69 | println(sqlQuery) 70 | val dfRes = spark.sql(sqlQuery) 71 | dfRes.printSchema() 72 | dfRes.collect().map { case Row(x: Int, y: Int, res: Int) => 73 | assert((x + y) * 2 === res) 74 | } 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /docs/_plugins/copy_api_dirs.rb: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | require 'fileutils' 19 | include FileUtils 20 | 21 | if not (ENV['SKIP_API'] == '1') 22 | if not (ENV['SKIP_SCALADOC'] == '1') 23 | # Build Scaladoc for Java/Scala 24 | 25 | puts "Moving to project root and building API docs." 26 | curr_dir = pwd 27 | cd("..") 28 | 29 | puts "Running 'build/sbt clean compile doc' from " + pwd + "; this may take a few minutes..." 30 | system("build/sbt clean compile doc") || raise("Doc generation failed") 31 | 32 | puts "Moving back into docs dir." 33 | cd("docs") 34 | 35 | puts "Removing old docs" 36 | puts `rm -rf api` 37 | 38 | # Copy over the unified ScalaDoc for all projects to api/scala. 39 | # This directory will be copied over to _site when `jekyll` command is run. 40 | source = "../target/scala-2.11/api" 41 | dest = "api/scala" 42 | 43 | puts "Making directory " + dest 44 | mkdir_p dest 45 | 46 | # From the rubydoc: cp_r('src', 'dest') makes src/dest, but this doesn't. 47 | puts "cp -r " + source + "/. " + dest 48 | cp_r(source + "/.", dest) 49 | 50 | # Append custom JavaScript 51 | js = File.readlines("./js/api-docs.js") 52 | js_file = dest + "/lib/template.js" 53 | File.open(js_file, 'a') { |f| f.write("\n" + js.join()) } 54 | 55 | # Append custom CSS 56 | css = File.readlines("./css/api-docs.css") 57 | css_file = dest + "/lib/template.css" 58 | File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } 59 | end 60 | 61 | if not (ENV['SKIP_PYTHONDOC'] == '1') 62 | # Build Sphinx docs for Python 63 | 64 | # Get and set release version 65 | version = File.foreach('_config.yml').grep(/^SPARKDL_VERSION: (.+)$/){$1}.first 66 | version ||= 'Unknown' 67 | 68 | puts "Moving to python/docs directory and building sphinx." 69 | cd("../python/docs") 70 | if not (ENV['SPARK_HOME']) 71 | raise("Python API docs cannot be generated if SPARK_HOME is not set.") 72 | end 73 | system({"PACKAGE_VERSION"=>version}, "make clean") || raise("Python doc clean failed") 74 | system({"PACKAGE_VERSION"=>version}, "make html") || raise("Python doc generation failed") 75 | 76 | puts "Moving back into home dir." 77 | cd("../../") 78 | 79 | puts "Making directory api/python" 80 | mkdir_p "docs/api/python" 81 | 82 | puts "cp -r python/docs/_build/html/. docs/api/python" 83 | cp_r("python/docs/_build/html/.", "docs/api/python") 84 | end 85 | end 86 | -------------------------------------------------------------------------------- /python/sparkdl/graph/pieces.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | import logging 17 | 18 | import tensorflow as tf 19 | 20 | from sparkdl.graph.builder import IsolatedSession 21 | from sparkdl.image.imageIO import SparkMode 22 | 23 | logger = logging.getLogger('sparkdl') 24 | 25 | """ 26 | Build various pieces of the function 27 | 28 | TODO: We might want to cache some of the big models in their GraphFunction format 29 | Deserializing ProtocolBuffer bytes is in general faster than directly loading Keras models. 30 | """ 31 | 32 | def buildSpImageConverter(img_dtype): 33 | """ 34 | Convert a imageIO byte encoded image into a image tensor suitable as input to ConvNets 35 | The name of the input must be a subset of those specified in `image.imageIO.imageSchema`. 36 | 37 | :param img_dtype: the type of data the underlying image bytes represent 38 | """ 39 | with IsolatedSession() as issn: 40 | # Flat image data -> image dimensions 41 | # This has to conform to `imageIO.imageSchema` 42 | height = tf.placeholder(tf.int32, [], name="height") 43 | width = tf.placeholder(tf.int32, [], name="width") 44 | num_channels = tf.placeholder(tf.int32, [], name="nChannels") 45 | image_buffer = tf.placeholder(tf.string, [], name="data") 46 | 47 | # The image is packed into bytes with height as leading dimension 48 | # This is the default behavior of Python Image Library 49 | shape = tf.reshape(tf.stack([height, width, num_channels], axis=0), 50 | shape=(3,), name='shape') 51 | if img_dtype == SparkMode.RGB: 52 | image_uint8 = tf.decode_raw(image_buffer, tf.uint8, name="decode_raw") 53 | image_float = tf.to_float(image_uint8) 54 | else: 55 | assert img_dtype == SparkMode.RGB_FLOAT32, \ 56 | "Unsupported dtype for image: {}".format(img_dtype) 57 | image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw") 58 | 59 | image_reshaped = tf.reshape(image_float, shape, name="reshaped") 60 | image_input = tf.expand_dims(image_reshaped, 0, name="image_input") 61 | gfn = issn.asGraphFunction([height, width, image_buffer, num_channels], [image_input]) 62 | 63 | return gfn 64 | 65 | def buildFlattener(): 66 | """ 67 | Build a flattening layer to remove the extra leading tensor dimension. 68 | e.g. a tensor of shape [1, W, H, C] will have a shape [W, H, C] after applying this. 69 | """ 70 | with IsolatedSession() as issn: 71 | mat_input = tf.placeholder(tf.float32, [None, None]) 72 | mat_output = tf.identity(tf.reshape(mat_input, shape=[-1]), name='output') 73 | gfn = issn.asGraphFunction([mat_input], [mat_output]) 74 | 75 | return gfn 76 | -------------------------------------------------------------------------------- /python/tests/transformers/keras_image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from sparkdl.image.imageIO import imageStructToArray 17 | from sparkdl.transformers.keras_image import KerasImageFileTransformer 18 | from sparkdl.transformers.utils import InceptionV3Constants 19 | from ..tests import SparkDLTestCase 20 | from .image_utils import ImageNetOutputComparisonTestCase 21 | from . import image_utils 22 | 23 | 24 | class KerasImageFileTransformerTest(SparkDLTestCase): 25 | 26 | def test_loadImages(self): 27 | input_col = "uri" 28 | output_col = "preds" 29 | 30 | model_path = image_utils.prepInceptionV3KerasModelFile("inceptionV3.h5") 31 | transformer = KerasImageFileTransformer(inputCol=input_col, outputCol=output_col, 32 | modelFile=model_path, 33 | imageLoader=image_utils.loadAndPreprocessKerasInceptionV3, 34 | outputMode="vector") 35 | 36 | uri_df = image_utils.getSampleImagePathsDF(self.sql, input_col) 37 | image_df = transformer._loadImages(uri_df) 38 | self.assertEqual(len(image_df.columns), 2) 39 | 40 | img_col = transformer._loadedImageCol() 41 | expected_shape = InceptionV3Constants.INPUT_SHAPE + (3,) 42 | for row in image_df.collect(): 43 | arr = imageStructToArray(row[img_col]) 44 | self.assertEqual(arr.shape, expected_shape) 45 | 46 | 47 | class KerasImageFileTransformerExamplesTest(SparkDLTestCase, ImageNetOutputComparisonTestCase): 48 | 49 | def test_inceptionV3_vs_keras(self): 50 | input_col = "uri" 51 | output_col = "preds" 52 | 53 | model_path = image_utils.prepInceptionV3KerasModelFile("inceptionV3.h5") 54 | transformer = KerasImageFileTransformer(inputCol=input_col, outputCol=output_col, 55 | modelFile=model_path, 56 | imageLoader=image_utils.loadAndPreprocessKerasInceptionV3, 57 | outputMode="vector") 58 | 59 | uri_df = image_utils.getSampleImagePathsDF(self.sql, input_col) 60 | final_df = transformer.transform(uri_df) 61 | self.assertDfHasCols(final_df, [input_col, output_col]) 62 | self.assertEqual(len(final_df.columns), 2) 63 | 64 | collected = final_df.collect() 65 | tvals, ttopK = self.transformOutputToComparables(collected, input_col, output_col) 66 | kvals, ktopK = image_utils.executeKerasInceptionV3(uri_df, uri_col=input_col) 67 | 68 | self.compareClassSets(ktopK, ttopK) 69 | self.compareClassOrderings(ktopK, ttopK) 70 | self.compareArrays(kvals, tvals) 71 | 72 | # TODO: test a workflow with ImageDataGenerator and see if it fits. (It might not.) 73 | -------------------------------------------------------------------------------- /docs/css/main.css: -------------------------------------------------------------------------------- 1 | /* ========================================================================== 2 | Author's custom styles 3 | ========================================================================== */ 4 | 5 | .navbar .brand { 6 | height: 50px; 7 | width: 200px; 8 | margin-left: 1px; 9 | padding: 0; 10 | font-size: 25px; 11 | } 12 | 13 | .version { 14 | line-height: 40px; 15 | vertical-align: bottom; 16 | font-size: 12px; 17 | padding: 0; 18 | margin: 0; 19 | font-weight: bold; 20 | color: #777; 21 | } 22 | 23 | .navbar-inner { 24 | padding-top: 2px; 25 | height: 50px; 26 | } 27 | 28 | .navbar-inner .nav { 29 | margin-top: 5px; 30 | font-size: 15px; 31 | } 32 | 33 | .navbar .divider-vertical { 34 | border-right-color: lightgray; 35 | } 36 | 37 | .navbar-text .version-text { 38 | color: #555555; 39 | padding: 5px; 40 | margin-left: 10px; 41 | } 42 | 43 | body #content { 44 | line-height: 1.6; /* Inspired by Github's wiki style */ 45 | } 46 | 47 | .title { 48 | font-size: 32px; 49 | } 50 | 51 | h1 { 52 | font-size: 28px; 53 | margin-top: 12px; 54 | } 55 | 56 | h2 { 57 | font-size: 24px; 58 | margin-top: 12px; 59 | } 60 | 61 | h3 { 62 | font-size: 21px; 63 | margin-top: 10px; 64 | } 65 | 66 | pre { 67 | font-family: "Menlo", "Lucida Console", monospace; 68 | } 69 | 70 | code { 71 | font-family: "Menlo", "Lucida Console", monospace; 72 | background: white; 73 | border: none; 74 | padding: 0; 75 | color: #444444; 76 | } 77 | 78 | a code { 79 | color: #0088cc; 80 | } 81 | 82 | a:hover code { 83 | color: #005580; 84 | text-decoration: underline; 85 | } 86 | 87 | .container { 88 | max-width: 914px; 89 | } 90 | 91 | .dropdown-menu { 92 | /* Remove the default 2px top margin which causes a small 93 | gap between the hover trigger area and the popup menu */ 94 | margin-top: 0; 95 | /* Avoid too much whitespace at the right for shorter menu items */ 96 | min-width: 50px; 97 | } 98 | 99 | /** 100 | * Make dropdown menus in nav bars show on hover instead of click 101 | * using solution at http://stackoverflow.com/questions/8878033/how- 102 | * to-make-twitter-bootstrap-menu-dropdown-on-hover-rather-than-click 103 | **/ 104 | ul.nav li.dropdown:hover ul.dropdown-menu{ 105 | display: block; 106 | } 107 | 108 | a.menu:after, .dropdown-toggle:after { 109 | content: none; 110 | } 111 | 112 | /** Make the submenus open on hover on the parent menu item */ 113 | ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu:hover ul.dropdown-menu { 114 | display: block; 115 | } 116 | 117 | /** Make the submenus be invisible until the parent menu item is hovered upon */ 118 | ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { 119 | display: none; 120 | } 121 | 122 | /** 123 | * Made the navigation bar buttons not grey out when clicked. 124 | * Essentially making nav bar buttons not react to clicks, only hover events. 125 | */ 126 | .navbar .nav li.dropdown.open > .dropdown-toggle { 127 | background-color: transparent; 128 | } 129 | 130 | /** 131 | * Made the active tab caption blue. Otherwise the active tab is black, and inactive tab is blue. 132 | * That looks weird. Changed the colors to active - blue, inactive - black, and 133 | * no color change on hover. 134 | */ 135 | .nav-tabs > .active > a, .nav-tabs > .active > a:hover { 136 | color: #08c; 137 | } 138 | 139 | .nav-tabs > li > a, .nav-tabs > li > a:hover { 140 | color: #333; 141 | } 142 | 143 | /** 144 | * MathJax (embedded latex formulas) 145 | */ 146 | .MathJax .mo { color: inherit } 147 | .MathJax .mi { color: inherit } 148 | .MathJax .mf { color: inherit } 149 | .MathJax .mh { color: inherit } 150 | 151 | /** 152 | * AnchorJS (anchor links when hovering over headers) 153 | */ 154 | a.anchorjs-link:hover { text-decoration: none; } 155 | -------------------------------------------------------------------------------- /docs/js/main.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | /* Custom JavaScript code in the MarkDown docs */ 19 | 20 | // Enable language-specific code tabs 21 | function codeTabs() { 22 | var counter = 0; 23 | var langImages = { 24 | "scala": "img/scala-sm.png", 25 | "python": "img/python-sm.png", 26 | "java": "img/java-sm.png" 27 | }; 28 | $("div.codetabs").each(function() { 29 | $(this).addClass("tab-content"); 30 | 31 | // Insert the tab bar 32 | var tabBar = $(''); 33 | $(this).before(tabBar); 34 | 35 | // Add each code sample to the tab bar: 36 | var codeSamples = $(this).children("div"); 37 | codeSamples.each(function() { 38 | $(this).addClass("tab-pane"); 39 | var lang = $(this).data("lang"); 40 | var image = $(this).data("image"); 41 | var notabs = $(this).data("notabs"); 42 | var capitalizedLang = lang.substr(0, 1).toUpperCase() + lang.substr(1); 43 | var id = "tab_" + lang + "_" + counter; 44 | $(this).attr("id", id); 45 | if (image != null && langImages[lang]) { 46 | var buttonLabel = "" + capitalizedLang + ""; 47 | } else if (notabs == null) { 48 | var buttonLabel = "" + capitalizedLang + ""; 49 | } else { 50 | var buttonLabel = "" 51 | } 52 | tabBar.append( 53 | '
  • ' + buttonLabel + '
  • ' 54 | ); 55 | }); 56 | 57 | codeSamples.first().addClass("active"); 58 | tabBar.children("li").first().addClass("active"); 59 | counter++; 60 | }); 61 | $("ul.nav-tabs a").click(function (e) { 62 | // Toggling a tab should switch all tabs corresponding to the same language 63 | // while retaining the scroll position 64 | e.preventDefault(); 65 | var scrollOffset = $(this).offset().top - $(document).scrollTop(); 66 | $("." + $(this).attr('class')).tab('show'); 67 | $(document).scrollTop($(this).offset().top - scrollOffset); 68 | }); 69 | } 70 | 71 | 72 | // A script to fix internal hash links because we have an overlapping top bar. 73 | // Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510 74 | function maybeScrollToHash() { 75 | if (window.location.hash && $(window.location.hash).length) { 76 | var newTop = $(window.location.hash).offset().top - 57; 77 | $(window).scrollTop(newTop); 78 | } 79 | } 80 | 81 | $(function() { 82 | codeTabs(); 83 | // Display anchor links when hovering over headers. For documentation of the 84 | // configuration options, see the AnchorJS documentation. 85 | anchors.options = { 86 | placement: 'left' 87 | }; 88 | anchors.add(); 89 | 90 | $(window).bind('hashchange', function() { 91 | maybeScrollToHash(); 92 | }); 93 | 94 | // Scroll now too in case we had opened the page on a hash, but wait a bit because some browsers 95 | // will try to do *their* initial scroll after running the onReady handler. 96 | $(window).load(function() { setTimeout(function() { maybeScrollToHash(); }, 25); }); 97 | }); 98 | -------------------------------------------------------------------------------- /python/docs/static/pysparkdl.js: -------------------------------------------------------------------------------- 1 | /* 2 | Licensed to the Apache Software Foundation (ASF) under one or more 3 | contributor license agreements. See the NOTICE file distributed with 4 | this work for additional information regarding copyright ownership. 5 | The ASF licenses this file to You under the Apache License, Version 2.0 6 | (the "License"); you may not use this file except in compliance with 7 | the License. You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | */ 17 | 18 | $(function (){ 19 | 20 | function startsWith(s, prefix) { 21 | return s && s.indexOf(prefix) === 0; 22 | } 23 | 24 | function buildSidebarLinkMap() { 25 | var linkMap = {}; 26 | $('div.sphinxsidebar a.reference.internal').each(function (i,a) { 27 | var href = $(a).attr('href'); 28 | if (startsWith(href, '#module-')) { 29 | var id = href.substr(8); 30 | linkMap[id] = [$(a), null]; 31 | } 32 | }) 33 | return linkMap; 34 | }; 35 | 36 | function getAdNoteDivs(dd) { 37 | var noteDivs = {}; 38 | dd.find('> div.admonition.note > p.last').each(function (i, p) { 39 | var text = $(p).text(); 40 | if (!noteDivs.experimental && startsWith(text, 'Experimental')) { 41 | noteDivs.experimental = $(p).parent(); 42 | } 43 | if (!noteDivs.deprecated && startsWith(text, 'Deprecated')) { 44 | noteDivs.deprecated = $(p).parent(); 45 | } 46 | }); 47 | return noteDivs; 48 | } 49 | 50 | function getParentId(name) { 51 | var last_idx = name.lastIndexOf('.'); 52 | return last_idx == -1? '': name.substr(0, last_idx); 53 | } 54 | 55 | function buildTag(text, cls, tooltip) { 56 | return '' + text + '' 57 | + tooltip + '' 58 | } 59 | 60 | 61 | var sidebarLinkMap = buildSidebarLinkMap(); 62 | 63 | $('dl.class, dl.function').each(function (i,dl) { 64 | 65 | dl = $(dl); 66 | dt = dl.children('dt').eq(0); 67 | dd = dl.children('dd').eq(0); 68 | var id = dt.attr('id'); 69 | var desc = dt.find('> .descname').text(); 70 | var adNoteDivs = getAdNoteDivs(dd); 71 | 72 | if (id) { 73 | var parent_id = getParentId(id); 74 | 75 | var r = sidebarLinkMap[parent_id]; 76 | if (r) { 77 | if (r[1] === null) { 78 | r[1] = $('
      '); 79 | r[0].parent().append(r[1]); 80 | } 81 | var tags = ''; 82 | if (adNoteDivs.experimental) { 83 | tags += buildTag('E', 'pys-tag-experimental', 'Experimental'); 84 | adNoteDivs.experimental.addClass('pys-note pys-note-experimental'); 85 | } 86 | if (adNoteDivs.deprecated) { 87 | tags += buildTag('D', 'pys-tag-deprecated', 'Deprecated'); 88 | adNoteDivs.deprecated.addClass('pys-note pys-note-deprecated'); 89 | } 90 | var li = $('
    • '); 91 | var a = $('' + desc + ''); 92 | li.append(a); 93 | li.append(tags); 94 | r[1].append(li); 95 | sidebarLinkMap[id] = [a, null]; 96 | } 97 | } 98 | }); 99 | }); 100 | -------------------------------------------------------------------------------- /docs/css/pygments-default.css: -------------------------------------------------------------------------------- 1 | /* 2 | Documentation for pygments (and Jekyll for that matter) is super sparse. 3 | To generate this, I had to run 4 | `pygmentize -S default -f html > pygments-default.css` 5 | But first I had to install pygments via easy_install pygments 6 | 7 | I had to override the conflicting bootstrap style rules by linking to 8 | this stylesheet lower in the html than the bootstap css. 9 | 10 | Also, I was thrown off for a while at first when I was using markdown 11 | code block inside my {% highlight scala %} ... {% endhighlight %} tags 12 | (I was using 4 spaces for this), when it turns out that pygments will 13 | insert the code (or pre?) tags for you. 14 | */ 15 | 16 | .hll { background-color: #ffffcc } 17 | .c { color: #60a0b0; font-style: italic } /* Comment */ 18 | .err { } /* Error */ 19 | .k { color: #007020; font-weight: bold } /* Keyword */ 20 | .o { color: #666666 } /* Operator */ 21 | .cm { color: #60a0b0; font-style: italic } /* Comment.Multiline */ 22 | .cp { color: #007020 } /* Comment.Preproc */ 23 | .c1 { color: #60a0b0; font-style: italic } /* Comment.Single */ 24 | .cs { color: #60a0b0; background-color: #fff0f0 } /* Comment.Special */ 25 | .gd { color: #A00000 } /* Generic.Deleted */ 26 | .ge { font-style: italic } /* Generic.Emph */ 27 | .gr { color: #FF0000 } /* Generic.Error */ 28 | .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 29 | .gi { color: #00A000 } /* Generic.Inserted */ 30 | .go { color: #808080 } /* Generic.Output */ 31 | .gp { color: #c65d09; font-weight: bold } /* Generic.Prompt */ 32 | .gs { font-weight: bold } /* Generic.Strong */ 33 | .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 34 | .gt { color: #0040D0 } /* Generic.Traceback */ 35 | .kc { color: #007020; font-weight: bold } /* Keyword.Constant */ 36 | .kd { color: #007020; font-weight: bold } /* Keyword.Declaration */ 37 | .kn { color: #007020; font-weight: bold } /* Keyword.Namespace */ 38 | .kp { color: #007020 } /* Keyword.Pseudo */ 39 | .kr { color: #007020; font-weight: bold } /* Keyword.Reserved */ 40 | .kt { color: #902000 } /* Keyword.Type */ 41 | .m { color: #40a070 } /* Literal.Number */ 42 | .s { color: #4070a0 } /* Literal.String */ 43 | .na { color: #4070a0 } /* Name.Attribute */ 44 | .nb { color: #007020 } /* Name.Builtin */ 45 | .nc { color: #0e84b5; font-weight: bold } /* Name.Class */ 46 | .no { color: #60add5 } /* Name.Constant */ 47 | .nd { color: #555555; font-weight: bold } /* Name.Decorator */ 48 | .ni { color: #d55537; font-weight: bold } /* Name.Entity */ 49 | .ne { color: #007020 } /* Name.Exception */ 50 | .nf { color: #06287e } /* Name.Function */ 51 | .nl { color: #002070; font-weight: bold } /* Name.Label */ 52 | .nn { color: #0e84b5; font-weight: bold } /* Name.Namespace */ 53 | .nt { color: #062873; font-weight: bold } /* Name.Tag */ 54 | .nv { color: #bb60d5 } /* Name.Variable */ 55 | .ow { color: #007020; font-weight: bold } /* Operator.Word */ 56 | .w { color: #bbbbbb } /* Text.Whitespace */ 57 | .mf { color: #40a070 } /* Literal.Number.Float */ 58 | .mh { color: #40a070 } /* Literal.Number.Hex */ 59 | .mi { color: #40a070 } /* Literal.Number.Integer */ 60 | .mo { color: #40a070 } /* Literal.Number.Oct */ 61 | .sb { color: #4070a0 } /* Literal.String.Backtick */ 62 | .sc { color: #4070a0 } /* Literal.String.Char */ 63 | .sd { color: #4070a0; font-style: italic } /* Literal.String.Doc */ 64 | .s2 { color: #4070a0 } /* Literal.String.Double */ 65 | .se { color: #4070a0; font-weight: bold } /* Literal.String.Escape */ 66 | .sh { color: #4070a0 } /* Literal.String.Heredoc */ 67 | .si { color: #70a0d0; font-style: italic } /* Literal.String.Interpol */ 68 | .sx { color: #c65d09 } /* Literal.String.Other */ 69 | .sr { color: #235388 } /* Literal.String.Regex */ 70 | .s1 { color: #4070a0 } /* Literal.String.Single */ 71 | .ss { color: #517918 } /* Literal.String.Symbol */ 72 | .bp { color: #007020 } /* Name.Builtin.Pseudo */ 73 | .vc { color: #bb60d5 } /* Name.Variable.Class */ 74 | .vg { color: #bb60d5 } /* Name.Variable.Global */ 75 | .vi { color: #bb60d5 } /* Name.Variable.Instance */ 76 | .il { color: #40a070 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /src/main/scala/com/databricks/sparkdl/python/PythonInterface.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.databricks.sparkdl.python 17 | 18 | import java.util.ArrayList 19 | 20 | import scala.collection.JavaConverters._ 21 | import scala.collection.mutable 22 | 23 | import org.apache.spark.annotation.DeveloperApi 24 | import org.apache.spark.ml.linalg.{DenseVector, Vector} 25 | import org.apache.spark.sql.{Column, SQLContext} 26 | import org.apache.spark.sql.expressions.UserDefinedFunction 27 | import org.apache.spark.sql.functions.udf 28 | import org.apache.spark.sql.sparkdl_stubs.{PipelinedUDF, UDFUtils} 29 | import org.apache.spark.sql.types.{ArrayType, DoubleType, FloatType} 30 | 31 | /** 32 | * This file contains some interfaces with the JVM runtime: theses functions create UDFs and 33 | * transform UDFs using java code. 34 | */ 35 | // TODO: this pattern is repeated over and over again, it should be standard somewhere. 36 | @DeveloperApi 37 | class PythonInterface { 38 | private var _sqlCtx: SQLContext = null 39 | 40 | def sqlContext(ctx: SQLContext): this.type = { 41 | _sqlCtx = ctx 42 | this 43 | } 44 | 45 | /** 46 | * Takes a column, which may contain either arrays of floats or doubles, and returns the 47 | * content, cast as MLlib's vectors. 48 | */ 49 | def listToMLlibVectorUDF(col: Column): Column = { 50 | Conversions.convertToVector(col) 51 | } 52 | 53 | /** 54 | * Create an UDF as the result of chainning multiple UDFs 55 | */ 56 | def registerPipeline(name: String, udfNames: ArrayList[String]) = { 57 | require(_sqlCtx != null, "spark session must be provided") 58 | require(udfNames.size > 0) 59 | UDFUtils.registerPipeline(_sqlCtx, name, udfNames.asScala) 60 | } 61 | } 62 | 63 | 64 | @DeveloperApi 65 | object Conversions { 66 | private def floatArrayToVector(x: Array[Float]): Vector = { 67 | new DenseVector(fromFloatArray(x)) 68 | } 69 | 70 | // This code is intrinsically bad for performance: all the elements are not stored in a contiguous 71 | // array, but they are wrapped in java.lang.Float objects (subject to garbage collection, etc.) 72 | // TODO: find a way to directly an array of float from Spark, without going through a scala 73 | // sequence first. 74 | private def floatSeqToVector(x: Seq[Float]): Vector = x match { 75 | case wa: mutable.WrappedArray[Float] => 76 | floatArrayToVector(wa.toArray) // This might look good, but boxing is still happening!!! 77 | case _ => throw new Exception( 78 | s"Expected a WrappedArray, got class of instance ${x.getClass}: $x") 79 | } 80 | 81 | private def doubleArrayToVector(x: Array[Double]): Vector = { new DenseVector(x) } 82 | 83 | private def fromFloatArray(x: Array[Float]): Array[Double] = { 84 | val res = Array.ofDim[Double](x.length) 85 | var idx = 0 86 | while (idx < res.length) { 87 | res(idx) = x(idx) 88 | idx += 1 89 | } 90 | res 91 | } 92 | 93 | def convertToVector(col: Column): Column = { 94 | col.expr.dataType match { 95 | case ArrayType(FloatType, false) => 96 | val f = udf(floatSeqToVector _) 97 | f(col) 98 | case ArrayType(DoubleType, false) => 99 | val f = udf(doubleArrayToVector _) 100 | f(col) 101 | case dt => 102 | throw new Exception(s"convertToVector: cannot deal with type $dt") 103 | } 104 | } 105 | 106 | } 107 | -------------------------------------------------------------------------------- /src/test/scala/org/tensorframes/impl/SqlOpsSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.tensorframes.impl 17 | 18 | import org.scalatest.FunSuite 19 | 20 | import com.databricks.sparkdl.TestSparkContext 21 | import org.apache.spark.sql.Row 22 | import org.apache.spark.sql.types._ 23 | import org.apache.spark.sql.{functions => sqlfn} 24 | 25 | import org.tensorflow.Tensor 26 | import org.tensorframes.{Logging, Shape, ShapeDescription} 27 | import org.tensorframes.dsl.Implicits._ 28 | import org.tensorframes.dsl._ 29 | import org.tensorframes.{dsl => tf} 30 | import org.apache.spark.sql.sparkdl_stubs._ 31 | 32 | // Classes used for creating Dataset 33 | // With `import spark.implicits_` we have the encoders 34 | object SqlOpsSchema { 35 | case class InputCol(a: Double) 36 | case class DFRow(idx: Long, input: InputCol) 37 | } 38 | 39 | class SqlOpsSpec extends FunSuite with TestSparkContext with GraphScoping with Logging { 40 | lazy val sql = sqlContext 41 | import SqlOpsSchema._ 42 | 43 | import TestUtils._ 44 | import Shape.Unknown 45 | 46 | test("Must be able to register TensorFlow Graph UDF") { 47 | val p1 = tf.placeholder[Double](1) named "p1" 48 | val p2 = tf.placeholder[Double](1) named "p2" 49 | val a = p1 + p2 named "a" 50 | val g = buildGraph(a) 51 | val shapeHints = ShapeDescription( 52 | Map("p1" -> Shape(1), "p2" -> Shape(1)), 53 | Seq("p1", "p2"), 54 | Map("a" -> "a")) 55 | 56 | val udfName = "tfs-test-simple-add" 57 | val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false) 58 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration 59 | assert(spark.catalog.functionExists(udfName)) 60 | } 61 | 62 | test("Registered tf.Graph UDF and use in SQL") { 63 | import spark.implicits._ 64 | 65 | val a = tf.placeholder[Double](Unknown) named "inputA" 66 | val z = a + 2.0 named "z" 67 | val g = buildGraph(z) 68 | 69 | val shapeHints = ShapeDescription( 70 | Map("z" -> Shape(1)), 71 | Seq("z"), 72 | Map("inputA" -> "a")) 73 | 74 | logDebug(s"graph ${g.toString}") 75 | 76 | // Build the UDF and register 77 | val udfName = "tfs_test_simple_add" 78 | val udf = SqlOps.makeUDF(udfName, g, shapeHints, false, false) 79 | UDFUtils.registerUDF(spark.sqlContext, udfName, udf) // generic UDF registeration 80 | 81 | // Create a DataFrame 82 | val inputs = (1 to 100).map(_.toDouble) 83 | 84 | val dfIn = inputs.zipWithIndex.map { case (v, idx) => 85 | new DFRow(idx.toLong, new InputCol(v)) 86 | }.toDS.toDF 87 | dfIn.printSchema() 88 | dfIn.createOrReplaceTempView("temp_input_df") 89 | 90 | // Create the query 91 | val sqlQuery = s"select ${udfName}(input) as output from temp_input_df" 92 | logDebug(sqlQuery) 93 | val dfOut = spark.sql(sqlQuery) 94 | dfOut.printSchema() 95 | 96 | // The UDF maps from StructType => StructType 97 | // Thus when iterating over the result, each record is a Row of Row 98 | val res = dfOut.select("output").collect().map { 99 | case rowOut @ Row(rowIn @ Row(t)) => 100 | //println(rowOut, rowIn, t) 101 | t.asInstanceOf[Seq[Double]].head 102 | } 103 | 104 | // Check that all the results are correct 105 | (res zip inputs).foreach { case (v, u) => 106 | assert(v === u + 2.0) 107 | } 108 | } 109 | 110 | } 111 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/sparkdl_stubs/UDFUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package org.apache.spark.sql.sparkdl_stubs 17 | 18 | import java.util.ArrayList 19 | import scala.collection.JavaConverters._ 20 | 21 | import org.apache.spark.internal.Logging 22 | import org.apache.spark.sql.{Column, Row, SQLContext} 23 | import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} 24 | import org.apache.spark.sql.expressions.UserDefinedFunction 25 | import org.apache.spark.sql.types.DataType 26 | 27 | object UDFUtils extends Logging { 28 | /** 29 | * Register a UDF to the given SparkSession, so as to expose it in Spark SQL 30 | * @param spark the SparkSession to which we want to register the UDF 31 | * @param name registered to the provided SparkSession 32 | * @param udf the actual body of the UDF 33 | * @return the registered UDF 34 | */ 35 | def registerUDF(sqlCtx: SQLContext, name: String, udf: UserDefinedFunction): UserDefinedFunction = { 36 | def builder(children: Seq[Expression]) = udf.apply(children.map(cx => new Column(cx)) : _*).expr 37 | val registry = sqlCtx.sessionState.functionRegistry 38 | registry.registerFunction(name, builder) 39 | udf 40 | } 41 | 42 | /** 43 | * Register a UserDefinedfunction (UDF) as a composition of several UDFs. 44 | * The UDFs must have already been registered 45 | * @param spark the SparkSession to which we want to register the UDF 46 | * @param name registered to the provided SparkSession 47 | * @param orderedUdfNames a sequence of UDF names in the composition order 48 | */ 49 | def registerPipeline(sqlCtx: SQLContext, name: String, orderedUdfNames: Seq[String]) = { 50 | val registry = sqlCtx.sessionState.functionRegistry 51 | val builders = orderedUdfNames.flatMap { fname => registry.lookupFunctionBuilder(fname) } 52 | require(builders.size == orderedUdfNames.size, 53 | s"all UDFs must have been registered to the SQL context: $sqlCtx") 54 | def composedBuilder(children: Seq[Expression]): Expression = { 55 | builders.foldLeft(children) { case (exprs, fb) => Seq(fb(exprs)) }.head 56 | } 57 | registry.registerFunction(name, composedBuilder) 58 | } 59 | } 60 | 61 | 62 | /** 63 | * Registering a set of UserDefinedFunctions (UDF) 64 | */ 65 | class PipelinedUDF( 66 | opName: String, 67 | udfs: Seq[UserDefinedFunction], 68 | returnType: DataType) extends UserDefinedFunction(null, returnType, None) { 69 | require(udfs.nonEmpty) 70 | 71 | override def apply(exprs: Column*): Column = { 72 | val start = udfs.head.apply(exprs: _*) 73 | var rest = start 74 | for (udf <- udfs.tail) { 75 | rest = udf.apply(rest) 76 | } 77 | val inner = exprs.toSeq.map(_.toString()).mkString(", ") 78 | val name = s"$opName($inner)" 79 | rest.alias(name) 80 | } 81 | } 82 | 83 | object PipelinedUDF { 84 | def apply(opName: String, fn: UserDefinedFunction, fns: UserDefinedFunction*): UserDefinedFunction = { 85 | if (fns.isEmpty) return fn 86 | new PipelinedUDF(opName, Seq(fn) ++ fns, fns.last.dataType) 87 | } 88 | } 89 | 90 | 91 | class RowUDF( 92 | opName: String, 93 | fun: Column => (Any => Row), 94 | returnType: DataType) extends UserDefinedFunction(null, returnType, None) { 95 | 96 | override def apply(exprs: Column*): Column = { 97 | require(exprs.size == 1, "only support one function") 98 | val f = fun(exprs.head) 99 | val inner = exprs.toSeq.map(_.toString()).mkString(", ") 100 | val name = s"$opName($inner)" 101 | new Column(ScalaUDF(f, dataType, exprs.map(_.expr), Nil)).alias(name) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | Welcome to the Deep Learning Pipelines Spark Package documentation! 2 | 3 | This readme will walk you through navigating and building the Deep Learning Pipelines documentation, which is 4 | included here with the source code. 5 | 6 | Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the 7 | documentation yourself. Why build it yourself? So that you have the docs that correspond to 8 | whichever version of Deep Learning Pipelines you currently have checked out of revision control. 9 | 10 | ## Generating the Documentation HTML 11 | 12 | We include the Deep Learning Pipelines documentation as part of the source (as opposed to using a hosted wiki, such as 13 | the github wiki, as the definitive documentation) to enable the documentation to evolve along with 14 | the source code and be captured by revision control (currently git). This way the code automatically 15 | includes the version of the documentation that is relevant regardless of which version or release 16 | you have checked out or downloaded. 17 | 18 | In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can 19 | read those text files directly if you want. Start with index.md. 20 | 21 | The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com). 22 | `Jekyll` and a few dependencies must be installed for this to work. We recommend 23 | installing via the Ruby Gem dependency manager. Since the exact HTML output 24 | varies between versions of Jekyll and its dependencies, we list specific versions here 25 | in some cases (`Jekyll 3.4.3`): 26 | 27 | $ sudo gem install jekyll bundler 28 | $ sudo gem install jekyll-redirect-from pygments.rb 29 | 30 | 31 | Then run the prepare script to setup prerequisites and generate a wrapper "jekyll" script 32 | $ ./prepare -s -t 33 | 34 | Execute `./jekyll build` from the `docs/` directory to compile the site. Compiling the site with Jekyll will create a directory 35 | called `_site` containing index.html as well as the rest of the compiled files. 36 | 37 | You can modify the default Jekyll build as follows: 38 | 39 | # Skip generating API docs (which takes a while) 40 | $ SKIP_API=1 ./jekyll build 41 | # Serve content locally on port 4000 42 | $ ./jekyll serve --watch 43 | # Build the site with extra features used on the live page 44 | $ PRODUCTION=1 ./jekyll build 45 | 46 | Note that `SPARK_HOME` must be set to your local Spark installation in order to generate the docs. 47 | 48 | ## Pygments 49 | 50 | We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages, 51 | so you will also need to install that (it requires Python) by running `sudo pip install Pygments`. 52 | 53 | To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile 54 | phase, use the following sytax: 55 | 56 | {% highlight scala %} 57 | // Your scala code goes here, you can replace scala with many other 58 | // supported languages too. 59 | {% endhighlight %} 60 | 61 | ## Sphinx 62 | 63 | We use Sphinx to generate Python API docs, so you will need to install it by running 64 | `sudo pip install sphinx`. 65 | 66 | ## API Docs (Scaladoc, Sphinx) 67 | 68 | You can build just the scaladoc by running `build/sbt unidoc` from the SPARKDL_PROJECT_ROOT directory. 69 | 70 | Similarly, you can build just the Python docs by running `make html` from the 71 | SPARKDL_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as 72 | public in `__init__.py`. 73 | 74 | When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various 75 | subprojects into the `docs` directory (and then also into the `_site` directory). We use a 76 | jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it 77 | may take some time as it generates all of the scaladoc. The jekyll plugin also generates the 78 | Python docs [Sphinx](http://sphinx-doc.org/). 79 | 80 | NOTE: To skip the step of building and copying over the Scala, Python API docs, run `SKIP_API=1 81 | jekyll build`. To skip building Scala API docs, run `SKIP_SCALADOC=1 jekyll build`; to skip building Python API docs, run `SKIP_PYTHONDOC=1 jekyll build`. 82 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/param.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | """ 17 | Some parts are copied from pyspark.ml.param.shared and some are complementary 18 | to pyspark.ml.param. The copy is due to some useful pyspark fns/classes being 19 | private APIs. 20 | """ 21 | 22 | from functools import wraps 23 | 24 | import keras 25 | import tensorflow as tf 26 | 27 | from pyspark.ml.param import Param, Params, TypeConverters 28 | 29 | 30 | # From pyspark 31 | 32 | def keyword_only(func): 33 | """ 34 | A decorator that forces keyword arguments in the wrapped method 35 | and saves actual input keyword arguments in `_input_kwargs`. 36 | 37 | .. note:: Should only be used to wrap a method where first arg is `self` 38 | """ 39 | @wraps(func) 40 | def wrapper(self, *args, **kwargs): 41 | if len(args) > 0: 42 | raise TypeError("Method %s forces keyword arguments." % func.__name__) 43 | self._input_kwargs = kwargs 44 | return func(self, **kwargs) 45 | return wrapper 46 | 47 | 48 | class HasInputCol(Params): 49 | """ 50 | Mixin for param inputCol: input column name. 51 | """ 52 | 53 | inputCol = Param(Params._dummy(), "inputCol", "input column name.", typeConverter=TypeConverters.toString) 54 | 55 | def __init__(self): 56 | super(HasInputCol, self).__init__() 57 | 58 | def setInputCol(self, value): 59 | """ 60 | Sets the value of :py:attr:`inputCol`. 61 | """ 62 | return self._set(inputCol=value) 63 | 64 | def getInputCol(self): 65 | """ 66 | Gets the value of inputCol or its default value. 67 | """ 68 | return self.getOrDefault(self.inputCol) 69 | 70 | 71 | class HasOutputCol(Params): 72 | """ 73 | Mixin for param outputCol: output column name. 74 | """ 75 | 76 | outputCol = Param(Params._dummy(), "outputCol", "output column name.", typeConverter=TypeConverters.toString) 77 | 78 | def __init__(self): 79 | super(HasOutputCol, self).__init__() 80 | self._setDefault(outputCol=self.uid + '__output') 81 | 82 | def setOutputCol(self, value): 83 | """ 84 | Sets the value of :py:attr:`outputCol`. 85 | """ 86 | return self._set(outputCol=value) 87 | 88 | def getOutputCol(self): 89 | """ 90 | Gets the value of outputCol or its default value. 91 | """ 92 | return self.getOrDefault(self.outputCol) 93 | 94 | 95 | # New in sparkdl 96 | 97 | class SparkDLTypeConverters(object): 98 | 99 | @staticmethod 100 | def toStringOrTFTensor(value): 101 | if isinstance(value, tf.Tensor): 102 | return value 103 | else: 104 | try: 105 | return TypeConverters.toString(value) 106 | except TypeError: 107 | raise TypeError("Could not convert %s to tensorflow.Tensor or str" % type(value)) 108 | 109 | @staticmethod 110 | def toTFGraph(value): 111 | # TODO: we may want to support tf.GraphDef in the future instead of tf.Graph since user 112 | # is less likely to mess up using GraphDef vs Graph (e.g. constants vs variables). 113 | if isinstance(value, tf.Graph): 114 | return value 115 | else: 116 | raise TypeError("Could not convert %s to tensorflow.Graph type" % type(value)) 117 | 118 | @staticmethod 119 | def supportedNameConverter(supportedList): 120 | def converter(value): 121 | if value in supportedList: 122 | return value 123 | else: 124 | raise TypeError("%s %s is not in the supported list." % type(value), str(value)) 125 | -------------------------------------------------------------------------------- /python/run-tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | # Return on any failure 21 | set -e 22 | 23 | # if (got > 1 argument OR ( got 1 argument AND that argument does not exist)) then 24 | # print usage and exit. 25 | if [[ $# -gt 1 || ($# = 1 && ! -e $1) ]]; then 26 | echo "run_tests.sh [target]" 27 | echo "" 28 | echo "Run python tests for this package." 29 | echo " target -- either a test file or directory [default tests]" 30 | if [[ ($# = 1 && ! -e $1) ]]; then 31 | echo 32 | echo "ERROR: Could not find $1" 33 | fi 34 | exit 1 35 | fi 36 | 37 | # assumes run from python/ directory 38 | if [ -z "$SPARK_HOME" ]; then 39 | echo 'You need to set $SPARK_HOME to run these tests.' >&2 40 | exit 1 41 | fi 42 | 43 | # Honor the choice of python driver 44 | if [ -z "$PYSPARK_PYTHON" ]; then 45 | PYSPARK_PYTHON=`which python` 46 | fi 47 | # Override the python driver version as well to make sure we are in sync in the tests. 48 | export PYSPARK_DRIVER_PYTHON=$PYSPARK_PYTHON 49 | python_major=$($PYSPARK_PYTHON -c 'import sys; print(".".join(map(str, sys.version_info[:1])))') 50 | 51 | echo $pyver 52 | 53 | LIBS="" 54 | for lib in "$SPARK_HOME/python/lib"/*zip ; do 55 | LIBS=$LIBS:$lib 56 | done 57 | 58 | # The current directory of the script. 59 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 60 | 61 | a=( ${SCALA_VERSION//./ } ) 62 | scala_version_major_minor="${a[0]}.${a[1]}" 63 | echo "List of assembly jars found, the last one will be used:" 64 | assembly_path="$DIR/../target/scala-$scala_version_major_minor" 65 | echo `ls $assembly_path/spark-deep-learning-assembly*.jar` 66 | JAR_PATH="" 67 | for assembly in $assembly_path/spark-deep-learning-assembly*.jar ; do 68 | JAR_PATH=$assembly 69 | done 70 | 71 | # python dir ($DIR) should be before assembly so dev changes can be picked up. 72 | export PYTHONPATH=$PYTHONPATH:$DIR 73 | export PYTHONPATH=$PYTHONPATH:$assembly # same $assembly used for the JAR_PATH above 74 | export PYTHONPATH=$PYTHONPATH:$SPARK_HOME/python:$LIBS:. 75 | 76 | # This will be used when starting pyspark. 77 | export PYSPARK_SUBMIT_ARGS="--driver-memory 4g --executor-memory 4g --jars $JAR_PATH pyspark-shell" 78 | 79 | 80 | # Run test suites 81 | 82 | # TODO: make sure travis has the right version of nose 83 | if [ -f "$1" ]; then 84 | noseOptionsArr="$1" 85 | else 86 | if [ -d "$1" ]; then 87 | targetDir=$1 88 | else 89 | targetDir=$DIR/tests 90 | fi 91 | # add all python files in the test dir recursively 92 | echo "============= Searching for tests in: $targetDir =============" 93 | noseOptionsArr="$(find "$targetDir" -type f | grep "\.py" | grep -v "\.pyc" | grep -v "\.py~" | grep -v "__init__.py")" 94 | fi 95 | 96 | # Limit TensorFlow error message 97 | # https://github.com/tensorflow/tensorflow/issues/1258 98 | export TF_CPP_MIN_LOG_LEVEL=3 99 | 100 | for noseOptions in $noseOptionsArr 101 | do 102 | echo "============= Running the tests in: $noseOptions =============" 103 | # The grep below is a horrible hack for spark 1.x: we manually remove some log lines to stay below the 4MB log limit on Travis. 104 | $PYSPARK_DRIVER_PYTHON \ 105 | -m "nose" \ 106 | --with-coverage \ 107 | --nologcapture \ 108 | -v --exe "$noseOptions" \ 109 | 2>&1 | grep -vE "INFO (ParquetOutputFormat|SparkContext|ContextCleaner|ShuffleBlockFetcherIterator|MapOutputTrackerMaster|TaskSetManager|Executor|MemoryStore|CacheManager|BlockManager|DAGScheduler|PythonRDD|TaskSchedulerImpl|ZippedPartitionsRDD2)"; 110 | 111 | # Exit immediately if the tests fail. 112 | # Since we pipe to remove the output, we need to use some horrible BASH features: 113 | # http://stackoverflow.com/questions/1221833/bash-pipe-output-and-capture-exit-status 114 | test ${PIPESTATUS[0]} -eq 0 || exit 1; 115 | done 116 | 117 | 118 | # Run doc tests 119 | 120 | #$PYSPARK_PYTHON -u ./sparkdl/ourpythonfilewheneverwehaveone.py "$@" 121 | -------------------------------------------------------------------------------- /docs/js/vendor/anchor.min.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * AnchorJS - v1.1.1 - 2015-05-23 3 | * https://github.com/bryanbraun/anchorjs 4 | * Copyright (c) 2015 Bryan Braun; Licensed MIT 5 | */ 6 | function AnchorJS(A){"use strict";this.options=A||{},this._applyRemainingDefaultOptions=function(A){this.options.icon=this.options.hasOwnProperty("icon")?A.icon:"",this.options.visible=this.options.hasOwnProperty("visible")?A.visible:"hover",this.options.placement=this.options.hasOwnProperty("placement")?A.placement:"right",this.options.class=this.options.hasOwnProperty("class")?A.class:""},this._applyRemainingDefaultOptions(A),this.add=function(A){var e,t,o,n,i,s,a,l,c,r,h,g,B,Q;if(this._applyRemainingDefaultOptions(this.options),A){if("string"!=typeof A)throw new Error("The selector provided to AnchorJS was invalid.")}else A="h1, h2, h3, h4, h5, h6";if(e=document.querySelectorAll(A),0===e.length)return!1;for(this._addBaselineStyles(),t=document.querySelectorAll("[id]"),o=[].map.call(t,function(A){return A.id}),i=0;i',B=document.createElement("div"),B.innerHTML=g,Q=B.childNodes,"always"===this.options.visible&&(Q[0].style.opacity="1"),""===this.options.icon&&(Q[0].style.fontFamily="anchorjs-icons",Q[0].style.fontStyle="normal",Q[0].style.fontVariant="normal",Q[0].style.fontWeight="normal"),"left"===this.options.placement?(Q[0].style.position="absolute",Q[0].style.marginLeft="-1em",Q[0].style.paddingRight="0.5em",e[i].insertBefore(Q[0],e[i].firstChild)):(Q[0].style.paddingLeft="0.375em",e[i].appendChild(Q[0]))}return this},this.remove=function(A){for(var e,t=document.querySelectorAll(A),o=0;o JPaths} 20 | import scala.collection.JavaConverters._ 21 | 22 | import org.tensorflow.{Graph => TFGraph, Session => TFSession, Output => TFOut, Tensor} 23 | import org.tensorflow.framework.GraphDef 24 | 25 | import org.tensorframes.ShapeDescription 26 | import org.tensorframes.dsl.Implicits._ 27 | 28 | 29 | /** 30 | * Utilities for buidling graphs with TensorFlow Java API 31 | * 32 | * Reference: tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java 33 | */ 34 | class GraphBuilder(g: TFGraph) { 35 | 36 | var varIdx: Long = 0L 37 | @transient private[this] var _sess: Option[TFSession] = None 38 | lazy val sess: TFSession = { 39 | if (_sess.isEmpty) { 40 | _sess = Some(new TFSession(g)) 41 | } 42 | _sess.get 43 | } 44 | 45 | def close(): Unit = { 46 | _sess.foreach(_.close()) 47 | g.close() 48 | } 49 | 50 | def op(opType: String, name: Option[String] = None)(in0: TFOut, ins: TFOut*): TFOut = { 51 | val opName = name.getOrElse(s"$opType-${varIdx += 1}") 52 | var b = g.opBuilder(opType, opName).addInput(in0) 53 | ins.foreach { in => b = b.addInput(in) } 54 | b.build().output(0) 55 | } 56 | 57 | def const[T](name: String, value: T): TFOut = { 58 | val tnsr = Tensor.create(value) 59 | g.opBuilder("Const", name) 60 | .setAttr("dtype", tnsr.dataType()) 61 | .setAttr("value", tnsr) 62 | .build().output(0) 63 | } 64 | 65 | def run(feeds: Map[String, Any], fetch: String): Tensor = { 66 | run(feeds, Seq(fetch)).head 67 | } 68 | 69 | def run(feeds: Map[String, Any], fetches: Seq[String]): Seq[Tensor] = { 70 | var runner = sess.runner() 71 | feeds.foreach { 72 | case (name, tnsr: Tensor) => 73 | runner = runner.feed(name, tnsr) 74 | case (name, value) => 75 | runner = runner.feed(name, Tensor.create(value)) 76 | } 77 | fetches.foreach { name => runner = runner.fetch(name) } 78 | runner.run().asScala 79 | } 80 | } 81 | 82 | /** 83 | * Utilities for building graphs with TensorFrames API (with DSL) 84 | * 85 | * TODO: these are taken from TensorFrames, we will eventually merge them 86 | */ 87 | private[tensorframes] object TestUtils { 88 | 89 | import org.tensorframes.dsl._ 90 | 91 | def buildGraph(node: Operation, nodes: Operation*): GraphDef = { 92 | buildGraph(Seq(node) ++ nodes) 93 | } 94 | 95 | def loadGraph(file: String): GraphDef = { 96 | val byteArray = Files.readAllBytes(JPaths.get(file)) 97 | GraphDef.newBuilder().mergeFrom(byteArray).build() 98 | } 99 | 100 | def analyzeGraph(nodes: Operation*): (GraphDef, Seq[GraphNodeSummary]) = { 101 | val g = buildGraph(nodes.head, nodes.tail: _*) 102 | g -> TensorFlowOps.analyzeGraphTF(g, extraInfo(nodes)) 103 | } 104 | 105 | // Implicit type conversion 106 | implicit def op2Node(op: Operation): Node = op.asInstanceOf[Node] 107 | implicit def ops2Nodes(ops: Seq[Operation]): Seq[Node] = ops.map(op2Node) 108 | 109 | private def getClosure(node: Node, treated: Map[String, Node]): Map[String, Node] = { 110 | val explored = node.parents 111 | .filterNot(n => treated.contains(n.name)) 112 | .flatMap(getClosure(_, treated + (node.name -> node))) 113 | .toMap 114 | 115 | uniqueByName(node +: (explored.values.toSeq ++ treated.values.toSeq)) 116 | } 117 | 118 | private def uniqueByName(nodes: Seq[Node]): Map[String, Node] = { 119 | nodes.groupBy(_.name).mapValues(_.head) 120 | } 121 | 122 | def buildGraph(nodes: Seq[Operation]): GraphDef = { 123 | nodes.foreach(_.freeze()) 124 | nodes.foreach(_.freeze(everything=true)) 125 | var treated: Map[String, Node] = Map.empty 126 | nodes.foreach { node => 127 | treated = getClosure(node, treated) 128 | } 129 | val b = GraphDef.newBuilder() 130 | treated.values.flatMap(_.nodes).foreach(b.addNode) 131 | b.build() 132 | } 133 | 134 | private def extraInfo(fetches: Seq[Node]): ShapeDescription = { 135 | val m2 = fetches.map(n => n.name -> n.name).toMap 136 | ShapeDescription( 137 | fetches.map(n => n.name -> n.shape).toMap, 138 | fetches.map(_.name), 139 | m2) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /docs/_layouts/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Page Not Found :( 6 | 141 | 142 | 143 |
      144 |

      Not found :(

      145 |

      Sorry, but the page you were trying to view does not exist.

      146 |

      It looks like this was the result of either:

      147 |
        148 |
      • a mistyped address
      • 149 |
      • an out-of-date link
      • 150 |
      151 | 154 | 155 |
      156 | 157 | 158 | -------------------------------------------------------------------------------- /python/tests/graph/test_builder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from __future__ import print_function 18 | 19 | from glob import glob 20 | import os 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | import keras.backend as K 25 | from keras.applications import InceptionV3 26 | from keras.applications import inception_v3 as iv3 27 | from keras.preprocessing.image import load_img, img_to_array 28 | 29 | from pyspark import SparkContext 30 | from pyspark.sql import DataFrame, Row 31 | from pyspark.sql.functions import udf 32 | 33 | from sparkdl.graph.builder import IsolatedSession, GraphFunction 34 | import sparkdl.graph.utils as tfx 35 | 36 | from ..tests import SparkDLTestCase 37 | from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF 38 | 39 | 40 | class GraphFunctionWithIsolatedSessionTest(SparkDLTestCase): 41 | 42 | def test_tf_consistency(self): 43 | """ Should get the same graph as running pure tf """ 44 | 45 | x_val = 2702.142857 46 | g = tf.Graph() 47 | with tf.Session(graph=g) as sess: 48 | x = tf.placeholder(tf.double, shape=[], name="x") 49 | z = tf.add(x, 3, name='z') 50 | gdef_ref = g.as_graph_def(add_shapes=True) 51 | z_ref = sess.run(z, {x: x_val}) 52 | 53 | with IsolatedSession() as issn: 54 | x = tf.placeholder(tf.double, shape=[], name="x") 55 | z = tf.add(x, 3, name='z') 56 | gfn = issn.asGraphFunction([x], [z]) 57 | z_tgt = issn.run(z, {x: x_val}) 58 | 59 | self.assertEqual(z_ref, z_tgt) 60 | 61 | # Version texts are not essential part of the graph, ignore them 62 | gdef_ref.ClearField("versions") 63 | gfn.graph_def.ClearField("versions") 64 | 65 | # The GraphDef contained in the GraphFunction object 66 | # should be the same as that in the one exported directly from TensorFlow session 67 | self.assertEqual(str(gfn.graph_def), str(gdef_ref)) 68 | 69 | def test_get_graph_elements(self): 70 | """ Fetching graph elements by names and other graph elements """ 71 | 72 | with IsolatedSession() as issn: 73 | x = tf.placeholder(tf.double, shape=[], name="x") 74 | z = tf.add(x, 3, name='z') 75 | 76 | g = issn.graph 77 | self.assertEqual(tfx.get_tensor(g, z), z) 78 | self.assertEqual(tfx.get_tensor(g, x), x) 79 | self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x)) 80 | self.assertEqual("x:0", tfx.tensor_name(g, x)) 81 | self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x)) 82 | self.assertEqual("x", tfx.op_name(g, x)) 83 | self.assertEqual("z", tfx.op_name(g, z)) 84 | self.assertEqual(tfx.tensor_name(g, z), "z:0") 85 | self.assertEqual(tfx.tensor_name(g, x), "x:0") 86 | 87 | def test_import_export_graph_function(self): 88 | """ Function import and export must be consistent """ 89 | 90 | with IsolatedSession() as issn: 91 | x = tf.placeholder(tf.double, shape=[], name="x") 92 | z = tf.add(x, 3, name='z') 93 | gfn_ref = issn.asGraphFunction([x], [z]) 94 | 95 | with IsolatedSession() as issn: 96 | feeds, fetches = issn.importGraphFunction(gfn_ref, prefix="") 97 | gfn_tgt = issn.asGraphFunction(feeds, fetches) 98 | 99 | self.assertEqual(gfn_tgt.input_names, gfn_ref.input_names) 100 | self.assertEqual(gfn_tgt.output_names, gfn_ref.output_names) 101 | self.assertEqual(str(gfn_tgt.graph_def), str(gfn_ref.graph_def)) 102 | 103 | 104 | def test_keras_consistency(self): 105 | """ Exported model in Keras should get same result as original """ 106 | 107 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) 108 | 109 | def keras_load_and_preproc(fpath): 110 | img = load_img(fpath, target_size=(299, 299)) 111 | img_arr = img_to_array(img) 112 | img_iv3_input = iv3.preprocess_input(img_arr) 113 | return np.expand_dims(img_iv3_input, axis=0) 114 | 115 | imgs_iv3_input = np.vstack([keras_load_and_preproc(fp) for fp in img_fpaths]) 116 | 117 | model_ref = InceptionV3(weights="imagenet") 118 | preds_ref = model_ref.predict(imgs_iv3_input) 119 | 120 | with IsolatedSession(using_keras=True) as issn: 121 | K.set_learning_phase(0) 122 | model = InceptionV3(weights="imagenet") 123 | gfn = issn.asGraphFunction(model.inputs, model.outputs) 124 | 125 | with IsolatedSession(using_keras=True) as issn: 126 | K.set_learning_phase(0) 127 | feeds, fetches = issn.importGraphFunction(gfn, prefix="InceptionV3") 128 | preds_tgt = issn.run(fetches[0], {feeds[0]: imgs_iv3_input}) 129 | 130 | self.assertTrue(np.all(preds_tgt == preds_ref)) 131 | -------------------------------------------------------------------------------- /python/tests/transformers/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import os 17 | from glob import glob 18 | import tempfile 19 | import unittest 20 | from warnings import warn 21 | 22 | from keras.applications import InceptionV3 23 | from keras.applications.inception_v3 import preprocess_input, decode_predictions 24 | from keras.preprocessing.image import img_to_array, load_img 25 | import keras.backend as K 26 | import numpy as np 27 | import PIL.Image 28 | 29 | from pyspark.sql.types import StringType 30 | 31 | from sparkdl.image import imageIO 32 | from sparkdl.transformers.utils import ImageNetConstants, InceptionV3Constants 33 | 34 | 35 | # Methods for getting some test data to work with. 36 | 37 | def _getSampleJPEGDir(): 38 | cur_dir = os.path.dirname(__file__) 39 | return os.path.join(cur_dir, "../resources/images") 40 | 41 | 42 | def getSampleImageDF(): 43 | return imageIO.readImages(_getSampleJPEGDir()) 44 | 45 | def getSampleImagePathsDF(sqlContext, colName): 46 | dirpath = _getSampleJPEGDir() 47 | files = [os.path.abspath(os.path.join(dirpath, f)) for f in os.listdir(dirpath) 48 | if f.endswith('.jpg')] 49 | return sqlContext.createDataFrame(files, StringType()).toDF(colName) 50 | 51 | 52 | # Methods for making comparisons between outputs of using different frameworks. 53 | # For ImageNet. 54 | 55 | class ImageNetOutputComparisonTestCase(unittest.TestCase): 56 | 57 | def transformOutputToComparables(self, collected, uri_col, output_col): 58 | values = {} 59 | topK = {} 60 | for row in collected: 61 | uri = row[uri_col] 62 | predictions = row[output_col] 63 | self.assertEqual(len(predictions), ImageNetConstants.NUM_CLASSES) 64 | 65 | values[uri] = np.expand_dims(predictions, axis=0) 66 | topK[uri] = decode_predictions(values[uri], top=5)[0] 67 | return values, topK 68 | 69 | def compareArrays(self, values1, values2): 70 | """ 71 | values1 & values2 are {key => numpy array}. 72 | """ 73 | for k, v1 in values1.items(): 74 | v1f = v1.astype(np.float32) 75 | v2f = values2[k].astype(np.float32) 76 | np.testing.assert_array_equal(v1f, v2f) 77 | 78 | def compareClassOrderings(self, preds1, preds2): 79 | """ 80 | preds1 & preds2 are {key => (class, description, probability)}. 81 | """ 82 | for k, v1 in preds1.items(): 83 | self.assertEqual([v[1] for v in v1], [v[1] for v in preds2[k]]) 84 | 85 | def compareClassSets(self, preds1, preds2): 86 | """ 87 | values1 & values2 are {key => numpy array}. 88 | """ 89 | for k, v1 in preds1.items(): 90 | self.assertEqual(set([v[1] for v in v1]), set([v[1] for v in preds2[k]])) 91 | 92 | 93 | def getSampleImageList(): 94 | imageFiles = glob(os.path.join(_getSampleJPEGDir(), "*")) 95 | images = [] 96 | for f in imageFiles: 97 | try: 98 | img = PIL.Image.open(f) 99 | except IOError: 100 | warn("Could not read file in image directory.") 101 | images.append(None) 102 | else: 103 | images.append(img) 104 | return imageFiles, images 105 | 106 | 107 | def executeKerasInceptionV3(image_df, uri_col="filePath"): 108 | """ 109 | Apply Keras InceptionV3 Model on input DataFrame. 110 | :param image_df: Dataset. contains a column (uri_col) for where the image file lives. 111 | :param uri_col: str. name of the column indicating where each row's image file lives. 112 | :return: ({str => np.array[float]}, {str => (str, str, float)}). 113 | image file uri to prediction probability array, 114 | image file uri to top K predictions (class id, class description, probability). 115 | """ 116 | K.set_learning_phase(0) 117 | model = InceptionV3(weights="imagenet") 118 | 119 | values = {} 120 | topK = {} 121 | for row in image_df.select(uri_col).collect(): 122 | raw_uri = row[uri_col] 123 | image = loadAndPreprocessKerasInceptionV3(raw_uri) 124 | values[raw_uri] = model.predict(image) 125 | topK[raw_uri] = decode_predictions(values[raw_uri], top=5)[0] 126 | return values, topK 127 | 128 | def loadAndPreprocessKerasInceptionV3(raw_uri): 129 | # this is the canonical way to load and prep images in keras 130 | uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri 131 | image = img_to_array(load_img(uri, target_size=InceptionV3Constants.INPUT_SHAPE)) 132 | image = np.expand_dims(image, axis=0) 133 | return preprocess_input(image) 134 | 135 | def prepInceptionV3KerasModelFile(fileName): 136 | model_dir_tmp = tempfile.mkdtemp("sparkdl_keras_tests", dir="/tmp") 137 | path = model_dir_tmp + "/" + fileName 138 | 139 | height, width = InceptionV3Constants.INPUT_SHAPE 140 | input_shape = (height, width, 3) 141 | model = InceptionV3(weights="imagenet", include_top=True, input_shape=input_shape) 142 | model.save(path) 143 | return path 144 | -------------------------------------------------------------------------------- /python/sparkdl/transformers/keras_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | import keras.backend as K 17 | from keras.models import load_model 18 | 19 | from pyspark.ml import Transformer 20 | from pyspark.ml.param import Param, Params, TypeConverters 21 | from pyspark.sql.functions import udf 22 | 23 | import sparkdl.graph.utils as tfx 24 | from sparkdl.image import imageIO 25 | from sparkdl.transformers.keras_utils import KSessionWrap 26 | from sparkdl.transformers.param import ( 27 | keyword_only, HasInputCol, HasOutputCol, SparkDLTypeConverters) 28 | from sparkdl.transformers.tf_image import TFImageTransformer, OUTPUT_MODES 29 | import sparkdl.transformers.utils as utils 30 | 31 | 32 | class KerasImageFileTransformer(Transformer, HasInputCol, HasOutputCol): 33 | """ 34 | Applies the Tensorflow-backed Keras model (specified by a file name) to 35 | images (specified by the URI in the inputCol column) in the DataFrame. 36 | 37 | Restrictions of the current API: 38 | * see TFImageTransformer. 39 | * Only supports Tensorflow-backed Keras models (no Theano). 40 | """ 41 | 42 | modelFile = Param(Params._dummy(), "modelFile", 43 | "h5py file containing the Keras model (architecture and weights)", 44 | typeConverter=TypeConverters.toString) 45 | # TODO :add a lambda type converter e.g callable(mylambda) 46 | imageLoader = Param(Params._dummy(), "imageLoader", 47 | "Function containing the logic for loading and pre-processing images. " + 48 | "The function should take in a URI string and return a 4-d numpy.array " + 49 | "with shape (batch_size (1), height, width, num_channels).") 50 | outputMode = Param(Params._dummy(), "outputMode", 51 | "How the output column should be formatted. 'vector' for a 1-d MLlib " + 52 | "Vector of floats. 'image' to format the output to work with the image " + 53 | "tools in this package.", 54 | typeConverter=SparkDLTypeConverters.supportedNameConverter(OUTPUT_MODES)) 55 | 56 | @keyword_only 57 | def __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, 58 | outputMode="vector"): 59 | """ 60 | __init__(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, 61 | outputMode="vector") 62 | """ 63 | super(KerasImageFileTransformer, self).__init__() 64 | kwargs = self._input_kwargs 65 | self.setParams(**kwargs) 66 | self._inputTensor = None 67 | self._outputTensor = None 68 | 69 | @keyword_only 70 | def setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, 71 | outputMode="vector"): 72 | """ 73 | setParams(self, inputCol=None, outputCol=None, modelFile=None, imageLoader=None, 74 | outputMode="vector") 75 | """ 76 | kwargs = self._input_kwargs 77 | self._set(**kwargs) 78 | return self 79 | 80 | def setModelFile(self, value): 81 | return self._set(modelFile=value) 82 | 83 | def getModelFile(self): 84 | return self.getOrDefault(self.modelFile) 85 | 86 | def _transform(self, dataset): 87 | graph = self._loadTFGraph() 88 | image_df = self._loadImages(dataset) 89 | 90 | assert self._inputTensor is not None, "self._inputTensor must be set" 91 | assert self._outputTensor is not None, "self._outputTensor must be set" 92 | 93 | transformer = TFImageTransformer(inputCol=self._loadedImageCol(), 94 | outputCol=self.getOutputCol(), graph=graph, 95 | inputTensor=self._inputTensor, 96 | outputTensor=self._outputTensor, 97 | outputMode=self.getOrDefault(self.outputMode)) 98 | return transformer.transform(image_df).drop(self._loadedImageCol()) 99 | 100 | def _loadTFGraph(self): 101 | with KSessionWrap() as (sess, g): 102 | assert K.backend() == "tensorflow", \ 103 | "Keras backend is not tensorflow but KerasImageTransformer only supports " + \ 104 | "tensorflow-backed Keras models." 105 | with g.as_default(): 106 | K.set_learning_phase(0) # Testing phase 107 | model = load_model(self.getModelFile()) 108 | out_op_name = tfx.op_name(g, model.output) 109 | self._inputTensor = model.input.name 110 | self._outputTensor = model.output.name 111 | return tfx.strip_and_freeze_until([out_op_name], g, sess, return_graph=True) 112 | 113 | def _loadedImageCol(self): 114 | return "__sdl_img" 115 | 116 | def _loadImages(self, dataset): 117 | """ 118 | Load image files specified in dataset as image format specified in sparkdl.image.imageIO. 119 | """ 120 | # plan 1: udf(loader() + convert from np.array to imageSchema) -> call TFImageTransformer 121 | # plan 2: udf(loader()) ... we don't support np.array as a dataframe column type... 122 | loader = self.getOrDefault(self.imageLoader) 123 | 124 | def load(uri): 125 | img = loader(uri) 126 | return imageIO.imageArrayToStruct(img) 127 | load_udf = udf(load, imageIO.imageSchema) 128 | return dataset.withColumn(self._loadedImageCol(), load_udf(dataset[self.getInputCol()])) 129 | -------------------------------------------------------------------------------- /docs/_layouts/global.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | {{ page.title }} - Deep Learning Pipelines {{site.SPARKDL_VERSION}} Documentation 10 | {% if page.description %} 11 | 12 | {% endif %} 13 | 14 | {% if page.redirect %} 15 | 16 | 17 | {% endif %} 18 | 19 | 20 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | {% production %} 35 | 36 | 47 | {% endproduction %} 48 | 49 | 50 | 51 | 54 | 55 | 56 | 57 | 86 | 87 |
      88 | {% if page.displayTitle %} 89 |

      {{ page.displayTitle }}

      90 | {% else %} 91 |

      {{ page.title }}

      92 | {% endif %} 93 | 94 | {{ content }} 95 | 96 |
      97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 109 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /src/main/scala/com/databricks/sparkdl/python/ModelFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2017 Databricks, Inc. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | package com.databricks.sparkdl.python 17 | 18 | import java.nio.file.{Files, Paths} 19 | import java.util 20 | 21 | import scala.collection.JavaConverters._ 22 | 23 | import org.apache.log4j.PropertyConfigurator 24 | 25 | import org.apache.spark.annotation.DeveloperApi 26 | import org.apache.spark.sql.SQLContext 27 | import org.apache.spark.sql.sparkdl_stubs.UDFUtils 28 | import org.apache.spark.sql.expressions.UserDefinedFunction 29 | import org.tensorflow.framework.GraphDef 30 | import org.tensorframes.{Shape, ShapeDescription} 31 | import org.tensorframes.impl.{SerializedGraph, SqlOps, TensorFlowOps} 32 | 33 | import com.databricks.sparkdl.Logging 34 | 35 | /** 36 | * Taking over TensorFlow graphs handed from Python 37 | * validate the graph 38 | */ 39 | // TODO: merge TensorFlow graphs into TensorFrames eventually 40 | @DeveloperApi 41 | class GraphModelFactory() extends Logging { 42 | private var _sqlCtx: SQLContext = null 43 | 44 | private var _shapeHints: ShapeDescription = ShapeDescription.empty 45 | // WARNING: this object may leak because of Py4J -> do not hold to large objects here. 46 | private var _graph: SerializedGraph = null 47 | private var _graphPath: Option[String] = None 48 | 49 | def initializeLogging(): Unit = initializeLogging("org/tensorframes/log4j.properties") 50 | 51 | /** 52 | * Performs some logging initialization before spark has the time to do it. 53 | * 54 | * Because of the the current implementation of PySpark, Spark thinks it runs as an interactive 55 | * console and makes some mistake when setting up log4j. 56 | */ 57 | private def initializeLogging(file: String): Unit = { 58 | Option(this.getClass.getClassLoader.getResource(file)) match { 59 | case Some(url) => 60 | PropertyConfigurator.configure(url) 61 | case None => 62 | System.err.println(s"$this Could not load logging file $file") 63 | } 64 | } 65 | 66 | /** Setup SQLContext for UDF registeration */ 67 | def sqlContext(ctx: SQLContext): this.type = { 68 | _sqlCtx = ctx 69 | this 70 | } 71 | 72 | /** 73 | * Append shape information to the graph 74 | * @param shapeHintsNames names of graph elements 75 | * @param shapeHintsShapes the corresponding shape of the named graph elements 76 | */ 77 | def shape( 78 | shapeHintsNames: util.ArrayList[String], 79 | shapeHintShapes: util.ArrayList[util.ArrayList[Int]]): this.type = { 80 | val s = shapeHintShapes.asScala.map(_.asScala.toSeq).map(x => Shape(x: _*)) 81 | _shapeHints = _shapeHints.copy(out = shapeHintsNames.asScala.zip(s).toMap) 82 | this 83 | } 84 | 85 | /** 86 | * Fetches (i.e. graph elements intended as output) of the graph 87 | * @param fetchNames a list of graph element names indicating the fetches 88 | */ 89 | def fetches(fetchNames: util.ArrayList[String]): this.type = { 90 | _shapeHints = _shapeHints.copy(requestedFetches = fetchNames.asScala) 91 | this 92 | } 93 | 94 | /** 95 | * Create TensorFlow graph from serialzied GraphDef 96 | * @param bytes the serialzied GraphDef 97 | */ 98 | def graph(bytes: Array[Byte]): this.type = { 99 | _graph = SerializedGraph.create(bytes) 100 | this 101 | } 102 | 103 | /** 104 | * Attach graph definition file 105 | * @param filename path to the serialized graph 106 | */ 107 | def graphFromFile(filename: String): this.type = { 108 | _graphPath = Option(filename) 109 | this 110 | } 111 | 112 | /** 113 | * Specify struct field names and corresponding tf.placeholder paths in the Graph 114 | * @param placeholderPaths tf.placeholder paths in the Graph 115 | * @param fieldNames struct field names 116 | */ 117 | def inputs( 118 | placeholderPaths: util.ArrayList[String], 119 | fieldNames: util.ArrayList[String]): this.type = { 120 | val feedPaths = placeholderPaths.asScala 121 | val fields = fieldNames.asScala 122 | require(feedPaths.size == fields.size, 123 | s"placeholder paths and field names must match ${(feedPaths, fields)}") 124 | val feedMap = feedPaths.zip(fields).toMap 125 | _shapeHints = _shapeHints.copy(inputs = feedMap) 126 | this 127 | } 128 | 129 | /** 130 | * Builds a java UDF based on the following input. 131 | * @param udfName the name of the udf 132 | * @param applyBlocks whether the function should be applied per row or a block of rows 133 | * @return UDF 134 | */ 135 | def makeUDF(udfName: String, applyBlocks: Boolean): UserDefinedFunction = { 136 | SqlOps.makeUDF(udfName, buildGraphDef(), _shapeHints, 137 | applyBlocks = applyBlocks, flattenStruct = true) 138 | } 139 | 140 | /** 141 | * Builds a java UDF based on the following input. 142 | * @param udfName the name of the udf 143 | * @param applyBlocks whether the function should be applied per row or a block of rows 144 | * @param flattenstruct whether the returned tensor struct should be flattened to vector 145 | * @return UDF 146 | */ 147 | def makeUDF(udfName: String, applyBlocks: Boolean, flattenStruct: Boolean): UserDefinedFunction = { 148 | SqlOps.makeUDF(udfName, buildGraphDef(), _shapeHints, 149 | applyBlocks = applyBlocks, flattenStruct = flattenStruct) 150 | } 151 | 152 | /** 153 | * Registers a TF UDF under the given name in Spark. 154 | * @param udfName the name of the UDF 155 | * @param blocked indicates that the UDF should be applied block-wise. 156 | * @return UDF 157 | */ 158 | def registerUDF(udfName: String, blocked: java.lang.Boolean): UserDefinedFunction = { 159 | assert(_sqlCtx != null) 160 | val udf = makeUDF(udfName, blocked) 161 | logger.warn(s"Registering udf $udfName -> $udf to session ${_sqlCtx.sparkSession}") 162 | UDFUtils.registerUDF(_sqlCtx, udfName, udf) 163 | } 164 | 165 | /** 166 | * Create a TensorFlow GraphDef object from the input graph 167 | * Or load the serialized graph bytes from file 168 | */ 169 | private def buildGraphDef(): GraphDef = { 170 | _graphPath match { 171 | case Some(p) => 172 | val path = Paths.get(p) 173 | val bytes = Files.readAllBytes(path) 174 | TensorFlowOps.readGraphSerial(SerializedGraph.create(bytes)) 175 | case None => 176 | assert(_graph != null) 177 | TensorFlowOps.readGraphSerial(_graph) 178 | } 179 | } 180 | 181 | } 182 | -------------------------------------------------------------------------------- /python/tests/image/test_imageIO.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from io import BytesIO 17 | 18 | # 3rd party 19 | import numpy as np 20 | import PIL.Image 21 | 22 | # pyspark 23 | from pyspark.sql.functions import col, udf 24 | from pyspark.sql.types import BinaryType, StringType, StructField, StructType 25 | 26 | from sparkdl.image import imageIO 27 | from ..tests import SparkDLTestCase 28 | 29 | # Create dome fake image data to work with 30 | def create_image_data(): 31 | # Random image-like data 32 | array = np.random.randint(0, 256, (10, 11, 3), 'uint8') 33 | 34 | # Compress as png 35 | imgFile = BytesIO() 36 | PIL.Image.fromarray(array).save(imgFile, 'png') 37 | imgFile.seek(0) 38 | 39 | # Get Png data as stream 40 | pngData = imgFile.read() 41 | return array, pngData 42 | 43 | array, pngData = create_image_data() 44 | 45 | 46 | class BinaryFilesMock(object): 47 | 48 | defaultParallelism = 4 49 | 50 | def __init__(self, sc): 51 | self.sc = sc 52 | 53 | def binaryFiles(self, path, minPartitions=None): 54 | imagesData = [["file/path", pngData], 55 | ["another/file/path", pngData], 56 | ["bad/image", b"badImageData"] 57 | ] 58 | rdd = self.sc.parallelize(imagesData) 59 | if minPartitions is not None: 60 | rdd = rdd.repartition(minPartitions) 61 | return rdd 62 | 63 | 64 | class TestReadImages(SparkDLTestCase): 65 | @classmethod 66 | def setUpClass(cls): 67 | super(TestReadImages, cls).setUpClass() 68 | cls.binaryFilesMock = BinaryFilesMock(cls.sc) 69 | 70 | @classmethod 71 | def tearDownClass(cls): 72 | super(TestReadImages, cls).tearDownClass() 73 | cls.binaryFilesMock = None 74 | 75 | def test_decodeImage(self): 76 | badImg = imageIO._decodeImage(b"xxx") 77 | self.assertIsNone(badImg) 78 | imgRow = imageIO._decodeImage(pngData) 79 | self.assertIsNotNone(imgRow) 80 | self.assertEqual(len(imgRow), len(imageIO.imageSchema.names)) 81 | for n in imageIO.imageSchema.names: 82 | imgRow[n] 83 | 84 | def test_resize(self): 85 | imgAsRow = imageIO.imageArrayToStruct(array) 86 | smaller = imageIO._resizeFunction([4, 5]) 87 | smallerImg = smaller(imgAsRow) 88 | for n in imageIO.imageSchema.names: 89 | smallerImg[n] 90 | self.assertEqual(smallerImg.height, 4) 91 | self.assertEqual(smallerImg.width, 5) 92 | 93 | sameImage = imageIO._resizeFunction([imgAsRow.height, imgAsRow.width])(imgAsRow) 94 | self.assertEqual(sameImage, sameImage) 95 | 96 | self.assertRaises(ValueError, imageIO._resizeFunction, [1, 2, 3]) 97 | 98 | def test_imageArrayToStruct(self): 99 | SparkMode = imageIO.SparkMode 100 | # Check converting with matching types 101 | height, width, chan = array.shape 102 | imgAsStruct = imageIO.imageArrayToStruct(array) 103 | self.assertEqual(imgAsStruct.height, height) 104 | self.assertEqual(imgAsStruct.width, width) 105 | self.assertEqual(imgAsStruct.data, array.tobytes()) 106 | 107 | # Check casting 108 | imgAsStruct = imageIO.imageArrayToStruct(array, SparkMode.RGB_FLOAT32) 109 | self.assertEqual(imgAsStruct.height, height) 110 | self.assertEqual(imgAsStruct.width, width) 111 | self.assertEqual(len(imgAsStruct.data), array.size * 4) 112 | 113 | # Check channel mismatch 114 | self.assertRaises(ValueError, imageIO.imageArrayToStruct, array, SparkMode.FLOAT32) 115 | 116 | # Check that unsafe cast raises error 117 | floatArray = np.zeros((3, 4, 3), dtype='float32') 118 | self.assertRaises(ValueError, imageIO.imageArrayToStruct, floatArray, SparkMode.RGB) 119 | 120 | def test_image_round_trip(self): 121 | # Test round trip: array -> png -> sparkImg -> array 122 | binarySchema = StructType([StructField("data", BinaryType(), False)]) 123 | df = self.session.createDataFrame([[bytearray(pngData)]], binarySchema) 124 | 125 | # Convert to images 126 | decImg = udf(imageIO._decodeImage, imageIO.imageSchema) 127 | imageDF = df.select(decImg("data").alias("image")) 128 | row = imageDF.first() 129 | 130 | testArray = imageIO.imageStructToArray(row.image) 131 | self.assertEqual(testArray.shape, array.shape) 132 | self.assertEqual(testArray.dtype, array.dtype) 133 | self.assertTrue(np.all(array == testArray)) 134 | 135 | def test_readImages(self): 136 | # Test that reading 137 | imageDF = imageIO._readImages("some/path", 2, self.binaryFilesMock) 138 | self.assertTrue("image" in imageDF.schema.names) 139 | self.assertTrue("filePath" in imageDF.schema.names) 140 | 141 | # The DF should have 2 images and 1 null. 142 | self.assertEqual(imageDF.count(), 3) 143 | validImages = imageDF.filter(col("image").isNotNull()) 144 | self.assertEqual(validImages.count(), 2) 145 | 146 | img = validImages.first().image 147 | self.assertEqual(img.height, array.shape[0]) 148 | self.assertEqual(img.width, array.shape[1]) 149 | self.assertEqual(imageIO.imageType(img).nChannels, array.shape[2]) 150 | self.assertEqual(img.data, array.tobytes()) 151 | 152 | def test_udf_schema(self): 153 | # Test that utility functions can be used to create a udf that accepts and return 154 | # imageSchema 155 | def do_nothing(imgRow): 156 | imType = imageIO.imageType(imgRow) 157 | array = imageIO.imageStructToArray(imgRow) 158 | return imageIO.imageArrayToStruct(array, imType.sparkMode) 159 | do_nothing_udf = udf(do_nothing, imageIO.imageSchema) 160 | 161 | df = imageIO._readImages("path", 2, self.binaryFilesMock) 162 | df = df.filter(col('image').isNotNull()).withColumn("test", do_nothing_udf('image')) 163 | self.assertEqual(df.first().test.data, array.tobytes()) 164 | df.printSchema() 165 | 166 | def test_filesTODF(self): 167 | df = imageIO.filesToDF(self.binaryFilesMock, "path", 217) 168 | self.assertEqual(df.rdd.getNumPartitions(), 217) 169 | df.schema.fields[0].dataType == StringType() 170 | df.schema.fields[0].dataType == BinaryType() 171 | first = df.first() 172 | self.assertTrue(hasattr(first, "filePath")) 173 | self.assertEqual(type(first.fileData), bytearray) 174 | 175 | 176 | # TODO: make unit tests for arrayToImageRow on arrays of varying shapes, channels, dtypes. 177 | -------------------------------------------------------------------------------- /python/sparkdl/graph/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | import logging 18 | import six 19 | import webbrowser 20 | from tempfile import NamedTemporaryFile 21 | 22 | import tensorflow as tf 23 | 24 | logger = logging.getLogger('sparkdl') 25 | 26 | """ 27 | When working with various pieces of TensorFlow, one is faced with 28 | figuring out providing one of the four variants 29 | (`tensor` OR `operation`, `name` OR `graph element`). 30 | 31 | The various combination makes it hard to figuring out the best way. 32 | We provide some methods to map whatever we have as input to 33 | one of the four target variants. 34 | """ 35 | 36 | def validated_graph(graph): 37 | """ 38 | Check if the input is a valid tf.Graph 39 | 40 | :param graph: tf.Graph, a TensorFlow Graph object 41 | """ 42 | assert isinstance(graph, tf.Graph), 'must provide tf.Graph, but get {}'.format(type(graph)) 43 | return graph 44 | 45 | def get_shape(graph, tfobj_or_name): 46 | """ 47 | Return the shape of the tensor as a list 48 | 49 | :param graph: tf.Graph, a TensorFlow Graph object 50 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 51 | """ 52 | graph = validated_graph(graph) 53 | _shape = get_tensor(graph, tfobj_or_name).get_shape().as_list() 54 | return [-1 if x is None else x for x in _shape] 55 | 56 | def get_op(graph, tfobj_or_name): 57 | """ 58 | Get a tf.Operation object 59 | 60 | :param graph: tf.Graph, a TensorFlow Graph object 61 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 62 | """ 63 | graph = validated_graph(graph) 64 | if isinstance(tfobj_or_name, tf.Operation): 65 | return tfobj_or_name 66 | name = tfobj_or_name 67 | if isinstance(tfobj_or_name, tf.Tensor): 68 | name = tfobj_or_name.name 69 | if not isinstance(name, six.string_types): 70 | raise TypeError('invalid op request for {} of {}'.format(name, type(name))) 71 | _op_name = as_op_name(name) 72 | op = graph.get_operation_by_name(_op_name) 73 | assert op is not None, \ 74 | 'cannot locate op {} in current graph'.format(_op_name) 75 | return op 76 | 77 | def get_tensor(graph, tfobj_or_name): 78 | """ 79 | Get a tf.Tensor object 80 | 81 | :param graph: tf.Graph, a TensorFlow Graph object 82 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 83 | """ 84 | graph = validated_graph(graph) 85 | if isinstance(tfobj_or_name, tf.Tensor): 86 | return tfobj_or_name 87 | name = tfobj_or_name 88 | if isinstance(tfobj_or_name, tf.Operation): 89 | name = tfobj_or_name.name 90 | if not isinstance(name, six.string_types): 91 | raise TypeError('invalid tensor request for {} of {}'.format(name, type(name))) 92 | _tensor_name = as_tensor_name(name) 93 | tnsr = graph.get_tensor_by_name(_tensor_name) 94 | assert tnsr is not None, \ 95 | 'cannot locate tensor {} in current graph'.format(_tensor_name) 96 | return tnsr 97 | 98 | def as_tensor_name(name): 99 | """ 100 | Derive tf.Tensor name from an op/tensor name. 101 | We do not check if the tensor exist (as no graph parameter is passed in). 102 | 103 | :param name: op name or tensor name 104 | """ 105 | assert isinstance(name, six.string_types) 106 | name_parts = name.split(":") 107 | assert len(name_parts) <= 2, name_parts 108 | if len(name_parts) < 2: 109 | name += ":0" 110 | return name 111 | 112 | def as_op_name(name): 113 | """ 114 | Derive tf.Operation name from an op/tensor name 115 | We do not check if the operation exist (as no graph parameter is passed in). 116 | 117 | :param name: op name or tensor name 118 | """ 119 | assert isinstance(name, six.string_types) 120 | name_parts = name.split(":") 121 | assert len(name_parts) <= 2, name_parts 122 | return name_parts[0] 123 | 124 | def op_name(graph, tfobj_or_name): 125 | """ 126 | Get the name of a tf.Operation 127 | 128 | :param graph: tf.Graph, a TensorFlow Graph object 129 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 130 | """ 131 | graph = validated_graph(graph) 132 | return get_op(graph, tfobj_or_name).name 133 | 134 | def tensor_name(graph, tfobj_or_name): 135 | """ 136 | Get the name of a tf.Tensor 137 | 138 | :param graph: tf.Graph, a TensorFlow Graph object 139 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 140 | """ 141 | graph = validated_graph(graph) 142 | return get_tensor(graph, tfobj_or_name).name 143 | 144 | def validated_output(graph, tfobj_or_name): 145 | """ 146 | Validate and return the output names useable GraphFunction 147 | 148 | :param graph: tf.Graph, a TensorFlow Graph object 149 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 150 | """ 151 | graph = validated_graph(graph) 152 | return op_name(graph, tfobj_or_name) 153 | 154 | def validated_input(graph, tfobj_or_name): 155 | """ 156 | Validate and return the input names useable GraphFunction 157 | 158 | :param graph: tf.Graph, a TensorFlow Graph object 159 | :param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either 160 | """ 161 | graph = validated_graph(graph) 162 | name = op_name(graph, tfobj_or_name) 163 | op = graph.get_operation_by_name(name) 164 | assert 'Placeholder' == op.type, \ 165 | ('input must be Placeholder, but get', op.type) 166 | return name 167 | 168 | def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False): 169 | """ 170 | Create a static view of the graph by 171 | 1. Converting all variables into constants 172 | 2. Removing graph elements not reachacble to `fetches` 173 | 174 | :param graph: tf.Graph, the graph to be frozen 175 | :param fetches: list, graph elements representing the outputs of the graph 176 | :param return_graph: bool, if set True, return the graph function object 177 | :return: GraphDef, the GraphDef object with cleanup procedure applied 178 | """ 179 | graph = validated_graph(graph) 180 | should_close_session = False 181 | if not sess: 182 | sess = tf.Session(graph=graph) 183 | should_close_session = True 184 | 185 | gdef_frozen = tf.graph_util.convert_variables_to_constants( 186 | sess, 187 | graph.as_graph_def(add_shapes=True), 188 | [op_name(graph, tnsr) for tnsr in fetches]) 189 | 190 | if should_close_session: 191 | sess.close() 192 | 193 | if return_graph: 194 | g = tf.Graph() 195 | with g.as_default(): 196 | tf.import_graph_def(gdef_frozen, name='') 197 | return g 198 | else: 199 | return gdef_frozen 200 | -------------------------------------------------------------------------------- /python/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | export PACKAGE_VERSION 11 | 12 | ifndef PYTHONPATH 13 | $(error PYTHONPATH is undefined) 14 | endif 15 | $(info $$PYTHONPATH is [${PYTHONPATH}]) 16 | 17 | # User-friendly check for sphinx-build 18 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 19 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 20 | endif 21 | 22 | # Internal variables. 23 | PAPEROPT_a4 = -D latex_paper_size=a4 24 | PAPEROPT_letter = -D latex_paper_size=letter 25 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 26 | # the i18n builder cannot share the environment and doctrees with the others 27 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 28 | 29 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 30 | 31 | help: 32 | @echo "Please use \`make ' where is one of" 33 | @echo " html to make standalone HTML files" 34 | @echo " dirhtml to make HTML files named index.html in directories" 35 | @echo " singlehtml to make a single large HTML file" 36 | @echo " pickle to make pickle files" 37 | @echo " json to make JSON files" 38 | @echo " htmlhelp to make HTML files and a HTML help project" 39 | @echo " qthelp to make HTML files and a qthelp project" 40 | @echo " devhelp to make HTML files and a Devhelp project" 41 | @echo " epub to make an epub" 42 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 43 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 44 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 45 | @echo " text to make text files" 46 | @echo " man to make manual pages" 47 | @echo " texinfo to make Texinfo files" 48 | @echo " info to make Texinfo files and run them through makeinfo" 49 | @echo " gettext to make PO message catalogs" 50 | @echo " changes to make an overview of all changed/added/deprecated items" 51 | @echo " xml to make Docutils-native XML files" 52 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 53 | @echo " linkcheck to check all external links for integrity" 54 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 55 | 56 | clean: 57 | rm -rf $(BUILDDIR)/* 58 | 59 | html: 60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 63 | 64 | dirhtml: 65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 66 | @echo 67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 68 | 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | pickle: 75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 76 | @echo 77 | @echo "Build finished; now you can process the pickle files." 78 | 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | htmlhelp: 85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 86 | @echo 87 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 88 | ".hhp project file in $(BUILDDIR)/htmlhelp." 89 | 90 | qthelp: 91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 92 | @echo 93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/pysparkdl.qhcp" 96 | @echo "To view the help file:" 97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/pysparkdl.qhc" 98 | 99 | devhelp: 100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 101 | @echo 102 | @echo "Build finished." 103 | @echo "To view the help file:" 104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/pysparkdl" 105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/pysparkdl" 106 | @echo "# devhelp" 107 | 108 | epub: 109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 110 | @echo 111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 112 | 113 | latex: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo 116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 118 | "(use \`make latexpdf' here to do that automatically)." 119 | 120 | latexpdf: 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | @echo "Running LaTeX files through pdflatex..." 123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 125 | 126 | latexpdfja: 127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 128 | @echo "Running LaTeX files through platex and dvipdfmx..." 129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 131 | 132 | text: 133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 134 | @echo 135 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 136 | 137 | man: 138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 139 | @echo 140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 141 | 142 | texinfo: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo 145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 146 | @echo "Run \`make' in that directory to run these through makeinfo" \ 147 | "(use \`make info' here to do that automatically)." 148 | 149 | info: 150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 151 | @echo "Running Texinfo files through makeinfo..." 152 | make -C $(BUILDDIR)/texinfo info 153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 154 | 155 | gettext: 156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 157 | @echo 158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 159 | 160 | changes: 161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 162 | @echo 163 | @echo "The overview file is in $(BUILDDIR)/changes." 164 | 165 | linkcheck: 166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 167 | @echo 168 | @echo "Link check complete; look for any errors in the above output " \ 169 | "or in $(BUILDDIR)/linkcheck/output.txt." 170 | 171 | doctest: 172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 173 | @echo "Testing of doctests in the sources finished, look at the " \ 174 | "results in $(BUILDDIR)/doctest/output.txt." 175 | 176 | xml: 177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 178 | @echo 179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 180 | 181 | pseudoxml: 182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 183 | @echo 184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 185 | -------------------------------------------------------------------------------- /python/tests/graph/test_pieces.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Databricks, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from __future__ import print_function 18 | 19 | from glob import glob 20 | import os 21 | from tempfile import NamedTemporaryFile 22 | 23 | import numpy as np 24 | import numpy.random as prng 25 | import tensorflow as tf 26 | import keras.backend as K 27 | from keras.applications import InceptionV3 28 | from keras.applications import inception_v3 as iv3 29 | from keras.applications import Xception 30 | from keras.applications import xception as xcpt 31 | from keras.applications import ResNet50 32 | from keras.applications import resnet50 as rsnt 33 | from keras.preprocessing.image import load_img, img_to_array 34 | 35 | from pyspark import SparkContext 36 | from pyspark.sql import DataFrame, Row 37 | from pyspark.sql.functions import udf 38 | 39 | from sparkdl.image.imageIO import imageArrayToStruct, SparkMode 40 | from sparkdl.graph.builder import IsolatedSession, GraphFunction 41 | import sparkdl.graph.pieces as gfac 42 | import sparkdl.graph.utils as tfx 43 | 44 | from ..tests import SparkDLTestCase 45 | from ..transformers.image_utils import _getSampleJPEGDir, getSampleImagePathsDF 46 | 47 | 48 | class GraphFactoryTest(SparkDLTestCase): 49 | 50 | 51 | def test_spimage_converter_module(self): 52 | """ spimage converter module must preserve original image """ 53 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) 54 | 55 | def exec_gfn_spimg_decode(spimg_dict, img_dtype): 56 | gfn = gfac.buildSpImageConverter(img_dtype) 57 | with IsolatedSession() as issn: 58 | feeds, fetches = issn.importGraphFunction(gfn, prefix="") 59 | feed_dict = dict((tnsr, spimg_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds) 60 | img_out = issn.run(fetches[0], feed_dict=feed_dict) 61 | return img_out 62 | 63 | def check_image_round_trip(img_arr): 64 | spimg_dict = imageArrayToStruct(img_arr).asDict() 65 | spimg_dict['data'] = bytes(spimg_dict['data']) 66 | img_arr_out = exec_gfn_spimg_decode(spimg_dict, spimg_dict['mode']) 67 | self.assertTrue(np.all(img_arr_out == img_arr)) 68 | 69 | for fp in img_fpaths: 70 | img = load_img(fp) 71 | 72 | img_arr_byte = img_to_array(img).astype(np.uint8) 73 | check_image_round_trip(img_arr_byte) 74 | 75 | img_arr_float = img_to_array(img).astype(np.float) 76 | check_image_round_trip(img_arr_float) 77 | 78 | img_arr_preproc = iv3.preprocess_input(img_to_array(img)) 79 | check_image_round_trip(img_arr_preproc) 80 | 81 | def test_identity_module(self): 82 | """ identity module should preserve input """ 83 | 84 | with IsolatedSession() as issn: 85 | pred_input = tf.placeholder(tf.float32, [None, None]) 86 | final_output = tf.identity(pred_input, name='output') 87 | gfn = issn.asGraphFunction([pred_input], [final_output]) 88 | 89 | for _ in range(10): 90 | m, n = prng.randint(10, 1000, size=2) 91 | mat = prng.randn(m, n).astype(np.float32) 92 | with IsolatedSession() as issn: 93 | feeds, fetches = issn.importGraphFunction(gfn) 94 | mat_out = issn.run(fetches[0], {feeds[0]: mat}) 95 | 96 | self.assertTrue(np.all(mat_out == mat)) 97 | 98 | def test_flattener_module(self): 99 | """ flattener module should preserve input data """ 100 | 101 | gfn = gfac.buildFlattener() 102 | for _ in range(10): 103 | m, n = prng.randint(10, 1000, size=2) 104 | mat = prng.randn(m, n).astype(np.float32) 105 | with IsolatedSession() as issn: 106 | feeds, fetches = issn.importGraphFunction(gfn) 107 | vec_out = issn.run(fetches[0], {feeds[0]: mat}) 108 | 109 | self.assertTrue(np.all(vec_out == mat.flatten())) 110 | 111 | def test_bare_keras_module(self): 112 | """ Keras GraphFunctions should give the same result as standard Keras models """ 113 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) 114 | 115 | for model_gen, preproc_fn in [(InceptionV3, iv3.preprocess_input), 116 | (Xception, xcpt.preprocess_input), 117 | (ResNet50, rsnt.preprocess_input)]: 118 | 119 | keras_model = model_gen(weights="imagenet") 120 | target_size = tuple(keras_model.input.shape.as_list()[1:-1]) 121 | 122 | _preproc_img_list = [] 123 | for fpath in img_fpaths: 124 | img = load_img(fpath, target_size=target_size) 125 | # WARNING: must apply expand dimensions first, or ResNet50 preprocessor fails 126 | img_arr = np.expand_dims(img_to_array(img), axis=0) 127 | _preproc_img_list.append(preproc_fn(img_arr)) 128 | 129 | imgs_input = np.vstack(_preproc_img_list) 130 | 131 | preds_ref = keras_model.predict(imgs_input) 132 | 133 | gfn_bare_keras = GraphFunction.fromKeras(keras_model) 134 | 135 | with IsolatedSession(using_keras=True) as issn: 136 | K.set_learning_phase(0) 137 | feeds, fetches = issn.importGraphFunction(gfn_bare_keras) 138 | preds_tgt = issn.run(fetches[0], {feeds[0]: imgs_input}) 139 | 140 | self.assertTrue(np.all(preds_tgt == preds_ref)) 141 | 142 | def test_pipeline(self): 143 | """ Pipeline should provide correct function composition """ 144 | img_fpaths = glob(os.path.join(_getSampleJPEGDir(), '*.jpg')) 145 | 146 | xcpt_model = Xception(weights="imagenet") 147 | stages = [('spimage', gfac.buildSpImageConverter(SparkMode.RGB_FLOAT32)), 148 | ('xception', GraphFunction.fromKeras(xcpt_model))] 149 | piped_model = GraphFunction.fromList(stages) 150 | 151 | for fpath in img_fpaths: 152 | target_size = tuple(xcpt_model.input.shape.as_list()[1:-1]) 153 | img = load_img(fpath, target_size=target_size) 154 | img_arr = np.expand_dims(img_to_array(img), axis=0) 155 | img_input = xcpt.preprocess_input(img_arr) 156 | preds_ref = xcpt_model.predict(img_input) 157 | 158 | spimg_input_dict = imageArrayToStruct(img_input).asDict() 159 | spimg_input_dict['data'] = bytes(spimg_input_dict['data']) 160 | with IsolatedSession() as issn: 161 | # Need blank import scope name so that spimg fields match the input names 162 | feeds, fetches = issn.importGraphFunction(piped_model, prefix="") 163 | feed_dict = dict((tnsr, spimg_input_dict[tfx.op_name(issn.graph, tnsr)]) for tnsr in feeds) 164 | preds_tgt = issn.run(fetches[0], feed_dict=feed_dict) 165 | # Uncomment the line below to see the graph 166 | # tfx.write_visualization_html(issn.graph, 167 | # NamedTemporaryFile(prefix="gdef", suffix=".html").name) 168 | 169 | self.assertTrue(np.all(preds_tgt == preds_ref)) 170 | -------------------------------------------------------------------------------- /python/tests/transformers/tf_image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from keras.applications import InceptionV3 17 | from keras.applications.inception_v3 import preprocess_input, decode_predictions 18 | import keras.backend as K 19 | from keras.preprocessing.image import img_to_array, load_img 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | import sparkdl.graph.utils as tfx 24 | from sparkdl.image.imageIO import imageStructToArray 25 | from sparkdl.transformers.keras_utils import KSessionWrap 26 | from sparkdl.transformers.tf_image import TFImageTransformer 27 | import sparkdl.transformers.utils as utils 28 | from sparkdl.transformers.utils import ImageNetConstants, InceptionV3Constants 29 | from ..tests import SparkDLTestCase 30 | from .image_utils import ImageNetOutputComparisonTestCase 31 | from . import image_utils 32 | 33 | 34 | class TFImageTransformerExamplesTest(SparkDLTestCase, ImageNetOutputComparisonTestCase): 35 | 36 | # Test loading & pre-processing as an example of a simple graph 37 | # NOTE: resizing here/tensorflow and in keras workflow are different, so the 38 | # test would fail with resizing added in. 39 | 40 | def _loadImageViaKeras(self, raw_uri): 41 | uri = raw_uri[5:] if raw_uri.startswith("file:/") else raw_uri 42 | image = img_to_array(load_img(uri)) 43 | image = np.expand_dims(image, axis=0) 44 | return preprocess_input(image) 45 | 46 | def test_load_image_vs_keras(self): 47 | g = tf.Graph() 48 | with g.as_default(): 49 | image_arr = utils.imageInputPlaceholder() 50 | preprocessed = preprocess_input(image_arr) 51 | 52 | output_col = "transformed_image" 53 | transformer = TFImageTransformer(inputCol="image", outputCol=output_col, graph=g, 54 | inputTensor=image_arr, outputTensor=preprocessed.name, 55 | outputMode="vector") 56 | 57 | image_df = image_utils.getSampleImageDF() 58 | df = transformer.transform(image_df.limit(5)) 59 | 60 | for row in df.collect(): 61 | processed = np.array(row[output_col]).astype(np.float32) 62 | # compare to keras loading 63 | images = self._loadImageViaKeras(row["filePath"]) 64 | image = images[0] 65 | image.shape = (1, image.shape[0] * image.shape[1] * image.shape[2]) 66 | keras_processed = image[0] 67 | self.assertTrue( (processed == keras_processed).all() ) 68 | 69 | 70 | # Test full pre-processing for InceptionV3 as an example of a simple computation graph 71 | 72 | def _preprocessingInceptionV3Transformed(self, outputMode, outputCol): 73 | g = tf.Graph() 74 | with g.as_default(): 75 | image_arr = utils.imageInputPlaceholder() 76 | resized_images = tf.image.resize_images(image_arr, InceptionV3Constants.INPUT_SHAPE) 77 | processed_images = preprocess_input(resized_images) 78 | self.assertEqual(processed_images.shape[1], InceptionV3Constants.INPUT_SHAPE[0]) 79 | self.assertEqual(processed_images.shape[2], InceptionV3Constants.INPUT_SHAPE[1]) 80 | 81 | transformer = TFImageTransformer(inputCol="image", outputCol=outputCol, graph=g, 82 | inputTensor=image_arr.name, outputTensor=processed_images, 83 | outputMode=outputMode) 84 | image_df = image_utils.getSampleImageDF() 85 | return transformer.transform(image_df.limit(5)) 86 | 87 | def test_image_output(self): 88 | output_col = "resized_image" 89 | preprocessed_df = self._preprocessingInceptionV3Transformed("image", output_col) 90 | self.assertDfHasCols(preprocessed_df, [output_col]) 91 | for row in preprocessed_df.collect(): 92 | original = row["image"] 93 | processed = row[output_col] 94 | errMsg = "nChannels must match: original {} v.s. processed {}" 95 | errMsg = errMsg.format(original.nChannels, processed.nChannels) 96 | self.assertEqual(original.nChannels, processed.nChannels, errMsg) 97 | self.assertEqual(processed.height, InceptionV3Constants.INPUT_SHAPE[0]) 98 | self.assertEqual(processed.width, InceptionV3Constants.INPUT_SHAPE[1]) 99 | 100 | # TODO: add tests for non-RGB8 images, at least RGB-float32. 101 | 102 | 103 | # Test InceptionV3 prediction as an example of applying a trained model. 104 | 105 | def _executeTensorflow(self, graph, input_tensor_name, output_tensor_name, 106 | df, id_col="filePath", input_col="image"): 107 | with tf.Session(graph=graph) as sess: 108 | output_tensor = graph.get_tensor_by_name(output_tensor_name) 109 | image_collected = df.collect() 110 | values = {} 111 | topK = {} 112 | for img_row in image_collected: 113 | image = np.expand_dims(imageStructToArray(img_row[input_col]), axis=0) 114 | uri = img_row[id_col] 115 | output = sess.run([output_tensor], 116 | feed_dict={ 117 | graph.get_tensor_by_name(input_tensor_name): image 118 | }) 119 | values[uri] = np.array(output[0]) 120 | topK[uri] = decode_predictions(values[uri], top=5)[0] 121 | return values, topK 122 | 123 | def test_prediction_vs_tensorflow_inceptionV3(self): 124 | output_col = "prediction" 125 | image_df = image_utils.getSampleImageDF() 126 | 127 | # An example of how a pre-trained keras model can be used with TFImageTransformer 128 | with KSessionWrap() as (sess, g): 129 | with g.as_default(): 130 | K.set_learning_phase(0) # this is important but it's on the user to call it. 131 | # nChannels needed for input_tensor in the InceptionV3 call below 132 | image_string = utils.imageInputPlaceholder(nChannels = 3) 133 | resized_images = tf.image.resize_images(image_string, 134 | InceptionV3Constants.INPUT_SHAPE) 135 | preprocessed = preprocess_input(resized_images) 136 | model = InceptionV3(input_tensor=preprocessed, weights="imagenet") 137 | graph = tfx.strip_and_freeze_until([model.output], g, sess, return_graph=True) 138 | 139 | transformer = TFImageTransformer(inputCol="image", outputCol=output_col, graph=graph, 140 | inputTensor=image_string, outputTensor=model.output, 141 | outputMode="vector") 142 | transformed_df = transformer.transform(image_df.limit(10)) 143 | self.assertDfHasCols(transformed_df, [output_col]) 144 | collected = transformed_df.collect() 145 | transformer_values, transformer_topK = self.transformOutputToComparables(collected, 146 | "filePath", 147 | output_col) 148 | 149 | tf_values, tf_topK = self._executeTensorflow(graph, image_string.name, model.output.name, 150 | image_df) 151 | self.compareClassSets(tf_topK, transformer_topK) 152 | self.compareClassOrderings(tf_topK, transformer_topK) 153 | self.compareArrays(tf_values, transformer_values) 154 | -------------------------------------------------------------------------------- /python/tests/transformers/named_image_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from keras.applications import inception_v3 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from pyspark.ml import Pipeline 21 | from pyspark.ml.classification import LogisticRegression 22 | from pyspark.sql.functions import udf 23 | from pyspark.sql.types import IntegerType, StructType, StructField 24 | 25 | from sparkdl.image import imageIO 26 | from sparkdl.transformers.named_image import (DeepImagePredictor, DeepImageFeaturizer, 27 | _buildTFGraphForName) 28 | from sparkdl.transformers.utils import InceptionV3Constants 29 | from ..tests import SparkDLTestCase 30 | from .image_utils import getSampleImageDF, getSampleImageList 31 | 32 | 33 | class NamedImageTransformerImagenetTest(SparkDLTestCase): 34 | 35 | @classmethod 36 | def setUpClass(cls): 37 | super(NamedImageTransformerImagenetTest, cls).setUpClass() 38 | 39 | # Compute values used by multiple tests. 40 | imgFiles, images = getSampleImageList() 41 | imageArray = np.empty((len(images), 299, 299, 3), 'uint8') 42 | for i, img in enumerate(images): 43 | assert img is not None and img.mode == "RGB" 44 | imageArray[i] = np.array(img.resize((299, 299))) 45 | 46 | # Predict the class probabilities for the images in our test library using keras API. 47 | prepedImaged = inception_v3.preprocess_input(imageArray.astype('float32')) 48 | model = inception_v3.InceptionV3() 49 | kerasPredict = model.predict(prepedImaged) 50 | # These values are used by multiple tests so cache them on class setup. 51 | cls.imageArray = imageArray 52 | cls.kerasPredict = kerasPredict 53 | 54 | def test_buildtfgraphforname(self): 55 | """" 56 | Run the graph produced by _buildtfgraphforname and compare the result to above keras 57 | result. 58 | """ 59 | imageArray = self.imageArray 60 | kerasPredict = self.kerasPredict 61 | modelGraphInfo = _buildTFGraphForName("InceptionV3", False) 62 | graph = modelGraphInfo["graph"] 63 | sess = tf.Session(graph=graph) 64 | with sess.as_default(): 65 | inputTensor = graph.get_tensor_by_name(modelGraphInfo["inputTensorName"]) 66 | outputTensor = graph.get_tensor_by_name(modelGraphInfo["outputTensorName"]) 67 | tfPredict = sess.run(outputTensor, {inputTensor: imageArray}) 68 | 69 | self.assertEqual(kerasPredict.shape, tfPredict.shape) 70 | np.testing.assert_array_almost_equal(kerasPredict, tfPredict) 71 | 72 | def test_DeepImagePredictorNoReshape(self): 73 | """ 74 | Run sparkDL inceptionV3 transformer on resized images and compare result to cached keras 75 | result. 76 | """ 77 | imageArray = self.imageArray 78 | kerasPredict = self.kerasPredict 79 | def rowWithImage(img): 80 | # return [imageIO.imageArrayToStruct(img.astype('uint8'), imageType.sparkMode)] 81 | row = imageIO.imageArrayToStruct(img.astype('uint8'), imageIO.SparkMode.RGB) 82 | # re-order row to avoid pyspark bug 83 | return [[getattr(row, field.name) for field in imageIO.imageSchema]] 84 | 85 | # test: predictor vs keras on resized images 86 | rdd = self.sc.parallelize([rowWithImage(img) for img in imageArray]) 87 | dfType = StructType([StructField("image", imageIO.imageSchema)]) 88 | imageDf = rdd.toDF(dfType) 89 | 90 | transformer = DeepImagePredictor(inputCol='image', modelName="InceptionV3", 91 | outputCol="prediction",) 92 | dfPredict = transformer.transform(imageDf).collect() 93 | dfPredict = np.array([i.prediction for i in dfPredict]) 94 | 95 | self.assertEqual(kerasPredict.shape, dfPredict.shape) 96 | np.testing.assert_array_almost_equal(kerasPredict, dfPredict) 97 | 98 | def test_DeepImagePredictor(self): 99 | """ 100 | Run sparkDL inceptionV3 transformer on raw (original size) images and compare result to 101 | above keras (using keras resizing) result. 102 | """ 103 | kerasPredict = self.kerasPredict 104 | transformer = DeepImagePredictor(inputCol='image', modelName="InceptionV3", 105 | outputCol="prediction",) 106 | origImgDf = getSampleImageDF() 107 | fullPredict = transformer.transform(origImgDf).collect() 108 | fullPredict = np.array([i.prediction for i in fullPredict]) 109 | 110 | self.assertEqual(kerasPredict.shape, fullPredict.shape) 111 | # We use a large tolerance below because of differences in the resize step 112 | # TODO: match keras resize step to get closer prediction 113 | np.testing.assert_array_almost_equal(kerasPredict, fullPredict, decimal=6) 114 | 115 | def test_inceptionV3_prediction_decoded(self): 116 | output_col = "prediction" 117 | topK = 10 118 | transformer = DeepImagePredictor(inputCol="image", outputCol=output_col, 119 | modelName="InceptionV3", decodePredictions=True, topK=topK) 120 | 121 | image_df = getSampleImageDF() 122 | transformed_df = transformer.transform(image_df.limit(5)) 123 | 124 | collected = transformed_df.collect() 125 | for row in collected: 126 | predictions = row[output_col] 127 | self.assertEqual(len(predictions), topK) 128 | # TODO: actually check the value of the output to see if they are reasonable 129 | # e.g. -- compare to just running with keras. 130 | 131 | def test_inceptionV3_featurization(self): 132 | output_col = "prediction" 133 | transformer = DeepImageFeaturizer(inputCol="image", outputCol=output_col, 134 | modelName="InceptionV3") 135 | 136 | image_df = getSampleImageDF() 137 | transformed_df = transformer.transform(image_df.limit(5)) 138 | 139 | collected = transformed_df.collect() 140 | for row in collected: 141 | predictions = row[output_col] 142 | self.assertEqual(len(predictions), InceptionV3Constants.NUM_OUTPUT_FEATURES) 143 | # TODO: actually check the value of the output to see if they are reasonable 144 | # e.g. -- compare to just running with keras. 145 | 146 | def test_featurizer_in_pipeline(self): 147 | """ 148 | Tests that the featurizer fits into an MLlib Pipeline. 149 | Does not test how good the featurization is for generalization. 150 | """ 151 | featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features", 152 | modelName="InceptionV3") 153 | lr = LogisticRegression(maxIter=20, regParam=0.05, elasticNetParam=0.3, labelCol="label") 154 | pipeline = Pipeline(stages=[featurizer, lr]) 155 | 156 | # add arbitrary labels to run logistic regression 157 | # TODO: it's weird that the test fails on some combinations of labels. check why. 158 | label_udf = udf(lambda x: abs(hash(x)) % 2, IntegerType()) 159 | image_df = getSampleImageDF() 160 | train_df = image_df.withColumn("label", label_udf(image_df["filePath"])) 161 | 162 | lrModel = pipeline.fit(train_df) 163 | # see if we at least get the training examples right. 164 | # with 5 examples and 131k features, it ought to. 165 | pred_df_collected = lrModel.transform(train_df).collect() 166 | for row in pred_df_collected: 167 | self.assertEqual(int(row.prediction), row.label) 168 | -------------------------------------------------------------------------------- /python/sparkdl/image/imageIO.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Databricks, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from io import BytesIO 17 | from collections import namedtuple 18 | from warnings import warn 19 | 20 | # 3rd party 21 | import numpy as np 22 | from PIL import Image 23 | 24 | # pyspark 25 | from pyspark import Row 26 | from pyspark import SparkContext 27 | from pyspark.sql.types import (BinaryType, IntegerType, StringType, StructField, StructType) 28 | from pyspark.sql.functions import udf 29 | 30 | 31 | imageSchema = StructType([StructField("mode", StringType(), False), 32 | StructField("height", IntegerType(), False), 33 | StructField("width", IntegerType(), False), 34 | StructField("nChannels", IntegerType(), False), 35 | StructField("data", BinaryType(), False)]) 36 | 37 | 38 | # ImageType class for holding metadata about images stored in DataFrames. 39 | # fields: 40 | # nChannels - number of channels in the image 41 | # dtype - data type of the image's "data" Column, sorted as a numpy compatible string. 42 | # channelContent - info about the contents of each channel currently only "I" (intensity) and 43 | # "RGB" are supported for 1 and 3 channel data respectively. 44 | # pilMode - The mode that should be used to convert to a PIL image. 45 | # sparkMode - Unique identifier string used in spark image representation. 46 | ImageType = namedtuple("ImageType", ["nChannels", 47 | "dtype", 48 | "channelContent", 49 | "pilMode", 50 | "sparkMode", 51 | ]) 52 | class SparkMode(object): 53 | RGB = "RGB" 54 | FLOAT32 = "float32" 55 | RGB_FLOAT32 = "RGB-float32" 56 | 57 | supportedImageTypes = [ 58 | ImageType(3, "uint8", "RGB", "RGB", SparkMode.RGB), 59 | ImageType(1, "float32", "I", "F", SparkMode.FLOAT32), 60 | ImageType(3, "float32", "RGB", None, SparkMode.RGB_FLOAT32), 61 | ] 62 | pilModeLookup = {t.pilMode: t for t in supportedImageTypes 63 | if t.pilMode is not None} 64 | sparkModeLookup = {t.sparkMode: t for t in supportedImageTypes} 65 | 66 | 67 | def imageArrayToStruct(imgArray, sparkMode=None): 68 | """ 69 | Create a row representation of an image from an image array and (optional) imageType. 70 | 71 | to_image_udf = udf(arrayToImageRow, imageSchema) 72 | df.withColumn("output_img", to_image_udf(df["np_arr_col"]) 73 | 74 | :param imgArray: ndarray, image data. 75 | :param sparkMode: spark mode, type information for the image, will be inferred from array if 76 | the mode is not provide. See SparkMode for valid modes. 77 | :return: Row, image as a DataFrame Row. 78 | """ 79 | # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists. 80 | if len(imgArray.shape) == 4: 81 | if imgArray.shape[0] != 1: 82 | raise ValueError("The first dimension of a 4-d image array is expected to be 1.") 83 | imgArray = imgArray.reshape(imgArray.shape[1:]) 84 | 85 | if sparkMode is None: 86 | sparkMode = _arrayToSparkMode(imgArray) 87 | imageType = sparkModeLookup[sparkMode] 88 | 89 | height, width, nChannels = imgArray.shape 90 | if imageType.nChannels != nChannels: 91 | msg = "Image of type {} should have {} channels, but array has {} channels." 92 | raise ValueError(msg.format(sparkMode, imageType.nChannels, nChannels)) 93 | 94 | # Convert the array to match the image type. 95 | if not np.can_cast(imgArray, imageType.dtype, 'same_kind'): 96 | msg = "Array of type {} cannot safely be cast to image type {}." 97 | raise ValueError(msg.format(imgArray.dtype, imageType.dtype)) 98 | imgArray = np.array(imgArray, dtype=imageType.dtype, copy=False) 99 | 100 | data = bytearray(imgArray.tobytes()) 101 | return Row(mode=sparkMode, height=height, width=width, nChannels=nChannels, data=data) 102 | 103 | 104 | def imageType(imageRow): 105 | """ 106 | Get type information about the image. 107 | 108 | :param imageRow: spark image row. 109 | :return: ImageType 110 | """ 111 | return sparkModeLookup[imageRow.mode] 112 | 113 | 114 | def imageStructToArray(imageRow): 115 | """ 116 | Convert an image to a numpy array. 117 | 118 | :param imageRow: Row, must use imageSchema. 119 | :return: ndarray, image data. 120 | """ 121 | imType = imageType(imageRow) 122 | shape = (imageRow.height, imageRow.width, imageRow.nChannels) 123 | return np.ndarray(shape, imType.dtype, imageRow.data) 124 | 125 | 126 | def _arrayToSparkMode(arr): 127 | assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format(arr.shape) 128 | num_channels = arr.shape[2] 129 | if num_channels == 1: 130 | if arr.dtype not in [np.float16, np.float32, np.float64]: 131 | raise ValueError("incompatible dtype (%s) for numpy array for float32 mode" % 132 | arr.dtype.string) 133 | return SparkMode.FLOAT32 134 | elif num_channels != 3: 135 | raise ValueError("number of channels of the input array (%d) is not supported" % 136 | num_channels) 137 | elif arr.dtype == np.uint8: 138 | return SparkMode.RGB 139 | elif arr.dtype in [np.float16, np.float32, np.float64]: 140 | return SparkMode.RGB_FLOAT32 141 | else: 142 | raise ValueError("did not find a sparkMode for the given array with num_channels = %d " + 143 | "and dtype %s" % (num_channels, arr.dtype.string)) 144 | 145 | 146 | def _resizeFunction(size): 147 | """ Creates a resize function. 148 | 149 | :param size: tuple, size of new image: (height, width). 150 | :return: function: image => image, a function that converts an input image to an image with 151 | of `size`. 152 | """ 153 | 154 | if len(size) != 2: 155 | raise ValueError("New image size should have for [hight, width] but got {}".format(size)) 156 | 157 | def resizeImageAsRow(imgAsRow): 158 | imgAsArray = imageStructToArray(imgAsRow) 159 | imgType = imageType(imgAsRow) 160 | imgAsPil = Image.fromarray(imgAsArray, imgType.pilMode) 161 | imgAsPil = imgAsPil.resize(size[::-1]) 162 | imgAsArray = np.array(imgAsPil) 163 | return imageArrayToStruct(imgAsArray, imgType.sparkMode) 164 | 165 | return resizeImageAsRow 166 | 167 | 168 | def resizeImage(size): 169 | """ Create a udf for resizing image. 170 | 171 | Example usage: 172 | dataFrame.select(resizeImage((height, width))('imageColumn')) 173 | 174 | :param size: tuple, target size of new image in the form (height, width). 175 | :return: udf, a udf for resizing an image column to `size`. 176 | """ 177 | return udf(_resizeFunction(size), imageSchema) 178 | 179 | 180 | def _decodeImage(imageData): 181 | """ 182 | Decode compressed image data into a DataFrame image row. 183 | 184 | :param imageData: (bytes, bytearray) compressed image data in PIL compatible format. 185 | :return: Row, decoded image. 186 | """ 187 | try: 188 | img = Image.open(BytesIO(imageData)) 189 | except IOError: 190 | return None 191 | 192 | if img.mode in pilModeLookup: 193 | mode = pilModeLookup[img.mode] 194 | else: 195 | msg = "We don't currently support images with mode: {mode}" 196 | warn(msg.format(mode=img.mode)) 197 | return None 198 | imgArray = np.asarray(img) 199 | image = imageArrayToStruct(imgArray, mode.sparkMode) 200 | return image 201 | 202 | # Creating a UDF on import can cause SparkContext issues sometimes. 203 | # decodeImage = udf(_decodeImage, imageSchema) 204 | 205 | def filesToDF(sc, path, numPartitions=None): 206 | """ 207 | Read files from a directory to a DataFrame. 208 | 209 | :param sc: SparkContext. 210 | :param path: str, path to files. 211 | :param numPartition: int, number or partitions to use for reading files. 212 | :return: DataFrame, with columns: (filePath: str, fileData: BinaryType) 213 | """ 214 | numPartitions = numPartitions or sc.defaultParallelism 215 | schema = StructType([StructField("filePath", StringType(), False), 216 | StructField("fileData", BinaryType(), False)]) 217 | rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions) 218 | rdd = rdd.map(lambda x: (x[0], bytearray(x[1]))) 219 | return rdd.toDF(schema) 220 | 221 | 222 | def readImages(imageDirectory, numPartition=None): 223 | """ 224 | Read a directory of images (or a single image) into a DataFrame. 225 | 226 | :param sc: spark context 227 | :param imageDirectory: str, file path. 228 | :param numPartition: int, number or partitions to use for reading files. 229 | :return: DataFrame, with columns: (filepath: str, image: imageSchema). 230 | """ 231 | return _readImages(imageDirectory, numPartition, SparkContext.getOrCreate()) 232 | 233 | 234 | def _readImages(imageDirectory, numPartition, sc): 235 | decodeImage = udf(_decodeImage, imageSchema) 236 | imageData = filesToDF(sc, imageDirectory, numPartitions=numPartition) 237 | return imageData.select("filePath", decodeImage("fileData").alias("image")) 238 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Learning Pipelines for Apache Spark 2 | 3 | Deep Learning Pipelines provides high-level APIs for scalable deep learning in Python. The 4 | library comes from Databricks and leverages Spark for its two strongest facets: 5 | 1. In the spirit of Spark and Spark MLlib, it provides easy-to-use APIs that enable deep learning 6 | in very few lines of code. 7 | 2. It uses Spark's powerful distributed engine to scale out deep learning on massive datasets. 8 | 9 | Currently, TensorFlow and TensorFlow-backed Keras workflows are supported, with a focus on model 10 | application and transfer learning on image data at scale, with hyper-parameter tuning in the works. 11 | Furthermore, it provides tools for data scientists and machine learning experts to turn deep 12 | learning models into SQL functions that can be used by a much wider group of users. It does not 13 | perform single-model distributed training - this is an area of active research, and here we aim to 14 | provide the most practical solutions for the majority of deep learning use cases. 15 | 16 | For an overview of the library, see the Databrick [blog post](https://databricks.com/blog/2017/06/06/databricks-vision-simplify-large-scale-deep-learning.html?preview=true) introducing Deep Learning Pipelines. 17 | For the various use cases the package serves, see the [Quick user guide](#quick-user-guide) section below. 18 | 19 | The library is in its early days, and we welcome everyone's feedback and contribution. 20 | 21 | Authors: Bago Amirbekian, Joseph Bradley, Sue Ann Hong, Tim Hunter, Philip Yang 22 | 23 | 24 | ## Building and running unit tests 25 | 26 | To compile this project, run `build/sbt assembly` from the project home directory. 27 | This will also run the Scala unit tests. 28 | 29 | To run the Python unit tests, run the `run-tests.sh` script from the `python/` directory. 30 | You will need to set a few environment variables, e.g. 31 | ```bash 32 | sparkdl$ SPARK_HOME=/usr/local/lib/spark-2.1.1-bin-hadoop2.7 PYSPARK_PYTHON=python2 SCALA_VERSION=2.11.8 SPARK_VERSION=2.1.1 ./python/run-tests.sh 33 | ``` 34 | 35 | 36 | ## Spark version compatibility 37 | 38 | Spark 2.1.1 and Python 2.7 are recommended. 39 | 40 | 41 | 42 | ## Quick user guide 43 | 44 | The current version of Deep Learning Pipelines provides a suite of tools around working with and 45 | processing images using deep learning. The tools can be categorized as 46 | * [Working with images in Spark](#working-with-images-in-spark) : natively in Spark DataFrames 47 | * [Transfer learning](#transfer-learning) : a super quick way to leverage deep learning 48 | * [Applying deep learning models at scale](#applying-deep-learning-models-at-scale) : apply your 49 | own or known popular models to image data to make predictions or transform them into features 50 | * Deploying models as SQL functions : empower everyone by making deep learning available in SQL (coming soon) 51 | * Distributed hyper-parameter tuning : via Spark MLlib Pipelines (coming soon) 52 | 53 | To try running the examples below, check out the Databricks notebook 54 | [Deep Learning Pipelines on Databricks](https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/5669198905533692/3647723071348946/3983381308530741/latest.html). 55 | 56 | 57 | ### Working with images in Spark 58 | The first step to applying deep learning on images is the ability to load the images. Deep Learning 59 | Pipelines includes utility functions that can load millions of images into a Spark DataFrame and 60 | decode them automatically in a distributed fashion, allowing manipulation at scale. 61 | 62 | ```python 63 | from sparkdl import readImages 64 | image_df = readImages("/data/myimages") 65 | ``` 66 | 67 | The resulting DataFrame contains a string column named "filePath" containing the path to each image 68 | file, and a image struct ("`SpImage`") column named "image" containing the decoded image data. 69 | 70 | ```python 71 | image_df.show() 72 | ``` 73 | 74 | The goal is to add support for more data types, such as text and time series, as there is interest. 75 | 76 | 77 | ### Transfer learning 78 | Deep Learning Pipelines provides utilities to perform 79 | [transfer learning](https://en.wikipedia.org/wiki/Transfer_learning) on images, which is one of 80 | the fastest (code and run-time-wise) ways to start using deep learning. Using Deep Learning 81 | Pipelines, it can be done in just several lines of code. 82 | 83 | ```python 84 | from pyspark.ml.classification import LogisticRegression 85 | from pyspark.ml.evaluation import MulticlassClassificationEvaluator 86 | from pyspark.ml import Pipeline 87 | from sparkdl import DeepImageFeaturizer 88 | 89 | featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features", modelName="InceptionV3") 90 | lr = LogisticRegression(maxIter=20, regParam=0.05, elasticNetParam=0.3, labelCol="label") 91 | p = Pipeline(stages=[featurizer, lr]) 92 | 93 | model = p.fit(train_images_df) # train_images_df is a dataset of images (SpImage) and labels 94 | 95 | # Inspect training error 96 | df = model.transform(train_images_df.limit(10)).select("image", "probability", "uri", "label") 97 | predictionAndLabels = df.select("prediction", "label") 98 | evaluator = MulticlassClassificationEvaluator(metricName="accuracy") 99 | print("Training set accuracy = " + str(evaluator.evaluate(predictionAndLabels))) 100 | ``` 101 | 102 | 103 | ### Applying deep learning models at scale 104 | Spark DataFrames are a natural construct for applying deep learning models to a large-scale dataset. 105 | Deep Learning Pipelines provides a set of (Spark MLlib) Transformers for applying TensorFlow Graphs 106 | and TensorFlow-backed Keras Models at scale. In addition, popular images models can be applied out 107 | of the box, without requiring any TensorFlow or Keras code. The Transformers, backed by the 108 | Tensorframes library, efficiently handle the distribution of models and data to Spark workers. 109 | 110 | #### Applying popular image models 111 | There are many well-known deep learning models for images. If the task at hand is very similar to 112 | what the models provide (e.g. object recognition with ImageNet classes), or for pure exploration, 113 | one can use the Transformer `DeepImagePredictor` by simply specifying the model name. 114 | 115 | ```python 116 | from sparkdl import readImages, DeepImagePredictor 117 | 118 | predictor = DeepImagePredictor(inputCol="image", outputCol="predicted_labels", 119 | modelName="InceptionV3", decodePredictions=True, topK=10) 120 | image_df = readImages("/data/myimages") 121 | predictions_df = predictor.transform(image_df) 122 | ``` 123 | 124 | #### For TensorFlow users 125 | Deep Learning Pipelines provides a Transformer that will apply the given TensorFlow Graph to a 126 | DataFrame containing a column of images (e.g. loaded using the utilities described in the previous 127 | section). Here is a very simple example of how a TensorFlow Graph can be used with the 128 | Transformer. In practice, the TensorFlow Graph will likely be restored from files before calling 129 | `TFImageTransformer`. 130 | 131 | ```python 132 | from sparkdl import readImages, TFImageTransformer 133 | from sparkdl.transformers import utils 134 | import tensorflow as tf 135 | 136 | g = tf.Graph() 137 | with g.as_default(): 138 | image_arr = utils.imageInputPlaceholder() 139 | resized_images = tf.image.resize_images(image_arr, (299, 299)) 140 | # the following step is not necessary for this graph, but can be for graphs with variables, etc 141 | frozen_graph = utils.stripAndFreezeGraph(g.as_graph_def(add_shapes=True), tf.Session(graph=g), 142 | [resized_images]) 143 | 144 | transformer = TFImageTransformer(inputCol="image", outputCol="predictions", graph=frozen_graph, 145 | inputTensor=image_arr, outputTensor=resized_images, 146 | outputMode="image") 147 | image_df = readImages("/data/myimages") 148 | processed_image_df = transformer.transform(image_df) 149 | ``` 150 | 151 | 152 | 153 | #### For Keras users 154 | For applying Keras models in a distributed manner using Spark, [`KerasImageFileTransformer`](link_here) 155 | works on TensorFlow-backed Keras models. It 156 | * Internally creates a DataFrame containing a column of images by applying the user-specified image 157 | loading and processing function to the input DataFrame containing a column of image URIs 158 | * Loads a Keras model from the given model file path 159 | * Applies the model to the image DataFrame 160 | 161 | The difference in the API from `TFImageTransformer` above stems from the fact that usual Keras 162 | workflows have very specific ways to load and resize images that are not part of the TensorFlow Graph. 163 | 164 | 165 | To use the transformer, we first need to have a Keras model stored as a file. For this example we'll 166 | just save the Keras built-in InceptionV3 model instead of training one. 167 | 168 | ```python 169 | from keras.applications import InceptionV3 170 | 171 | model = InceptionV3(weights="imagenet") 172 | model.save('/tmp/model-full.h5') 173 | ``` 174 | 175 | Now on the prediction side, we can do: 176 | 177 | ```python 178 | from keras.applications.inception_v3 import preprocess_input 179 | from keras.preprocessing.image import img_to_array, load_img 180 | import numpy as np 181 | import os 182 | from sparkdl import KerasImageFileTransformer 183 | 184 | def loadAndPreprocessKerasInceptionV3(uri): 185 | # this is a typical way to load and prep images in keras 186 | image = img_to_array(load_img(uri, target_size=(299, 299))) 187 | image = np.expand_dims(image, axis=0) 188 | return preprocess_input(image) 189 | 190 | transformer = KerasImageFileTransformer(inputCol="uri", outputCol="predictions", 191 | modelFile="/tmp/model-full.h5", 192 | imageLoader=loadAndPreprocessKerasInceptionV3, 193 | outputMode="vector") 194 | 195 | files = [os.path.abspath(os.path.join(dirpath, f)) for f in os.listdir("/data/myimages") if f.endswith('.jpg')] 196 | uri_df = sqlContext.createDataFrame(files, StringType()).toDF("uri") 197 | 198 | final_df = transformer.transform(uri_df) 199 | ``` 200 | 201 | 202 | ## Releases: 203 | 204 | **TBA** 205 | --------------------------------------------------------------------------------